SOLR-12699: Make contrib/ltr LTRScoringModel immutable and cache its hashCode.

(Stanislav Livotov, Edward Ribeiro, Christine Poerschke)
This commit is contained in:
Christine Poerschke 2018-11-05 18:56:40 +00:00
parent 01808eee93
commit be65b95e80
7 changed files with 61 additions and 11 deletions

View File

@ -178,6 +178,9 @@ Improvements
* SOLR-12892: MapWriter to use CharSequence instead of String (noble) * SOLR-12892: MapWriter to use CharSequence instead of String (noble)
* SOLR-12699: Make contrib/ltr LTRScoringModel immutable and cache its hashCode.
(Stanislav Livotov, Edward Ribeiro, Christine Poerschke)
================== 7.5.0 ================== ================== 7.5.0 ==================
Consult the LUCENE_CHANGES.txt file for additional, low level, changes in this release. Consult the LUCENE_CHANGES.txt file for additional, low level, changes in this release.

View File

@ -52,6 +52,11 @@ import org.noggit.ObjectBuilder;
*/ */
public class DefaultWrapperModel extends WrapperModel { public class DefaultWrapperModel extends WrapperModel {
/**
* resource is part of the LTRScoringModel params map
* and therefore here it does not individually
* influence the class hashCode, equals, etc.
*/
private String resource; private String resource;
public DefaultWrapperModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName, public DefaultWrapperModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName,

View File

@ -16,11 +16,14 @@
*/ */
package org.apache.solr.ltr.model; package org.apache.solr.ltr.model;
import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Explanation;
@ -81,6 +84,7 @@ public abstract class LTRScoringModel {
private final List<Feature> allFeatures; private final List<Feature> allFeatures;
private final Map<String,Object> params; private final Map<String,Object> params;
protected final List<Normalizer> norms; protected final List<Normalizer> norms;
private Integer hashCode; // cached since it shouldn't actually change after construction
public static LTRScoringModel getInstance(SolrResourceLoader solrResourceLoader, public static LTRScoringModel getInstance(SolrResourceLoader solrResourceLoader,
String className, String name, List<Feature> features, String className, String name, List<Feature> features,
@ -111,11 +115,11 @@ public abstract class LTRScoringModel {
String featureStoreName, List<Feature> allFeatures, String featureStoreName, List<Feature> allFeatures,
Map<String,Object> params) { Map<String,Object> params) {
this.name = name; this.name = name;
this.features = features; this.features = features != null ? Collections.unmodifiableList(new ArrayList<>(features)) : null;
this.featureStoreName = featureStoreName; this.featureStoreName = featureStoreName;
this.allFeatures = allFeatures; this.allFeatures = allFeatures != null ? Collections.unmodifiableList(new ArrayList<>(allFeatures)) : null;
this.params = params; this.params = params != null ? Collections.unmodifiableMap(new LinkedHashMap<>(params)) : null;
this.norms = norms; this.norms = norms != null ? Collections.unmodifiableList(new ArrayList<>(norms)) : null;
} }
/** /**
@ -144,7 +148,7 @@ public abstract class LTRScoringModel {
* @return the norms * @return the norms
*/ */
public List<Normalizer> getNorms() { public List<Normalizer> getNorms() {
return Collections.unmodifiableList(norms); return norms;
} }
/** /**
@ -158,7 +162,7 @@ public abstract class LTRScoringModel {
* @return the features * @return the features
*/ */
public List<Feature> getFeatures() { public List<Feature> getFeatures() {
return Collections.unmodifiableList(features); return features;
} }
public Map<String,Object> getParams() { public Map<String,Object> getParams() {
@ -167,13 +171,20 @@ public abstract class LTRScoringModel {
@Override @Override
public int hashCode() { public int hashCode() {
if(hashCode == null) {
hashCode = calculateHashCode();
}
return hashCode;
}
final private int calculateHashCode() {
final int prime = 31; final int prime = 31;
int result = 1; int result = 1;
result = (prime * result) + ((features == null) ? 0 : features.hashCode()); result = (prime * result) + Objects.hashCode(features);
result = (prime * result) + ((name == null) ? 0 : name.hashCode()); result = (prime * result) + Objects.hashCode(name);
result = (prime * result) + ((params == null) ? 0 : params.hashCode()); result = (prime * result) + Objects.hashCode(params);
result = (prime * result) + ((norms == null) ? 0 : norms.hashCode()); result = (prime * result) + Objects.hashCode(norms);
result = (prime * result) + ((featureStoreName == null) ? 0 : featureStoreName.hashCode()); result = (prime * result) + Objects.hashCode(featureStoreName);
return result; return result;
} }

View File

@ -71,6 +71,11 @@ import org.apache.solr.ltr.norm.Normalizer;
*/ */
public class LinearModel extends LTRScoringModel { public class LinearModel extends LTRScoringModel {
/**
* featureToWeight is part of the LTRScoringModel params map
* and therefore here it does not individually
* influence the class hashCode, equals, etc.
*/
protected Float[] featureToWeight; protected Float[] featureToWeight;
public void setWeights(Object weights) { public void setWeights(Object weights) {

View File

@ -90,7 +90,18 @@ import org.apache.solr.util.SolrPluginUtils;
*/ */
public class MultipleAdditiveTreesModel extends LTRScoringModel { public class MultipleAdditiveTreesModel extends LTRScoringModel {
/**
* fname2index is filled from constructor arguments
* (that are already part of the base class hashCode)
* and therefore here it does not individually
* influence the class hashCode, equals, etc.
*/
private final HashMap<String,Integer> fname2index; private final HashMap<String,Integer> fname2index;
/**
* trees is part of the LTRScoringModel params map
* and therefore here it does not individually
* influence the class hashCode, equals, etc.
*/
private List<RegressionTree> trees; private List<RegressionTree> trees;
private RegressionTree createRegressionTree(Map<String,Object> map) { private RegressionTree createRegressionTree(Map<String,Object> map) {

View File

@ -95,6 +95,11 @@ import org.apache.solr.util.SolrPluginUtils;
*/ */
public class NeuralNetworkModel extends LTRScoringModel { public class NeuralNetworkModel extends LTRScoringModel {
/**
* layers is part of the LTRScoringModel params map
* and therefore here it does not individually
* influence the class hashCode, equals, etc.
*/
private List<Layer> layers; private List<Layer> layers;
protected interface Activation { protected interface Activation {

View File

@ -93,7 +93,17 @@ public class TestAdapterModel extends TestRerankBase {
public static class CustomModel extends AdapterModel { public static class CustomModel extends AdapterModel {
/**
* answerFileName is part of the LTRScoringModel params map
* and therefore here it does not individually
* influence the class hashCode, equals, etc.
*/
private String answerFileName; private String answerFileName;
/**
* answerValue is obtained from answerFileName
* and therefore here it does not individually
* influence the class hashCode, equals, etc.
*/
private float answerValue; private float answerValue;
public CustomModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName, public CustomModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName,