From 64b3a5bb4b67d4bfee9fa022ce4acd71f3923c4c Mon Sep 17 00:00:00 2001 From: Christine Poerschke Date: Tue, 28 Nov 2017 14:55:57 +0000 Subject: [PATCH] SOLR-11250: A new DefaultWrapperModel class for loading of large and/or externally stored LTRScoringModel definitions. (Yuki Yano, shalin, Christine Poerschke) --- solr/CHANGES.txt | 3 + solr/contrib/ltr/build.xml | 5 + solr/contrib/ltr/ivy.xml | 5 +- .../solr/ltr/model/DefaultWrapperModel.java | 105 +++++++ .../apache/solr/ltr/model/WrapperModel.java | 169 ++++++++++ .../ltr/store/rest/ManagedModelStore.java | 35 ++- .../solr/collection1/conf/solrconfig-ltr.xml | 3 + .../ltr/model/TestDefaultWrapperModel.java | 145 +++++++++ .../solr/ltr/model/TestWrapperModel.java | 290 ++++++++++++++++++ .../rest/TestModelManagerPersistence.java | 76 +++++ solr/solr-ref-guide/src/learning-to-rank.adoc | 49 +++ 11 files changed, 878 insertions(+), 7 deletions(-) create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/model/DefaultWrapperModel.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/model/WrapperModel.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestDefaultWrapperModel.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestWrapperModel.java diff --git a/solr/CHANGES.txt b/solr/CHANGES.txt index f384bbec235..ac5b605f639 100644 --- a/solr/CHANGES.txt +++ b/solr/CHANGES.txt @@ -99,6 +99,9 @@ New Features * SOLR-11202: Implement a set-property command for AutoScaling API. (ab, shalin) +* SOLR-11250: A new DefaultWrapperModel class for loading of large and/or externally stored + LTRScoringModel definitions. (Yuki Yano, shalin, Christine Poerschke) + Bug Fixes ---------------------- diff --git a/solr/contrib/ltr/build.xml b/solr/contrib/ltr/build.xml index bbd5cf3d9b1..a5778c44147 100644 --- a/solr/contrib/ltr/build.xml +++ b/solr/contrib/ltr/build.xml @@ -25,6 +25,11 @@ + + + + + diff --git a/solr/contrib/ltr/ivy.xml b/solr/contrib/ltr/ivy.xml index 68e9797bb09..3b7e1c70b3c 100644 --- a/solr/contrib/ltr/ivy.xml +++ b/solr/contrib/ltr/ivy.xml @@ -24,9 +24,10 @@ - - + + + diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/DefaultWrapperModel.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/DefaultWrapperModel.java new file mode 100644 index 00000000000..b21b6c3a150 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/DefaultWrapperModel.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.ltr.model; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.Reader; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import org.apache.solr.core.SolrResourceLoader; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.norm.Normalizer; +import org.noggit.JSONParser; +import org.noggit.ObjectBuilder; + +/** + * A scoring model that fetches the wrapped model from {@link SolrResourceLoader}. + * + *

This model uses {@link SolrResourceLoader#openResource(String)} for fetching the wrapped model. + * If you give a relative path for {@code params/resource}, this model will try to load the wrapped model from + * the instance directory (i.e. ${solr.solr.home}). Otherwise, seek through classpaths. + * + *

Example configuration: + *

{
+  "class": "org.apache.solr.ltr.model.DefaultWrapperModel",
+  "name": "myWrapperModelName",
+  "params": {
+    "resource": "models/myModel.json"
+  }
+}
+ * + * @see SolrResourceLoader#openResource(String) + */ +public class DefaultWrapperModel extends WrapperModel { + + private String resource; + + public DefaultWrapperModel(String name, List features, List norms, String featureStoreName, + List allFeatures, Map params) { + super(name, features, norms, featureStoreName, allFeatures, params); + } + + public void setResource(String resource) { + this.resource = resource; + } + + @Override + protected void validate() throws ModelException { + super.validate(); + if (resource == null) { + throw new ModelException("no resource configured for model "+name); + } + } + + @Override + public Map fetchModelMap() throws ModelException { + Map modelMapObj; + try (InputStream in = openInputStream()) { + modelMapObj = parseInputStream(in); + } catch (IOException e) { + throw new ModelException("Failed to fetch the wrapper model from given resource (" + resource + ")", e); + } + return modelMapObj; + } + + protected InputStream openInputStream() throws IOException { + return solrResourceLoader.openResource(resource); + } + + @SuppressWarnings("unchecked") + protected Map parseInputStream(InputStream in) throws IOException { + try (Reader reader = new BufferedReader(new InputStreamReader(in, StandardCharsets.UTF_8))) { + return (Map) new ObjectBuilder(new JSONParser(reader)).getVal(); + } + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder(getClass().getSimpleName()); + sb.append("(name=").append(getName()); + sb.append(",resource=").append(resource); + sb.append(",model=(").append(model.toString()).append(")"); + + return sb.toString(); + } +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/WrapperModel.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/WrapperModel.java new file mode 100644 index 00000000000..cf66135ba98 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/WrapperModel.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.ltr.model; + +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Explanation; +import org.apache.solr.core.SolrResourceLoader; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.norm.Normalizer; + +/** + * A scoring model that wraps the other model. + * + *

This model loads a model from an external resource during the initialization. + * The way of fetching the wrapped model is depended on + * the implementation of {@link WrapperModel#fetchModelMap()}. + * + *

This model doesn't hold the actual parameters of the wrapped model, + * thus it can manage large models which are difficult to upload to ZooKeeper. + * + *

Example configuration: + *

{
+    "class": "...",
+    "name": "myModelName",
+    "params": {
+        ...
+    }
+ }
+ * + *

NOTE: no "features" are configured in the wrapper model + * because the wrapped model's features will be used instead. + * Also note that if a "store" is configured for the wrapper + * model then it must match the "store" of the wrapped model. + */ +public abstract class WrapperModel extends LTRScoringModel { + + protected SolrResourceLoader solrResourceLoader; + protected LTRScoringModel model; + + @Override + public int hashCode() { + final int prime = 31; + int result = super.hashCode(); + result = prime * result + ((model == null) ? 0 : model.hashCode()); + result = prime * result + ((solrResourceLoader == null) ? 0 : solrResourceLoader.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (!super.equals(obj)) return false; + if (getClass() != obj.getClass()) return false; + WrapperModel other = (WrapperModel) obj; + if (model == null) { + if (other.model != null) return false; + } else if (!model.equals(other.model)) return false; + if (solrResourceLoader == null) { + if (other.solrResourceLoader != null) return false; + } else if (!solrResourceLoader.equals(other.solrResourceLoader)) return false; + return true; + } + + public WrapperModel(String name, List features, List norms, String featureStoreName, + List allFeatures, Map params) { + super(name, features, norms, featureStoreName, allFeatures, params); + } + + @Override + protected void validate() throws ModelException { + if (!features.isEmpty()) { + throw new ModelException("features must be empty for the wrapper model " + name); + } + if (!norms.isEmpty()) { + throw new ModelException("norms must be empty for the wrapper model " + name); + } + + if (model != null) { + super.validate(); + model.validate(); + // check feature store names match + final String wrappedFeatureStoreName = model.getFeatureStoreName(); + if (wrappedFeatureStoreName == null || !wrappedFeatureStoreName.equals(this.getFeatureStoreName())) { + throw new ModelException("wrapper feature store name ("+this.getFeatureStoreName() +")" + + " must match the " + + "wrapped feature store name ("+wrappedFeatureStoreName+")"); + } + } + } + + public void setSolrResourceLoader(SolrResourceLoader solrResourceLoader) { + this.solrResourceLoader = solrResourceLoader; + } + + public void updateModel(LTRScoringModel model) { + this.model = model; + validate(); + } + + /* + * The child classes must implement how to fetch the definition of the wrapped model. + */ + public abstract Map fetchModelMap() throws ModelException; + + @Override + public List getNorms() { + return model.getNorms(); + } + + @Override + public List getFeatures() { + return model.getFeatures(); + } + + @Override + public Collection getAllFeatures() { + return model.getAllFeatures(); + } + + @Override + public float score(float[] modelFeatureValuesNormalized) { + return model.score(modelFeatureValuesNormalized); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc, float finalScore, + List featureExplanations) { + return model.explain(context, doc, finalScore, featureExplanations); + } + + @Override + public void normalizeFeaturesInPlace(float[] modelFeatureValues) { + model.normalizeFeaturesInPlace(modelFeatureValues); + } + + @Override + public Explanation getNormalizerExplanation(Explanation e, int idx) { + return model.getNormalizerExplanation(e, idx); + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder(getClass().getSimpleName()); + sb.append("(name=").append(getName()); + sb.append(",model=(").append(model.toString()).append(")"); + + return sb.toString(); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/rest/ManagedModelStore.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/rest/ManagedModelStore.java index 79640c13617..342a14067c6 100644 --- a/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/rest/ManagedModelStore.java +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/rest/ManagedModelStore.java @@ -27,6 +27,7 @@ import org.apache.solr.common.util.NamedList; import org.apache.solr.core.SolrCore; import org.apache.solr.core.SolrResourceLoader; import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.model.WrapperModel; import org.apache.solr.ltr.model.LTRScoringModel; import org.apache.solr.ltr.model.ModelException; import org.apache.solr.ltr.norm.IdentityNormalizer; @@ -231,7 +232,7 @@ public class ManagedModelStore extends ManagedResource implements ManagedResourc } } - return LTRScoringModel.getInstance(solrResourceLoader, + final LTRScoringModel ltrScoringModel = LTRScoringModel.getInstance(solrResourceLoader, (String) modelMap.get(CLASS_KEY), // modelClassName (String) modelMap.get(NAME_KEY), // modelName features, @@ -239,6 +240,28 @@ public class ManagedModelStore extends ManagedResource implements ManagedResourc featureStore.getName(), featureStore.getFeatures(), (Map) modelMap.get(PARAMS_KEY)); + + if (ltrScoringModel instanceof WrapperModel) { + initWrapperModel(solrResourceLoader, (WrapperModel)ltrScoringModel, managedFeatureStore); + } + + return ltrScoringModel; + } + + private static void initWrapperModel(SolrResourceLoader solrResourceLoader, + WrapperModel wrapperModel, ManagedFeatureStore managedFeatureStore) { + wrapperModel.setSolrResourceLoader(solrResourceLoader); + final LTRScoringModel model = fromLTRScoringModelMap( + solrResourceLoader, + wrapperModel.fetchModelMap(), + managedFeatureStore); + if (model instanceof WrapperModel) { + log.warn("It is unusual for one WrapperModel ({}) to wrap another WrapperModel ({})", + wrapperModel.getName(), + model.getName()); + initWrapperModel(solrResourceLoader, (WrapperModel)model, managedFeatureStore); + } + wrapperModel.updateModel(model); } private static LinkedHashMap toLTRScoringModelMap(LTRScoringModel model) { @@ -249,10 +272,12 @@ public class ManagedModelStore extends ManagedResource implements ManagedResourc modelMap.put(STORE_KEY, model.getFeatureStoreName()); final List> features = new ArrayList<>(); - final List featuresList = model.getFeatures(); - final List normsList = model.getNorms(); - for (int ii=0; ii featuresList = model.getFeatures(); + final List normsList = model.getNorms(); + for (int ii = 0; ii < featuresList.size(); ++ii) { + features.add(toFeatureMap(featuresList.get(ii), normsList.get(ii))); + } } modelMap.put(FEATURES_KEY, features); diff --git a/solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml b/solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml index 4f1f5ca6db2..c5519384057 100644 --- a/solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml +++ b/solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml @@ -16,6 +16,9 @@ + + + diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestDefaultWrapperModel.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestDefaultWrapperModel.java new file mode 100644 index 00000000000..a930c23f2b1 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestDefaultWrapperModel.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.ltr.model; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileOutputStream; +import java.io.OutputStreamWriter; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.feature.FieldValueFeature; +import org.apache.solr.ltr.feature.ValueFeature; +import org.apache.solr.ltr.store.rest.ManagedModelStore; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestDefaultWrapperModel extends TestRerankBase { + + final private static String featureStoreName = "test"; + private static String baseModelJson = null; + private static File baseModelFile = null; + + static List features = null; + + @BeforeClass + public static void setupBeforeClass() throws Exception { + setuptest(false); + assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity", "1")); + assertU(adoc("id", "2", "title", "w2", "description", "w2", "popularity", "2")); + assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity", "3")); + assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity", "4")); + assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", "5")); + assertU(commit()); + + loadFeature("popularity", FieldValueFeature.class.getCanonicalName(), "test", "{\"field\":\"popularity\"}"); + loadFeature("const", ValueFeature.class.getCanonicalName(), "test", "{\"value\":5}"); + features = new ArrayList<>(); + features.add(getManagedFeatureStore().getFeatureStore("test").get("popularity")); + features.add(getManagedFeatureStore().getFeatureStore("test").get("const")); + + baseModelJson = getModelInJson("linear", LinearModel.class.getCanonicalName(), + new String[] {"popularity", "const"}, + featureStoreName, + "{\"weights\":{\"popularity\":-1.0, \"const\":1.0}}"); + // prepare the base model as a file resource + baseModelFile = new File(tmpConfDir, "baseModel.json"); + try (BufferedWriter writer = new BufferedWriter( + new OutputStreamWriter(new FileOutputStream(baseModelFile), StandardCharsets.UTF_8))) { + writer.write(baseModelJson); + } + baseModelFile.deleteOnExit(); + } + + private static String getDefaultWrapperModelInJson(String wrapperModelName, String[] features, String params) { + return getModelInJson(wrapperModelName, DefaultWrapperModel.class.getCanonicalName(), + features, featureStoreName, params); + } + + @Test + public void testLoadModelFromResource() throws Exception { + String wrapperModelJson = getDefaultWrapperModelInJson("fileWrapper", + new String[0], + "{\"resource\":\"" + baseModelFile.getName() + "\"}"); + assertJPut(ManagedModelStore.REST_END_POINT, wrapperModelJson, "/responseHeader/status==0"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("{!func}pow(popularity,2)"); + query.add("rows", "3"); + query.add("fl", "*,score"); + query.add("rq", "{!ltr reRankDocs=3 model=fileWrapper}"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/id==\"3\"", "/response/docs/[0]/score==2.0", + "/response/docs/[1]/id==\"4\"", "/response/docs/[1]/score==1.0", + "/response/docs/[2]/id==\"5\"", "/response/docs/[2]/score==0.0"); + } + + @Test + public void testLoadNestedWrapperModel() throws Exception { + String otherWrapperModelJson = getDefaultWrapperModelInJson("otherNestedWrapper", + new String[0], + "{\"resource\":\"" + baseModelFile.getName() + "\"}"); + File otherWrapperModelFile = new File(tmpConfDir, "nestedWrapperModel.json"); + try (BufferedWriter writer = new BufferedWriter( + new OutputStreamWriter(new FileOutputStream(otherWrapperModelFile), StandardCharsets.UTF_8))) { + writer.write(otherWrapperModelJson); + } + + String wrapperModelJson = getDefaultWrapperModelInJson("nestedWrapper", + new String[0], + "{\"resource\":\"" + otherWrapperModelFile.getName() + "\"}"); + assertJPut(ManagedModelStore.REST_END_POINT, wrapperModelJson, "/responseHeader/status==0"); + final SolrQuery query = new SolrQuery(); + query.setQuery("{!func}pow(popularity,2)"); + query.add("rows", "3"); + query.add("fl", "*,score"); + query.add("rq", "{!ltr reRankDocs=3 model=nestedWrapper}"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/id==\"3\"", "/response/docs/[0]/score==2.0", + "/response/docs/[1]/id==\"4\"", "/response/docs/[1]/score==1.0", + "/response/docs/[2]/id==\"5\"", "/response/docs/[2]/score==0.0"); + } + + @Test + public void testLoadModelFromUnknownResource() throws Exception { + String wrapperModelJson = getDefaultWrapperModelInJson("unknownWrapper", + new String[0], + "{\"resource\":\"unknownModel.json\"}"); + assertJPut(ManagedModelStore.REST_END_POINT, wrapperModelJson, + "/responseHeader/status==400", + "/error/msg==\"org.apache.solr.ltr.model.ModelException: " + + "Failed to fetch the wrapper model from given resource (unknownModel.json)\""); + } + + @Test + public void testLoadModelWithEmptyParams() throws Exception { + String wrapperModelJson = getDefaultWrapperModelInJson("invalidWrapper", + new String[0], + "{}"); + assertJPut(ManagedModelStore.REST_END_POINT, wrapperModelJson, + "/responseHeader/status==400", + "/error/msg==\"org.apache.solr.ltr.model.ModelException: " + + "no resource configured for model invalidWrapper\""); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestWrapperModel.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestWrapperModel.java new file mode 100644 index 00000000000..78818ff5e67 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestWrapperModel.java @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.ltr.model; + +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.feature.ValueFeature; +import org.apache.solr.ltr.norm.IdentityNormalizer; +import org.apache.solr.ltr.norm.Normalizer; +import org.apache.solr.ltr.store.FeatureStore; +import org.junit.Test; +import org.mockito.Mockito; + +public class TestWrapperModel extends TestRerankBase { + + private static class StubWrapperModel extends WrapperModel { + + private StubWrapperModel(String name) { + this(name, Collections.emptyList(), Collections.emptyList()); + } + + private StubWrapperModel(String name, List features, List norms) { + super(name, features, norms, FeatureStore.DEFAULT_FEATURE_STORE_NAME, features, Collections.emptyMap()); + } + + @Override + public Map fetchModelMap() throws ModelException { + return null; + } + } + + private static LTRScoringModel createMockWrappedModel(String featureStoreName, + List features, List norms) { + LTRScoringModel wrappedModel = Mockito.mock(LTRScoringModel.class); + Mockito.doReturn(featureStoreName).when(wrappedModel).getFeatureStoreName(); + Mockito.doReturn(features).when(wrappedModel).getFeatures(); + Mockito.doReturn(norms).when(wrappedModel).getNorms(); + return wrappedModel; + } + + @Test + public void testValidate() throws Exception { + WrapperModel wrapperModel = new StubWrapperModel("testModel"); + try { + wrapperModel.validate(); + } catch (ModelException e) { + fail("Validation must succeed if no wrapped model is set"); + } + + // wrapper model with features + WrapperModel wrapperModelWithFeatures = new StubWrapperModel("testModel", + Collections.singletonList(new ValueFeature("val", Collections.emptyMap())), Collections.emptyList()); + try { + wrapperModelWithFeatures.validate(); + fail("Validation must fail if features of the wrapper model isn't empty"); + } catch (ModelException e) { + assertEquals("features must be empty for the wrapper model testModel", e.getMessage()); + } + + // wrapper model with norms + WrapperModel wrapperModelWithNorms = new StubWrapperModel("testModel", + Collections.emptyList(), Collections.singletonList(IdentityNormalizer.INSTANCE)); + try { + wrapperModelWithNorms.validate(); + fail("Validation must fail if norms of the wrapper model isn't empty"); + } catch (ModelException e) { + assertEquals("norms must be empty for the wrapper model testModel", e.getMessage()); + } + + assumeWorkingMockito(); + + // update valid model + { + LTRScoringModel wrappedModel = + createMockWrappedModel(FeatureStore.DEFAULT_FEATURE_STORE_NAME, + Arrays.asList( + new ValueFeature("v1", Collections.emptyMap()), + new ValueFeature("v2", Collections.emptyMap())), + Arrays.asList( + IdentityNormalizer.INSTANCE, + IdentityNormalizer.INSTANCE) + ); + try { + wrapperModel.updateModel(wrappedModel); + } catch (ModelException e) { + fail("Validation must succeed if the wrapped model is valid"); + } + } + + // update invalid model (feature store mismatch) + { + LTRScoringModel wrappedModel = + createMockWrappedModel("wrappedFeatureStore", + Arrays.asList( + new ValueFeature("v1", Collections.emptyMap()), + new ValueFeature("v2", Collections.emptyMap())), + Arrays.asList( + IdentityNormalizer.INSTANCE, + IdentityNormalizer.INSTANCE) + ); + try { + wrapperModel.updateModel(wrappedModel); + fail("Validation must fail if wrapped model feature store differs from wrapper model feature store"); + } catch (ModelException e) { + assertEquals("wrapper feature store name (_DEFAULT_) must match the wrapped feature store name (wrappedFeatureStore)", e.getMessage()); + } + } + + // update invalid model (no features) + { + LTRScoringModel wrappedModel = + createMockWrappedModel(FeatureStore.DEFAULT_FEATURE_STORE_NAME, + Collections.emptyList(), + Arrays.asList( + IdentityNormalizer.INSTANCE, + IdentityNormalizer.INSTANCE) + ); + try { + wrapperModel.updateModel(wrappedModel); + fail("Validation must fail if the wrapped model is invalid"); + } catch (ModelException e) { + assertEquals("no features declared for model testModel", e.getMessage()); + } + } + + // update invalid model (no norms) + { + LTRScoringModel wrappedModel = + createMockWrappedModel(FeatureStore.DEFAULT_FEATURE_STORE_NAME, + Arrays.asList( + new ValueFeature("v1", Collections.emptyMap()), + new ValueFeature("v2", Collections.emptyMap())), + Collections.emptyList() + ); + try { + wrapperModel.updateModel(wrappedModel); + fail("Validation must fail if the wrapped model is invalid"); + } catch (ModelException e) { + assertEquals("counted 2 features and 0 norms in model testModel", e.getMessage()); + } + } + } + + @Test + public void testMethodOverridesAndDelegation() throws Exception { + assumeWorkingMockito(); + final int overridableMethodCount = testOverwrittenMethods(); + final int methodCount = testDelegateMethods(); + assertEquals("method count mismatch", overridableMethodCount, methodCount); + } + + private int testOverwrittenMethods() throws Exception { + int overridableMethodCount = 0; + for (final Method superClassMethod : LTRScoringModel.class.getDeclaredMethods()) { + final int modifiers = superClassMethod.getModifiers(); + if (Modifier.isFinal(modifiers)) continue; + if (Modifier.isStatic(modifiers)) continue; + + ++overridableMethodCount; + if (Arrays.asList( + "getName", // the wrapper model's name is its own name i.e. _not_ the name of the wrapped model + "getFeatureStoreName", // wrapper and wrapped model feature store should match, so need not override + "getParams" // the wrapper model's params are its own params i.e. _not_ the params of the wrapped model + ).contains(superClassMethod.getName())) { + try { + final Method subClassMethod = WrapperModel.class.getDeclaredMethod( + superClassMethod.getName(), + superClassMethod.getParameterTypes()); + fail(WrapperModel.class + " need not override\n'" + superClassMethod + "'" + + " but it does override\n'" + subClassMethod + "'"); + } catch (NoSuchMethodException e) { + // ok + } + } else { + try { + final Method subClassMethod = WrapperModel.class.getDeclaredMethod( + superClassMethod.getName(), + superClassMethod.getParameterTypes()); + assertEquals("getReturnType() difference", + superClassMethod.getReturnType(), + subClassMethod.getReturnType()); + } catch (NoSuchMethodException e) { + fail(WrapperModel.class + " needs to override '" + superClassMethod + "'"); + } + } + } + return overridableMethodCount; + } + + private int testDelegateMethods() throws Exception { + int methodCount = 0; + WrapperModel wrapperModel = Mockito.spy(new StubWrapperModel("testModel")); + + // ignore validate in this test case + Mockito.doNothing().when(wrapperModel).validate(); + ++methodCount; + + LTRScoringModel wrappedModel = Mockito.mock(LTRScoringModel.class); + wrapperModel.updateModel(wrappedModel); + + // cannot be stubbed or verified + ++methodCount; // toString + ++methodCount; // hashCode + ++methodCount; // equals + + // getFeatureStoreName : not delegate + Mockito.reset(wrappedModel); + wrapperModel.getFeatureStoreName(); + ++methodCount; + Mockito.verify(wrappedModel, Mockito.times(0)).getFeatureStoreName(); + + // getName : not delegate + Mockito.reset(wrappedModel); + wrapperModel.getName(); + ++methodCount; + Mockito.verify(wrappedModel, Mockito.times(0)).getName(); + + // getParams : not delegate + Mockito.reset(wrappedModel); + wrapperModel.getParams(); + ++methodCount; + Mockito.verify(wrappedModel, Mockito.times(0)).getParams(); + + // getNorms : delegate + Mockito.reset(wrappedModel); + wrapperModel.getNorms(); + ++methodCount; + Mockito.verify(wrappedModel, Mockito.times(1)).getNorms(); + + // getFeatures : delegate + Mockito.reset(wrappedModel); + wrapperModel.getFeatures(); + ++methodCount; + Mockito.verify(wrappedModel, Mockito.times(1)).getFeatures(); + + // getAllFeatures : delegate + Mockito.reset(wrappedModel); + wrapperModel.getAllFeatures(); + ++methodCount; + Mockito.verify(wrappedModel, Mockito.times(1)).getAllFeatures(); + + // score : delegate + Mockito.reset(wrappedModel); + wrapperModel.score(null); + ++methodCount; + Mockito.verify(wrappedModel, Mockito.times(1)).score(null); + + // normalizeFeaturesInPlace : delegate + Mockito.reset(wrappedModel); + wrapperModel.normalizeFeaturesInPlace(null); + ++methodCount; + Mockito.verify(wrappedModel, Mockito.times(1)).normalizeFeaturesInPlace(null); + + // getNormalizerExplanation : delegate + Mockito.reset(wrappedModel); + wrapperModel.getNormalizerExplanation(null, 0); + ++methodCount; + Mockito.verify(wrappedModel, Mockito.times(1)).getNormalizerExplanation(null, 0); + + // explain : delegate + Mockito.reset(wrappedModel); + wrapperModel.explain(null, 0, 0.0f, null); + ++methodCount; + Mockito.verify(wrappedModel, Mockito.times(1)).explain(null, 0, 0.0f, null); + + return methodCount; + } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/store/rest/TestModelManagerPersistence.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/store/rest/TestModelManagerPersistence.java index 9dc28e6fbc2..a056cf7e340 100644 --- a/solr/contrib/ltr/src/test/org/apache/solr/ltr/store/rest/TestModelManagerPersistence.java +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/store/rest/TestModelManagerPersistence.java @@ -16,12 +16,19 @@ */ package org.apache.solr.ltr.store.rest; +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileOutputStream; +import java.io.OutputStreamWriter; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Map; import org.apache.commons.io.FileUtils; import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.feature.FieldValueFeature; import org.apache.solr.ltr.feature.ValueFeature; +import org.apache.solr.ltr.model.DefaultWrapperModel; import org.apache.solr.ltr.model.LinearModel; import org.apache.solr.ltr.store.FeatureStore; import org.junit.BeforeClass; @@ -185,4 +192,73 @@ public class TestModelManagerPersistence extends TestRerankBase { assertJQ(ManagedFeatureStore.REST_END_POINT + "/" + FeatureStore.DEFAULT_FEATURE_STORE_NAME, "/features/==[]"); } + + private static void doWrapperModelPersistenceChecks(String modelName, + String featureStoreName, String baseModelFileName) throws Exception { + // note that the wrapper and the wrapped model always have the same name + assertJQ(ManagedModelStore.REST_END_POINT, + // the wrapped model shouldn't be registered + "!/models/[1]/name=='"+modelName+"'", + // but the wrapper model should be registered + "/models/[0]/name=='"+modelName+"'", + "/models/[0]/class=='" + DefaultWrapperModel.class.getCanonicalName() + "'", + "/models/[0]/store=='" + featureStoreName + "'", + // the wrapper model shouldn't contain the definitions of the wrapped model + "/models/[0]/features/==[]", + // but only its own parameters + "/models/[0]/params=={resource:'"+baseModelFileName+"'}"); + } + + @Test + public void testWrapperModelPersistence() throws Exception { + final String modelName = "linear"; + final String FS_NAME = "testWrapper"; + + // check whether models and features are empty + assertJQ(ManagedModelStore.REST_END_POINT, + "/models/==[]"); + assertJQ(ManagedFeatureStore.REST_END_POINT + "/" + FS_NAME, + "/features/==[]"); + + // setup features + loadFeature("popularity", FieldValueFeature.class.getCanonicalName(), FS_NAME, "{\"field\":\"popularity\"}"); + loadFeature("const", ValueFeature.class.getCanonicalName(), FS_NAME, "{\"value\":5}"); + + // setup base model + String baseModelJson = getModelInJson(modelName, LinearModel.class.getCanonicalName(), + new String[] {"popularity", "const"}, FS_NAME, + "{\"weights\":{\"popularity\":-1.0, \"const\":1.0}}"); + File baseModelFile = new File(tmpConfDir, "baseModelForPersistence.json"); + try (BufferedWriter writer = new BufferedWriter( + new OutputStreamWriter(new FileOutputStream(baseModelFile), StandardCharsets.UTF_8))) { + writer.write(baseModelJson); + } + baseModelFile.deleteOnExit(); + + // setup wrapper model + String wrapperModelJson = getModelInJson(modelName, DefaultWrapperModel.class.getCanonicalName(), + new String[0], FS_NAME, + "{\"resource\":\"" + baseModelFile.getName() + "\"}"); + assertJPut(ManagedModelStore.REST_END_POINT, wrapperModelJson, "/responseHeader/status==0"); + doWrapperModelPersistenceChecks(modelName, FS_NAME, baseModelFile.getName()); + + // check persistence after reload + restTestHarness.reload(); + doWrapperModelPersistenceChecks(modelName, FS_NAME, baseModelFile.getName()); + + // check persistence after restart + jetty.stop(); + jetty.start(); + doWrapperModelPersistenceChecks(modelName, FS_NAME, baseModelFile.getName()); + + // delete test settings + restTestHarness.delete(ManagedModelStore.REST_END_POINT + "/" + modelName); + restTestHarness.delete(ManagedFeatureStore.REST_END_POINT + "/" + FS_NAME); + assertJQ(ManagedModelStore.REST_END_POINT, + "/models/==[]"); + assertJQ(ManagedFeatureStore.REST_END_POINT + "/" + FS_NAME, + "/features/==[]"); + + // NOTE: we don't test the persistence of the deletion here because it's tested in testFilePersistence + } } diff --git a/solr/solr-ref-guide/src/learning-to-rank.adoc b/solr/solr-ref-guide/src/learning-to-rank.adoc index 02475c6b4fd..3bbc34da7ac 100644 --- a/solr/solr-ref-guide/src/learning-to-rank.adoc +++ b/solr/solr-ref-guide/src/learning-to-rank.adoc @@ -87,6 +87,7 @@ Feature selection and model training take place offline and outside Solr. The lt |General form |Class |Specific examples |Linear |{solr-javadocs}/solr-ltr/org/apache/solr/ltr/model/LinearModel.html[LinearModel] |RankSVM, Pranking |Multiple Additive Trees |{solr-javadocs}/solr-ltr/org/apache/solr/ltr/model/MultipleAdditiveTreesModel.html[MultipleAdditiveTreesModel] |LambdaMART, Gradient Boosted Regression Trees (GBRT) +|(wrapper) |{solr-javadocs}/solr-ltr/org/apache/solr/ltr/model/DefaultWrapperModel.html[DefaultWrapperModel] |(not applicable) |(custom) |(custom class extending {solr-javadocs}/solr-ltr/org/apache/solr/ltr/model/LTRScoringModel.html[LTRScoringModel]) |(not applicable) |=== @@ -509,6 +510,54 @@ To delete the `currentFeatureStore` feature store: curl -XDELETE 'http://localhost:8983/solr/techproducts/schema/feature-store/currentFeatureStore' ---- +==== Using large models + +With SolrCloud, large models may fail to upload due to the limitation of ZooKeeper's buffer. In this case, `DefaultWrapperModel` may help you to separate the model definition from uploaded file. + +Assuming that you consider to use a large model placed at `/path/to/models/myModel.json` through `DefaultWrapperModel`. + +[source,json] +---- +{ + "store" : "largeModelsFeatureStore", + "name" : "myModel", + "class" : ..., + "features" : [ + ... + ], + "params" : { + ... + } +} +---- + +First, add the directory to Solr's resource paths by <>: + +[source,xml] +---- + +---- + +Then, configure `DefaultWrapperModel` to wrap `myModel.json`: + +[source,json] +---- +{ + "store" : "largeModelsFeatureStore", + "name" : "myWrapperModel", + "class" : "org.apache.solr.ltr.model.DefaultWrapperModel", + "params" : { + "resource" : "myModel.json" + } +} +---- + +`myModel.json` will be loaded during the initialization and be able to use by specifying `model=myWrapperModel`. + +NOTE: No `"features"` are configured in `myWrapperModel` because the features of the wrapped model (`myModel`) will be used; also note that the `"store"` configured for the wrapper model must match that of the wrapped model i.e. in this example the feature store called `largeModelsFeatureStore` is used. + +CAUTION: `` doesn't work as expected in this case, because `SolrResourceLoader` considers given resources as JAR if `` indicates files. + === Applying Changes The feature store and the model store are both <>. Changes made to managed resources are not applied to the active Solr components until the Solr collection (or Solr core in single server mode) is reloaded.