SOLR-11250: A new DefaultWrapperModel class for loading of large and/or externally stored LTRScoringModel definitions. (Yuki Yano, shalin, Christine Poerschke)

This commit is contained in:
Christine Poerschke 2017-11-28 14:55:57 +00:00
parent a06e685642
commit 64b3a5bb4b
11 changed files with 878 additions and 7 deletions

View File

@ -99,6 +99,9 @@ New Features
* SOLR-11202: Implement a set-property command for AutoScaling API. (ab, shalin)
* SOLR-11250: A new DefaultWrapperModel class for loading of large and/or externally stored
LTRScoringModel definitions. (Yuki Yano, shalin, Christine Poerschke)
Bug Fixes
----------------------

View File

@ -25,6 +25,11 @@
<import file="../contrib-build.xml"/>
<path id="test.classpath">
<path refid="solr.test.base.classpath"/>
<fileset dir="${test.lib.dir}" includes="*.jar"/>
</path>
<target name="compile-core" depends=" solr-contrib-build.compile-core"/>
</project>

View File

@ -24,9 +24,10 @@
</configurations>
<dependencies>
<dependency org="org.slf4j" name="jcl-over-slf4j" rev="${/org.slf4j/jcl-over-slf4j}" conf="test"/>
<dependency org="org.mockito" name="mockito-core" rev="${/org.mockito/mockito-core}" conf="test"/>
<dependency org="net.bytebuddy" name="byte-buddy" rev="${/net.bytebuddy/byte-buddy}" conf="test"/>
<dependency org="org.objenesis" name="objenesis" rev="${/org.objenesis/objenesis}" conf="test"/>
<exclude org="*" ext="*" matcher="regexp" type="${ivy.exclude.types}"/>
</dependencies>
</ivy-module>

View File

@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.model;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.norm.Normalizer;
import org.noggit.JSONParser;
import org.noggit.ObjectBuilder;
/**
* A scoring model that fetches the wrapped model from {@link SolrResourceLoader}.
*
* <p>This model uses {@link SolrResourceLoader#openResource(String)} for fetching the wrapped model.
* If you give a relative path for {@code params/resource}, this model will try to load the wrapped model from
* the instance directory (i.e. ${solr.solr.home}). Otherwise, seek through classpaths.
*
* <p>Example configuration:
* <pre>{
"class": "org.apache.solr.ltr.model.DefaultWrapperModel",
"name": "myWrapperModelName",
"params": {
"resource": "models/myModel.json"
}
}</pre>
*
* @see SolrResourceLoader#openResource(String)
*/
public class DefaultWrapperModel extends WrapperModel {
private String resource;
public DefaultWrapperModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName,
List<Feature> allFeatures, Map<String, Object> params) {
super(name, features, norms, featureStoreName, allFeatures, params);
}
public void setResource(String resource) {
this.resource = resource;
}
@Override
protected void validate() throws ModelException {
super.validate();
if (resource == null) {
throw new ModelException("no resource configured for model "+name);
}
}
@Override
public Map<String, Object> fetchModelMap() throws ModelException {
Map<String, Object> modelMapObj;
try (InputStream in = openInputStream()) {
modelMapObj = parseInputStream(in);
} catch (IOException e) {
throw new ModelException("Failed to fetch the wrapper model from given resource (" + resource + ")", e);
}
return modelMapObj;
}
protected InputStream openInputStream() throws IOException {
return solrResourceLoader.openResource(resource);
}
@SuppressWarnings("unchecked")
protected Map<String, Object> parseInputStream(InputStream in) throws IOException {
try (Reader reader = new BufferedReader(new InputStreamReader(in, StandardCharsets.UTF_8))) {
return (Map<String, Object>) new ObjectBuilder(new JSONParser(reader)).getVal();
}
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder(getClass().getSimpleName());
sb.append("(name=").append(getName());
sb.append(",resource=").append(resource);
sb.append(",model=(").append(model.toString()).append(")");
return sb.toString();
}
}

View File

@ -0,0 +1,169 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.model;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.norm.Normalizer;
/**
* A scoring model that wraps the other model.
*
* <p>This model loads a model from an external resource during the initialization.
* The way of fetching the wrapped model is depended on
* the implementation of {@link WrapperModel#fetchModelMap()}.
*
* <p>This model doesn't hold the actual parameters of the wrapped model,
* thus it can manage large models which are difficult to upload to ZooKeeper.
*
* <p>Example configuration:
* <pre>{
"class": "...",
"name": "myModelName",
"params": {
...
}
}</pre>
*
* <p>NOTE: no "features" are configured in the wrapper model
* because the wrapped model's features will be used instead.
* Also note that if a "store" is configured for the wrapper
* model then it must match the "store" of the wrapped model.
*/
public abstract class WrapperModel extends LTRScoringModel {
protected SolrResourceLoader solrResourceLoader;
protected LTRScoringModel model;
@Override
public int hashCode() {
final int prime = 31;
int result = super.hashCode();
result = prime * result + ((model == null) ? 0 : model.hashCode());
result = prime * result + ((solrResourceLoader == null) ? 0 : solrResourceLoader.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj) return true;
if (!super.equals(obj)) return false;
if (getClass() != obj.getClass()) return false;
WrapperModel other = (WrapperModel) obj;
if (model == null) {
if (other.model != null) return false;
} else if (!model.equals(other.model)) return false;
if (solrResourceLoader == null) {
if (other.solrResourceLoader != null) return false;
} else if (!solrResourceLoader.equals(other.solrResourceLoader)) return false;
return true;
}
public WrapperModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName,
List<Feature> allFeatures, Map<String, Object> params) {
super(name, features, norms, featureStoreName, allFeatures, params);
}
@Override
protected void validate() throws ModelException {
if (!features.isEmpty()) {
throw new ModelException("features must be empty for the wrapper model " + name);
}
if (!norms.isEmpty()) {
throw new ModelException("norms must be empty for the wrapper model " + name);
}
if (model != null) {
super.validate();
model.validate();
// check feature store names match
final String wrappedFeatureStoreName = model.getFeatureStoreName();
if (wrappedFeatureStoreName == null || !wrappedFeatureStoreName.equals(this.getFeatureStoreName())) {
throw new ModelException("wrapper feature store name ("+this.getFeatureStoreName() +")"
+ " must match the "
+ "wrapped feature store name ("+wrappedFeatureStoreName+")");
}
}
}
public void setSolrResourceLoader(SolrResourceLoader solrResourceLoader) {
this.solrResourceLoader = solrResourceLoader;
}
public void updateModel(LTRScoringModel model) {
this.model = model;
validate();
}
/*
* The child classes must implement how to fetch the definition of the wrapped model.
*/
public abstract Map<String, Object> fetchModelMap() throws ModelException;
@Override
public List<Normalizer> getNorms() {
return model.getNorms();
}
@Override
public List<Feature> getFeatures() {
return model.getFeatures();
}
@Override
public Collection<Feature> getAllFeatures() {
return model.getAllFeatures();
}
@Override
public float score(float[] modelFeatureValuesNormalized) {
return model.score(modelFeatureValuesNormalized);
}
@Override
public Explanation explain(LeafReaderContext context, int doc, float finalScore,
List<Explanation> featureExplanations) {
return model.explain(context, doc, finalScore, featureExplanations);
}
@Override
public void normalizeFeaturesInPlace(float[] modelFeatureValues) {
model.normalizeFeaturesInPlace(modelFeatureValues);
}
@Override
public Explanation getNormalizerExplanation(Explanation e, int idx) {
return model.getNormalizerExplanation(e, idx);
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder(getClass().getSimpleName());
sb.append("(name=").append(getName());
sb.append(",model=(").append(model.toString()).append(")");
return sb.toString();
}
}

View File

@ -27,6 +27,7 @@ import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.SolrCore;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.model.WrapperModel;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.model.ModelException;
import org.apache.solr.ltr.norm.IdentityNormalizer;
@ -231,7 +232,7 @@ public class ManagedModelStore extends ManagedResource implements ManagedResourc
}
}
return LTRScoringModel.getInstance(solrResourceLoader,
final LTRScoringModel ltrScoringModel = LTRScoringModel.getInstance(solrResourceLoader,
(String) modelMap.get(CLASS_KEY), // modelClassName
(String) modelMap.get(NAME_KEY), // modelName
features,
@ -239,6 +240,28 @@ public class ManagedModelStore extends ManagedResource implements ManagedResourc
featureStore.getName(),
featureStore.getFeatures(),
(Map<String,Object>) modelMap.get(PARAMS_KEY));
if (ltrScoringModel instanceof WrapperModel) {
initWrapperModel(solrResourceLoader, (WrapperModel)ltrScoringModel, managedFeatureStore);
}
return ltrScoringModel;
}
private static void initWrapperModel(SolrResourceLoader solrResourceLoader,
WrapperModel wrapperModel, ManagedFeatureStore managedFeatureStore) {
wrapperModel.setSolrResourceLoader(solrResourceLoader);
final LTRScoringModel model = fromLTRScoringModelMap(
solrResourceLoader,
wrapperModel.fetchModelMap(),
managedFeatureStore);
if (model instanceof WrapperModel) {
log.warn("It is unusual for one WrapperModel ({}) to wrap another WrapperModel ({})",
wrapperModel.getName(),
model.getName());
initWrapperModel(solrResourceLoader, (WrapperModel)model, managedFeatureStore);
}
wrapperModel.updateModel(model);
}
private static LinkedHashMap<String,Object> toLTRScoringModelMap(LTRScoringModel model) {
@ -249,10 +272,12 @@ public class ManagedModelStore extends ManagedResource implements ManagedResourc
modelMap.put(STORE_KEY, model.getFeatureStoreName());
final List<Map<String,Object>> features = new ArrayList<>();
final List<Feature> featuresList = model.getFeatures();
final List<Normalizer> normsList = model.getNorms();
for (int ii=0; ii<featuresList.size(); ++ii) {
features.add(toFeatureMap(featuresList.get(ii), normsList.get(ii)));
if (!(model instanceof WrapperModel)) {
final List<Feature> featuresList = model.getFeatures();
final List<Normalizer> normsList = model.getNorms();
for (int ii = 0; ii < featuresList.size(); ++ii) {
features.add(toFeatureMap(featuresList.get(ii), normsList.get(ii)));
}
}
modelMap.put(FEATURES_KEY, features);

View File

@ -16,6 +16,9 @@
<directoryFactory name="DirectoryFactory"
class="${solr.directoryFactory:solr.RAMDirectoryFactory}" />
<!-- for use with the DefaultWrapperModel class -->
<lib dir="${solr.solr.home:.}/models" />
<schemaFactory class="ClassicIndexSchemaFactory" />
<requestDispatcher>

View File

@ -0,0 +1,145 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.model;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStreamWriter;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.ltr.TestRerankBase;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.FieldValueFeature;
import org.apache.solr.ltr.feature.ValueFeature;
import org.apache.solr.ltr.store.rest.ManagedModelStore;
import org.junit.BeforeClass;
import org.junit.Test;
public class TestDefaultWrapperModel extends TestRerankBase {
final private static String featureStoreName = "test";
private static String baseModelJson = null;
private static File baseModelFile = null;
static List<Feature> features = null;
@BeforeClass
public static void setupBeforeClass() throws Exception {
setuptest(false);
assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity", "1"));
assertU(adoc("id", "2", "title", "w2", "description", "w2", "popularity", "2"));
assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity", "3"));
assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity", "4"));
assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", "5"));
assertU(commit());
loadFeature("popularity", FieldValueFeature.class.getCanonicalName(), "test", "{\"field\":\"popularity\"}");
loadFeature("const", ValueFeature.class.getCanonicalName(), "test", "{\"value\":5}");
features = new ArrayList<>();
features.add(getManagedFeatureStore().getFeatureStore("test").get("popularity"));
features.add(getManagedFeatureStore().getFeatureStore("test").get("const"));
baseModelJson = getModelInJson("linear", LinearModel.class.getCanonicalName(),
new String[] {"popularity", "const"},
featureStoreName,
"{\"weights\":{\"popularity\":-1.0, \"const\":1.0}}");
// prepare the base model as a file resource
baseModelFile = new File(tmpConfDir, "baseModel.json");
try (BufferedWriter writer = new BufferedWriter(
new OutputStreamWriter(new FileOutputStream(baseModelFile), StandardCharsets.UTF_8))) {
writer.write(baseModelJson);
}
baseModelFile.deleteOnExit();
}
private static String getDefaultWrapperModelInJson(String wrapperModelName, String[] features, String params) {
return getModelInJson(wrapperModelName, DefaultWrapperModel.class.getCanonicalName(),
features, featureStoreName, params);
}
@Test
public void testLoadModelFromResource() throws Exception {
String wrapperModelJson = getDefaultWrapperModelInJson("fileWrapper",
new String[0],
"{\"resource\":\"" + baseModelFile.getName() + "\"}");
assertJPut(ManagedModelStore.REST_END_POINT, wrapperModelJson, "/responseHeader/status==0");
final SolrQuery query = new SolrQuery();
query.setQuery("{!func}pow(popularity,2)");
query.add("rows", "3");
query.add("fl", "*,score");
query.add("rq", "{!ltr reRankDocs=3 model=fileWrapper}");
assertJQ("/query" + query.toQueryString(),
"/response/docs/[0]/id==\"3\"", "/response/docs/[0]/score==2.0",
"/response/docs/[1]/id==\"4\"", "/response/docs/[1]/score==1.0",
"/response/docs/[2]/id==\"5\"", "/response/docs/[2]/score==0.0");
}
@Test
public void testLoadNestedWrapperModel() throws Exception {
String otherWrapperModelJson = getDefaultWrapperModelInJson("otherNestedWrapper",
new String[0],
"{\"resource\":\"" + baseModelFile.getName() + "\"}");
File otherWrapperModelFile = new File(tmpConfDir, "nestedWrapperModel.json");
try (BufferedWriter writer = new BufferedWriter(
new OutputStreamWriter(new FileOutputStream(otherWrapperModelFile), StandardCharsets.UTF_8))) {
writer.write(otherWrapperModelJson);
}
String wrapperModelJson = getDefaultWrapperModelInJson("nestedWrapper",
new String[0],
"{\"resource\":\"" + otherWrapperModelFile.getName() + "\"}");
assertJPut(ManagedModelStore.REST_END_POINT, wrapperModelJson, "/responseHeader/status==0");
final SolrQuery query = new SolrQuery();
query.setQuery("{!func}pow(popularity,2)");
query.add("rows", "3");
query.add("fl", "*,score");
query.add("rq", "{!ltr reRankDocs=3 model=nestedWrapper}");
assertJQ("/query" + query.toQueryString(),
"/response/docs/[0]/id==\"3\"", "/response/docs/[0]/score==2.0",
"/response/docs/[1]/id==\"4\"", "/response/docs/[1]/score==1.0",
"/response/docs/[2]/id==\"5\"", "/response/docs/[2]/score==0.0");
}
@Test
public void testLoadModelFromUnknownResource() throws Exception {
String wrapperModelJson = getDefaultWrapperModelInJson("unknownWrapper",
new String[0],
"{\"resource\":\"unknownModel.json\"}");
assertJPut(ManagedModelStore.REST_END_POINT, wrapperModelJson,
"/responseHeader/status==400",
"/error/msg==\"org.apache.solr.ltr.model.ModelException: "
+ "Failed to fetch the wrapper model from given resource (unknownModel.json)\"");
}
@Test
public void testLoadModelWithEmptyParams() throws Exception {
String wrapperModelJson = getDefaultWrapperModelInJson("invalidWrapper",
new String[0],
"{}");
assertJPut(ManagedModelStore.REST_END_POINT, wrapperModelJson,
"/responseHeader/status==400",
"/error/msg==\"org.apache.solr.ltr.model.ModelException: "
+ "no resource configured for model invalidWrapper\"");
}
}

View File

@ -0,0 +1,290 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.model;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.solr.ltr.TestRerankBase;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.ValueFeature;
import org.apache.solr.ltr.norm.IdentityNormalizer;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.ltr.store.FeatureStore;
import org.junit.Test;
import org.mockito.Mockito;
public class TestWrapperModel extends TestRerankBase {
private static class StubWrapperModel extends WrapperModel {
private StubWrapperModel(String name) {
this(name, Collections.emptyList(), Collections.emptyList());
}
private StubWrapperModel(String name, List<Feature> features, List<Normalizer> norms) {
super(name, features, norms, FeatureStore.DEFAULT_FEATURE_STORE_NAME, features, Collections.emptyMap());
}
@Override
public Map<String, Object> fetchModelMap() throws ModelException {
return null;
}
}
private static LTRScoringModel createMockWrappedModel(String featureStoreName,
List<Feature> features, List<Normalizer> norms) {
LTRScoringModel wrappedModel = Mockito.mock(LTRScoringModel.class);
Mockito.doReturn(featureStoreName).when(wrappedModel).getFeatureStoreName();
Mockito.doReturn(features).when(wrappedModel).getFeatures();
Mockito.doReturn(norms).when(wrappedModel).getNorms();
return wrappedModel;
}
@Test
public void testValidate() throws Exception {
WrapperModel wrapperModel = new StubWrapperModel("testModel");
try {
wrapperModel.validate();
} catch (ModelException e) {
fail("Validation must succeed if no wrapped model is set");
}
// wrapper model with features
WrapperModel wrapperModelWithFeatures = new StubWrapperModel("testModel",
Collections.singletonList(new ValueFeature("val", Collections.emptyMap())), Collections.emptyList());
try {
wrapperModelWithFeatures.validate();
fail("Validation must fail if features of the wrapper model isn't empty");
} catch (ModelException e) {
assertEquals("features must be empty for the wrapper model testModel", e.getMessage());
}
// wrapper model with norms
WrapperModel wrapperModelWithNorms = new StubWrapperModel("testModel",
Collections.emptyList(), Collections.singletonList(IdentityNormalizer.INSTANCE));
try {
wrapperModelWithNorms.validate();
fail("Validation must fail if norms of the wrapper model isn't empty");
} catch (ModelException e) {
assertEquals("norms must be empty for the wrapper model testModel", e.getMessage());
}
assumeWorkingMockito();
// update valid model
{
LTRScoringModel wrappedModel =
createMockWrappedModel(FeatureStore.DEFAULT_FEATURE_STORE_NAME,
Arrays.asList(
new ValueFeature("v1", Collections.emptyMap()),
new ValueFeature("v2", Collections.emptyMap())),
Arrays.asList(
IdentityNormalizer.INSTANCE,
IdentityNormalizer.INSTANCE)
);
try {
wrapperModel.updateModel(wrappedModel);
} catch (ModelException e) {
fail("Validation must succeed if the wrapped model is valid");
}
}
// update invalid model (feature store mismatch)
{
LTRScoringModel wrappedModel =
createMockWrappedModel("wrappedFeatureStore",
Arrays.asList(
new ValueFeature("v1", Collections.emptyMap()),
new ValueFeature("v2", Collections.emptyMap())),
Arrays.asList(
IdentityNormalizer.INSTANCE,
IdentityNormalizer.INSTANCE)
);
try {
wrapperModel.updateModel(wrappedModel);
fail("Validation must fail if wrapped model feature store differs from wrapper model feature store");
} catch (ModelException e) {
assertEquals("wrapper feature store name (_DEFAULT_) must match the wrapped feature store name (wrappedFeatureStore)", e.getMessage());
}
}
// update invalid model (no features)
{
LTRScoringModel wrappedModel =
createMockWrappedModel(FeatureStore.DEFAULT_FEATURE_STORE_NAME,
Collections.emptyList(),
Arrays.asList(
IdentityNormalizer.INSTANCE,
IdentityNormalizer.INSTANCE)
);
try {
wrapperModel.updateModel(wrappedModel);
fail("Validation must fail if the wrapped model is invalid");
} catch (ModelException e) {
assertEquals("no features declared for model testModel", e.getMessage());
}
}
// update invalid model (no norms)
{
LTRScoringModel wrappedModel =
createMockWrappedModel(FeatureStore.DEFAULT_FEATURE_STORE_NAME,
Arrays.asList(
new ValueFeature("v1", Collections.emptyMap()),
new ValueFeature("v2", Collections.emptyMap())),
Collections.emptyList()
);
try {
wrapperModel.updateModel(wrappedModel);
fail("Validation must fail if the wrapped model is invalid");
} catch (ModelException e) {
assertEquals("counted 2 features and 0 norms in model testModel", e.getMessage());
}
}
}
@Test
public void testMethodOverridesAndDelegation() throws Exception {
assumeWorkingMockito();
final int overridableMethodCount = testOverwrittenMethods();
final int methodCount = testDelegateMethods();
assertEquals("method count mismatch", overridableMethodCount, methodCount);
}
private int testOverwrittenMethods() throws Exception {
int overridableMethodCount = 0;
for (final Method superClassMethod : LTRScoringModel.class.getDeclaredMethods()) {
final int modifiers = superClassMethod.getModifiers();
if (Modifier.isFinal(modifiers)) continue;
if (Modifier.isStatic(modifiers)) continue;
++overridableMethodCount;
if (Arrays.asList(
"getName", // the wrapper model's name is its own name i.e. _not_ the name of the wrapped model
"getFeatureStoreName", // wrapper and wrapped model feature store should match, so need not override
"getParams" // the wrapper model's params are its own params i.e. _not_ the params of the wrapped model
).contains(superClassMethod.getName())) {
try {
final Method subClassMethod = WrapperModel.class.getDeclaredMethod(
superClassMethod.getName(),
superClassMethod.getParameterTypes());
fail(WrapperModel.class + " need not override\n'" + superClassMethod + "'"
+ " but it does override\n'" + subClassMethod + "'");
} catch (NoSuchMethodException e) {
// ok
}
} else {
try {
final Method subClassMethod = WrapperModel.class.getDeclaredMethod(
superClassMethod.getName(),
superClassMethod.getParameterTypes());
assertEquals("getReturnType() difference",
superClassMethod.getReturnType(),
subClassMethod.getReturnType());
} catch (NoSuchMethodException e) {
fail(WrapperModel.class + " needs to override '" + superClassMethod + "'");
}
}
}
return overridableMethodCount;
}
private int testDelegateMethods() throws Exception {
int methodCount = 0;
WrapperModel wrapperModel = Mockito.spy(new StubWrapperModel("testModel"));
// ignore validate in this test case
Mockito.doNothing().when(wrapperModel).validate();
++methodCount;
LTRScoringModel wrappedModel = Mockito.mock(LTRScoringModel.class);
wrapperModel.updateModel(wrappedModel);
// cannot be stubbed or verified
++methodCount; // toString
++methodCount; // hashCode
++methodCount; // equals
// getFeatureStoreName : not delegate
Mockito.reset(wrappedModel);
wrapperModel.getFeatureStoreName();
++methodCount;
Mockito.verify(wrappedModel, Mockito.times(0)).getFeatureStoreName();
// getName : not delegate
Mockito.reset(wrappedModel);
wrapperModel.getName();
++methodCount;
Mockito.verify(wrappedModel, Mockito.times(0)).getName();
// getParams : not delegate
Mockito.reset(wrappedModel);
wrapperModel.getParams();
++methodCount;
Mockito.verify(wrappedModel, Mockito.times(0)).getParams();
// getNorms : delegate
Mockito.reset(wrappedModel);
wrapperModel.getNorms();
++methodCount;
Mockito.verify(wrappedModel, Mockito.times(1)).getNorms();
// getFeatures : delegate
Mockito.reset(wrappedModel);
wrapperModel.getFeatures();
++methodCount;
Mockito.verify(wrappedModel, Mockito.times(1)).getFeatures();
// getAllFeatures : delegate
Mockito.reset(wrappedModel);
wrapperModel.getAllFeatures();
++methodCount;
Mockito.verify(wrappedModel, Mockito.times(1)).getAllFeatures();
// score : delegate
Mockito.reset(wrappedModel);
wrapperModel.score(null);
++methodCount;
Mockito.verify(wrappedModel, Mockito.times(1)).score(null);
// normalizeFeaturesInPlace : delegate
Mockito.reset(wrappedModel);
wrapperModel.normalizeFeaturesInPlace(null);
++methodCount;
Mockito.verify(wrappedModel, Mockito.times(1)).normalizeFeaturesInPlace(null);
// getNormalizerExplanation : delegate
Mockito.reset(wrappedModel);
wrapperModel.getNormalizerExplanation(null, 0);
++methodCount;
Mockito.verify(wrappedModel, Mockito.times(1)).getNormalizerExplanation(null, 0);
// explain : delegate
Mockito.reset(wrappedModel);
wrapperModel.explain(null, 0, 0.0f, null);
++methodCount;
Mockito.verify(wrappedModel, Mockito.times(1)).explain(null, 0, 0.0f, null);
return methodCount;
}
}

View File

@ -16,12 +16,19 @@
*/
package org.apache.solr.ltr.store.rest;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStreamWriter;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.apache.solr.ltr.TestRerankBase;
import org.apache.solr.ltr.feature.FieldValueFeature;
import org.apache.solr.ltr.feature.ValueFeature;
import org.apache.solr.ltr.model.DefaultWrapperModel;
import org.apache.solr.ltr.model.LinearModel;
import org.apache.solr.ltr.store.FeatureStore;
import org.junit.BeforeClass;
@ -185,4 +192,73 @@ public class TestModelManagerPersistence extends TestRerankBase {
assertJQ(ManagedFeatureStore.REST_END_POINT + "/" + FeatureStore.DEFAULT_FEATURE_STORE_NAME,
"/features/==[]");
}
private static void doWrapperModelPersistenceChecks(String modelName,
String featureStoreName, String baseModelFileName) throws Exception {
// note that the wrapper and the wrapped model always have the same name
assertJQ(ManagedModelStore.REST_END_POINT,
// the wrapped model shouldn't be registered
"!/models/[1]/name=='"+modelName+"'",
// but the wrapper model should be registered
"/models/[0]/name=='"+modelName+"'",
"/models/[0]/class=='" + DefaultWrapperModel.class.getCanonicalName() + "'",
"/models/[0]/store=='" + featureStoreName + "'",
// the wrapper model shouldn't contain the definitions of the wrapped model
"/models/[0]/features/==[]",
// but only its own parameters
"/models/[0]/params=={resource:'"+baseModelFileName+"'}");
}
@Test
public void testWrapperModelPersistence() throws Exception {
final String modelName = "linear";
final String FS_NAME = "testWrapper";
// check whether models and features are empty
assertJQ(ManagedModelStore.REST_END_POINT,
"/models/==[]");
assertJQ(ManagedFeatureStore.REST_END_POINT + "/" + FS_NAME,
"/features/==[]");
// setup features
loadFeature("popularity", FieldValueFeature.class.getCanonicalName(), FS_NAME, "{\"field\":\"popularity\"}");
loadFeature("const", ValueFeature.class.getCanonicalName(), FS_NAME, "{\"value\":5}");
// setup base model
String baseModelJson = getModelInJson(modelName, LinearModel.class.getCanonicalName(),
new String[] {"popularity", "const"}, FS_NAME,
"{\"weights\":{\"popularity\":-1.0, \"const\":1.0}}");
File baseModelFile = new File(tmpConfDir, "baseModelForPersistence.json");
try (BufferedWriter writer = new BufferedWriter(
new OutputStreamWriter(new FileOutputStream(baseModelFile), StandardCharsets.UTF_8))) {
writer.write(baseModelJson);
}
baseModelFile.deleteOnExit();
// setup wrapper model
String wrapperModelJson = getModelInJson(modelName, DefaultWrapperModel.class.getCanonicalName(),
new String[0], FS_NAME,
"{\"resource\":\"" + baseModelFile.getName() + "\"}");
assertJPut(ManagedModelStore.REST_END_POINT, wrapperModelJson, "/responseHeader/status==0");
doWrapperModelPersistenceChecks(modelName, FS_NAME, baseModelFile.getName());
// check persistence after reload
restTestHarness.reload();
doWrapperModelPersistenceChecks(modelName, FS_NAME, baseModelFile.getName());
// check persistence after restart
jetty.stop();
jetty.start();
doWrapperModelPersistenceChecks(modelName, FS_NAME, baseModelFile.getName());
// delete test settings
restTestHarness.delete(ManagedModelStore.REST_END_POINT + "/" + modelName);
restTestHarness.delete(ManagedFeatureStore.REST_END_POINT + "/" + FS_NAME);
assertJQ(ManagedModelStore.REST_END_POINT,
"/models/==[]");
assertJQ(ManagedFeatureStore.REST_END_POINT + "/" + FS_NAME,
"/features/==[]");
// NOTE: we don't test the persistence of the deletion here because it's tested in testFilePersistence
}
}

View File

@ -87,6 +87,7 @@ Feature selection and model training take place offline and outside Solr. The lt
|General form |Class |Specific examples
|Linear |{solr-javadocs}/solr-ltr/org/apache/solr/ltr/model/LinearModel.html[LinearModel] |RankSVM, Pranking
|Multiple Additive Trees |{solr-javadocs}/solr-ltr/org/apache/solr/ltr/model/MultipleAdditiveTreesModel.html[MultipleAdditiveTreesModel] |LambdaMART, Gradient Boosted Regression Trees (GBRT)
|(wrapper) |{solr-javadocs}/solr-ltr/org/apache/solr/ltr/model/DefaultWrapperModel.html[DefaultWrapperModel] |(not applicable)
|(custom) |(custom class extending {solr-javadocs}/solr-ltr/org/apache/solr/ltr/model/LTRScoringModel.html[LTRScoringModel]) |(not applicable)
|===
@ -509,6 +510,54 @@ To delete the `currentFeatureStore` feature store:
curl -XDELETE 'http://localhost:8983/solr/techproducts/schema/feature-store/currentFeatureStore'
----
==== Using large models
With SolrCloud, large models may fail to upload due to the limitation of ZooKeeper's buffer. In this case, `DefaultWrapperModel` may help you to separate the model definition from uploaded file.
Assuming that you consider to use a large model placed at `/path/to/models/myModel.json` through `DefaultWrapperModel`.
[source,json]
----
{
"store" : "largeModelsFeatureStore",
"name" : "myModel",
"class" : ...,
"features" : [
...
],
"params" : {
...
}
}
----
First, add the directory to Solr's resource paths by <<lib-directives-in-solrconfig.adoc#lib-directives-in-solrconfig,Lib Directives>>:
[source,xml]
----
<lib dir="/path/to" regex="models" />
----
Then, configure `DefaultWrapperModel` to wrap `myModel.json`:
[source,json]
----
{
"store" : "largeModelsFeatureStore",
"name" : "myWrapperModel",
"class" : "org.apache.solr.ltr.model.DefaultWrapperModel",
"params" : {
"resource" : "myModel.json"
}
}
----
`myModel.json` will be loaded during the initialization and be able to use by specifying `model=myWrapperModel`.
NOTE: No `"features"` are configured in `myWrapperModel` because the features of the wrapped model (`myModel`) will be used; also note that the `"store"` configured for the wrapper model must match that of the wrapped model i.e. in this example the feature store called `largeModelsFeatureStore` is used.
CAUTION: `<lib dir="/path/to/models" regex=".*\.json" />` doesn't work as expected in this case, because `SolrResourceLoader` considers given resources as JAR if `<lib />` indicates files.
=== Applying Changes
The feature store and the model store are both <<managed-resources.adoc#managed-resources,Managed Resources>>. Changes made to managed resources are not applied to the active Solr components until the Solr collection (or Solr core in single server mode) is reloaded.