mirror of https://github.com/apache/lucene.git
SOLR-8542: couple of tweaks (Michael Nilsson, Diego Ceccarelli, Christine Poerschke)
* removed code triplication in ManagedModelStore * LTRScoringQuery.java tweaks * FeatureLogger.makeFeatureVector(...) can now safely be called repeatedly (though that doesn't happen at present) * make Feature.FeatureWeight.extractTerms a no-op; (OriginalScore|SolrFeature)Weight now implement extractTerms * LTRThreadModule javadocs and README.md tweaks * add TestFieldValueFeature.testBooleanValue test; replace "T"/"F" magic string use in FieldValueFeature * add TestOriginalScoreScorer test; add OriginalScoreScorer.freq() method * in TestMultipleAdditiveTreesModel revive dead explain test
This commit is contained in:
parent
d2ed42b847
commit
bfc3690d52
|
@ -390,17 +390,17 @@ About half the time for ranking is spent in the creation of weights for each fea
|
||||||
<!-- Query parser used to rerank top docs with a provided model -->
|
<!-- Query parser used to rerank top docs with a provided model -->
|
||||||
<queryParser name="ltr" class="org.apache.solr.ltr.search.LTRQParserPlugin">
|
<queryParser name="ltr" class="org.apache.solr.ltr.search.LTRQParserPlugin">
|
||||||
<int name="threadModule.totalPoolThreads">10</int> <!-- Maximum threads to share for all requests -->
|
<int name="threadModule.totalPoolThreads">10</int> <!-- Maximum threads to share for all requests -->
|
||||||
<int name="threadModule.numThreadsPerRequest">5</int> <!-- Maximum threads to use for a single requests-->
|
<int name="threadModule.numThreadsPerRequest">5</int> <!-- Maximum threads to use for a single request -->
|
||||||
</queryParser>
|
</queryParser>
|
||||||
|
|
||||||
<!-- Transformer for extracting features -->
|
<!-- Transformer for extracting features -->
|
||||||
<transformer name="features" class="org.apache.solr.ltr.response.transform.LTRFeatureLoggerTransformerFactory">
|
<transformer name="features" class="org.apache.solr.ltr.response.transform.LTRFeatureLoggerTransformerFactory">
|
||||||
<int name="threadModule.totalPoolThreads">10</int> <!-- Maximum threads to share for all requests -->
|
<int name="threadModule.totalPoolThreads">10</int> <!-- Maximum threads to share for all requests -->
|
||||||
<int name="threadModule.numThreadsPerRequest">5</int> <!-- Maximum threads to use for a single requests-->
|
<int name="threadModule.numThreadsPerRequest">5</int> <!-- Maximum threads to use for a single request -->
|
||||||
</transformer>
|
</transformer>
|
||||||
</config>
|
</config>
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
The threadModule.totalPoolThreads option limits the total number of threads to be used across all query instances at any given time. threadModule.numThreadsPerRequest limits the number of threads used to process a single query. In the above example, 10 threads will be used to services all queries and a maximum of 5 threads to service a single query. If the solr instances is expected to receive no more than one query at a time, it is best to set both these numbers to the same value. If multiple queries need to serviced simultaneously, the numbers can be adjusted based on the expected response times. If the value of threadModule.numThreadsPerRequest is higher, the reponse time for a single query will be improved upto a point. If multiple queries are serviced simultaneously, the threadModule.totalPoolThreads imposes a contention between the queries if (threadModule.numThreadsPerRequest*total parallel queries > threadModule.totalPoolThreads).
|
The threadModule.totalPoolThreads option limits the total number of threads to be used across all query instances at any given time. threadModule.numThreadsPerRequest limits the number of threads used to process a single query. In the above example, 10 threads will be used to services all queries and a maximum of 5 threads to service a single query. If the solr instance is expected to receive no more than one query at a time, it is best to set both these numbers to the same value. If multiple queries need to be serviced simultaneously, the numbers can be adjusted based on the expected response times. If the value of threadModule.numThreadsPerRequest is higher, the response time for a single query will be improved upto a point. If multiple queries are serviced simultaneously, the threadModule.totalPoolThreads imposes a contention between the queries if (threadModule.numThreadsPerRequest*total parallel queries > threadModule.totalPoolThreads).
|
||||||
|
|
||||||
|
|
|
@ -151,7 +151,6 @@ public abstract class FeatureLogger<FV_TYPE> {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class CSVFeatureLogger extends FeatureLogger<String> {
|
public static class CSVFeatureLogger extends FeatureLogger<String> {
|
||||||
StringBuilder sb = new StringBuilder(500);
|
|
||||||
char keyValueSep = ':';
|
char keyValueSep = ':';
|
||||||
char featureSep = ';';
|
char featureSep = ';';
|
||||||
|
|
||||||
|
@ -171,6 +170,10 @@ public abstract class FeatureLogger<FV_TYPE> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String makeFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo) {
|
public String makeFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo) {
|
||||||
|
// Allocate the buffer to a size based on the number of features instead of the
|
||||||
|
// default 16. You need space for the name, value, and two separators per feature,
|
||||||
|
// but not all the features are expected to fire, so this is just a naive estimate.
|
||||||
|
StringBuilder sb = new StringBuilder(featuresInfo.length * 3);
|
||||||
boolean isDense = featureFormat.equals(FeatureFormat.DENSE);
|
boolean isDense = featureFormat.equals(FeatureFormat.DENSE);
|
||||||
for (LTRScoringQuery.FeatureInfo featInfo:featuresInfo) {
|
for (LTRScoringQuery.FeatureInfo featInfo:featuresInfo) {
|
||||||
if (featInfo.isUsed() || isDense){
|
if (featInfo.isUsed() || isDense){
|
||||||
|
@ -181,9 +184,8 @@ public abstract class FeatureLogger<FV_TYPE> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
final String features = (sb.length() > 0 ? sb.substring(0,
|
final String features = (sb.length() > 0 ?
|
||||||
sb.length() - 1) : "");
|
sb.substring(0, sb.length() - 1) : "");
|
||||||
sb.setLength(0);
|
|
||||||
|
|
||||||
return features;
|
return features;
|
||||||
}
|
}
|
||||||
|
|
|
@ -205,10 +205,10 @@ public class LTRScoringQuery extends Query {
|
||||||
List<Feature.FeatureWeight > featureWeights = new ArrayList<>(features.size());
|
List<Feature.FeatureWeight > featureWeights = new ArrayList<>(features.size());
|
||||||
|
|
||||||
if (querySemaphore == null) {
|
if (querySemaphore == null) {
|
||||||
createWeights(searcher, needsScores, boost, featureWeights, features);
|
createWeights(searcher, needsScores, featureWeights, features);
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
createWeightsParallel(searcher, needsScores, boost, featureWeights, features);
|
createWeightsParallel(searcher, needsScores, featureWeights, features);
|
||||||
}
|
}
|
||||||
int i=0, j = 0;
|
int i=0, j = 0;
|
||||||
if (this.extractAllFeatures) {
|
if (this.extractAllFeatures) {
|
||||||
|
@ -228,7 +228,7 @@ public class LTRScoringQuery extends Query {
|
||||||
return new ModelWeight(modelFeaturesWeights, extractedFeatureWeights, allFeatures.size());
|
return new ModelWeight(modelFeaturesWeights, extractedFeatureWeights, allFeatures.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void createWeights(IndexSearcher searcher, boolean needsScores, float boost,
|
private void createWeights(IndexSearcher searcher, boolean needsScores,
|
||||||
List<Feature.FeatureWeight > featureWeights, Collection<Feature> features) throws IOException {
|
List<Feature.FeatureWeight > featureWeights, Collection<Feature> features) throws IOException {
|
||||||
final SolrQueryRequest req = getRequest();
|
final SolrQueryRequest req = getRequest();
|
||||||
// since the feature store is a linkedhashmap order is preserved
|
// since the feature store is a linkedhashmap order is preserved
|
||||||
|
@ -271,7 +271,7 @@ public class LTRScoringQuery extends Query {
|
||||||
}
|
}
|
||||||
} // end of call CreateWeightCallable
|
} // end of call CreateWeightCallable
|
||||||
|
|
||||||
private void createWeightsParallel(IndexSearcher searcher, boolean needsScores, float boost,
|
private void createWeightsParallel(IndexSearcher searcher, boolean needsScores,
|
||||||
List<Feature.FeatureWeight > featureWeights, Collection<Feature> features) throws RuntimeException {
|
List<Feature.FeatureWeight > featureWeights, Collection<Feature> features) throws RuntimeException {
|
||||||
|
|
||||||
final SolrQueryRequest req = getRequest();
|
final SolrQueryRequest req = getRequest();
|
||||||
|
@ -401,8 +401,9 @@ public class LTRScoringQuery extends Query {
|
||||||
/**
|
/**
|
||||||
* Goes through all the stored feature values, and calculates the normalized
|
* Goes through all the stored feature values, and calculates the normalized
|
||||||
* values for all the features that will be used for scoring.
|
* values for all the features that will be used for scoring.
|
||||||
|
* Then calculate and return the model's score.
|
||||||
*/
|
*/
|
||||||
private void makeNormalizedFeatures() {
|
private float makeNormalizedFeaturesAndScore() {
|
||||||
int pos = 0;
|
int pos = 0;
|
||||||
for (final Feature.FeatureWeight feature : modelFeatureWeights) {
|
for (final Feature.FeatureWeight feature : modelFeatureWeights) {
|
||||||
final int featureId = feature.getIndex();
|
final int featureId = feature.getIndex();
|
||||||
|
@ -415,6 +416,7 @@ public class LTRScoringQuery extends Query {
|
||||||
pos++;
|
pos++;
|
||||||
}
|
}
|
||||||
ltrScoringModel.normalizeFeaturesInPlace(modelFeatureValuesNormalized);
|
ltrScoringModel.normalizeFeaturesInPlace(modelFeatureValuesNormalized);
|
||||||
|
return ltrScoringModel.score(modelFeatureValuesNormalized);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -491,8 +493,8 @@ public class LTRScoringQuery extends Query {
|
||||||
for (final Feature.FeatureWeight.FeatureScorer subSocer : featureScorers) {
|
for (final Feature.FeatureWeight.FeatureScorer subSocer : featureScorers) {
|
||||||
subSocer.setDocInfo(docInfo);
|
subSocer.setDocInfo(docInfo);
|
||||||
}
|
}
|
||||||
if (featureScorers.size() <= 1) { // TODO: Allow the use of dense
|
if (featureScorers.size() <= 1) {
|
||||||
// features in other cases
|
// future enhancement: allow the use of dense features in other cases
|
||||||
featureTraversalScorer = new DenseModelScorer(weight, featureScorers);
|
featureTraversalScorer = new DenseModelScorer(weight, featureScorers);
|
||||||
} else {
|
} else {
|
||||||
featureTraversalScorer = new SparseModelScorer(weight, featureScorers);
|
featureTraversalScorer = new SparseModelScorer(weight, featureScorers);
|
||||||
|
@ -570,8 +572,7 @@ public class LTRScoringQuery extends Query {
|
||||||
featuresInfo[featureId].setUsed(true);
|
featuresInfo[featureId].setUsed(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
makeNormalizedFeatures();
|
return makeNormalizedFeaturesAndScore();
|
||||||
return ltrScoringModel.score(modelFeatureValuesNormalized);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -663,8 +664,7 @@ public class LTRScoringQuery extends Query {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
makeNormalizedFeatures();
|
return makeNormalizedFeaturesAndScore();
|
||||||
return ltrScoringModel.score(modelFeatureValuesNormalized);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -29,6 +29,35 @@ import org.apache.solr.util.DefaultSolrThreadFactory;
|
||||||
import org.apache.solr.util.SolrPluginUtils;
|
import org.apache.solr.util.SolrPluginUtils;
|
||||||
import org.apache.solr.util.plugin.NamedListInitializedPlugin;
|
import org.apache.solr.util.plugin.NamedListInitializedPlugin;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The LTRThreadModule is optionally used by the {@link org.apache.solr.ltr.search.LTRQParserPlugin} and
|
||||||
|
* {@link org.apache.solr.ltr.response.transform.LTRFeatureLoggerTransformerFactory LTRFeatureLoggerTransformerFactory}
|
||||||
|
* classes to parallelize the creation of {@link org.apache.solr.ltr.feature.Feature.FeatureWeight Feature.FeatureWeight}
|
||||||
|
* objects.
|
||||||
|
* <p>
|
||||||
|
* Example configuration:
|
||||||
|
* <pre>
|
||||||
|
<queryParser name="ltr" class="org.apache.solr.ltr.search.LTRQParserPlugin">
|
||||||
|
<int name="threadModule.totalPoolThreads">10</int>
|
||||||
|
<int name="threadModule.numThreadsPerRequest">5</int>
|
||||||
|
</queryParser>
|
||||||
|
|
||||||
|
<transformer name="features" class="org.apache.solr.ltr.response.transform.LTRFeatureLoggerTransformerFactory">
|
||||||
|
<int name="threadModule.totalPoolThreads">10</int>
|
||||||
|
<int name="threadModule.numThreadsPerRequest">5</int>
|
||||||
|
</transformer>
|
||||||
|
</pre>
|
||||||
|
* If an individual solr instance is expected to receive no more than one query at a time, it is best
|
||||||
|
* to set <code>totalPoolThreads</code> and <code>numThreadsPerRequest</code> to the same value.
|
||||||
|
*
|
||||||
|
* If multiple queries need to be serviced simultaneously then <code>totalPoolThreads</code> and
|
||||||
|
* <code>numThreadsPerRequest</code> can be adjusted based on the expected response times.
|
||||||
|
*
|
||||||
|
* If the value of <code>numThreadsPerRequest</code> is higher, the response time for a single query
|
||||||
|
* will be improved up to a point. If multiple queries are serviced simultaneously, the value of
|
||||||
|
* <code>totalPoolThreads</code> imposes a contention between the queries if
|
||||||
|
* <code>(totalPoolThreads < numThreadsPerRequest * total parallel queries)</code>.
|
||||||
|
*/
|
||||||
final public class LTRThreadModule implements NamedListInitializedPlugin {
|
final public class LTRThreadModule implements NamedListInitializedPlugin {
|
||||||
|
|
||||||
public static LTRThreadModule getInstance(NamedList args) {
|
public static LTRThreadModule getInstance(NamedList args) {
|
||||||
|
|
|
@ -258,8 +258,7 @@ public abstract class Feature extends Query {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void extractTerms(Set<Term> terms) {
|
public void extractTerms(Set<Term> terms) {
|
||||||
// needs to be implemented by query subclasses
|
// no-op
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -29,6 +29,7 @@ import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.IndexSearcher;
|
import org.apache.lucene.search.IndexSearcher;
|
||||||
import org.apache.lucene.search.Query;
|
import org.apache.lucene.search.Query;
|
||||||
import org.apache.solr.request.SolrQueryRequest;
|
import org.apache.solr.request.SolrQueryRequest;
|
||||||
|
import org.apache.solr.schema.BoolField;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This feature returns the value of a field in the current document
|
* This feature returns the value of a field in the current document
|
||||||
|
@ -119,13 +120,16 @@ public class FieldValueFeature extends Feature {
|
||||||
return number.floatValue();
|
return number.floatValue();
|
||||||
} else {
|
} else {
|
||||||
final String string = indexableField.stringValue();
|
final String string = indexableField.stringValue();
|
||||||
// boolean values in the index are encoded with the
|
if (string.length() == 1) {
|
||||||
// chars T/F
|
// boolean values in the index are encoded with the
|
||||||
if (string.equals("T")) {
|
// a single char contained in TRUE_TOKEN or FALSE_TOKEN
|
||||||
return 1;
|
// (see BoolField)
|
||||||
}
|
if (string.charAt(0) == BoolField.TRUE_TOKEN[0]) {
|
||||||
if (string.equals("F")) {
|
return 1;
|
||||||
return 0;
|
}
|
||||||
|
if (string.charAt(0) == BoolField.FALSE_TOKEN[0]) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (final IOException e) {
|
} catch (final IOException e) {
|
||||||
|
|
|
@ -19,8 +19,10 @@ package org.apache.solr.ltr.feature;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
import org.apache.lucene.index.LeafReaderContext;
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
|
import org.apache.lucene.index.Term;
|
||||||
import org.apache.lucene.search.DocIdSetIterator;
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.IndexSearcher;
|
import org.apache.lucene.search.IndexSearcher;
|
||||||
import org.apache.lucene.search.Query;
|
import org.apache.lucene.search.Query;
|
||||||
|
@ -76,7 +78,10 @@ public class OriginalScoreFeature extends Feature {
|
||||||
return "OriginalScoreFeature [query:" + originalQuery.toString() + "]";
|
return "OriginalScoreFeature [query:" + originalQuery.toString() + "]";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void extractTerms(Set<Term> terms) {
|
||||||
|
w.extractTerms(terms);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public FeatureScorer scorer(LeafReaderContext context) throws IOException {
|
public FeatureScorer scorer(LeafReaderContext context) throws IOException {
|
||||||
|
@ -102,6 +107,11 @@ public class OriginalScoreFeature extends Feature {
|
||||||
return (docInfo.hasOriginalDocScore() ? docInfo.getOriginalDocScore() : originalScorer.score());
|
return (docInfo.hasOriginalDocScore() ? docInfo.getOriginalDocScore() : originalScorer.score());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int freq() throws IOException {
|
||||||
|
return originalScorer.freq();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int docID() {
|
public int docID() {
|
||||||
return originalScorer.docID();
|
return originalScorer.docID();
|
||||||
|
|
|
@ -21,8 +21,10 @@ import java.util.ArrayList;
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
import org.apache.lucene.index.LeafReaderContext;
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
|
import org.apache.lucene.index.Term;
|
||||||
import org.apache.lucene.search.DocIdSet;
|
import org.apache.lucene.search.DocIdSet;
|
||||||
import org.apache.lucene.search.DocIdSetIterator;
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.IndexSearcher;
|
import org.apache.lucene.search.IndexSearcher;
|
||||||
|
@ -123,9 +125,9 @@ public class SolrFeature extends Feature {
|
||||||
* Weight for a SolrFeature
|
* Weight for a SolrFeature
|
||||||
**/
|
**/
|
||||||
public class SolrFeatureWeight extends FeatureWeight {
|
public class SolrFeatureWeight extends FeatureWeight {
|
||||||
Weight solrQueryWeight;
|
final private Weight solrQueryWeight;
|
||||||
Query query;
|
final private Query query;
|
||||||
List<Query> queryAndFilters;
|
final private List<Query> queryAndFilters;
|
||||||
|
|
||||||
public SolrFeatureWeight(IndexSearcher searcher,
|
public SolrFeatureWeight(IndexSearcher searcher,
|
||||||
SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi) throws IOException {
|
SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi) throws IOException {
|
||||||
|
@ -174,6 +176,8 @@ public class SolrFeature extends Feature {
|
||||||
if (query != null) {
|
if (query != null) {
|
||||||
queryAndFilters.add(query);
|
queryAndFilters.add(query);
|
||||||
solrQueryWeight = searcher.createNormalizedWeight(query, true);
|
solrQueryWeight = searcher.createNormalizedWeight(query, true);
|
||||||
|
} else {
|
||||||
|
solrQueryWeight = null;
|
||||||
}
|
}
|
||||||
} catch (final SyntaxError e) {
|
} catch (final SyntaxError e) {
|
||||||
throw new FeatureException("Failed to parse feature query.", e);
|
throw new FeatureException("Failed to parse feature query.", e);
|
||||||
|
@ -201,6 +205,13 @@ public class SolrFeature extends Feature {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void extractTerms(Set<Term> terms) {
|
||||||
|
if (solrQueryWeight != null) {
|
||||||
|
solrQueryWeight.extractTerms(terms);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public FeatureScorer scorer(LeafReaderContext context) throws IOException {
|
public FeatureScorer scorer(LeafReaderContext context) throws IOException {
|
||||||
Scorer solrScorer = null;
|
Scorer solrScorer = null;
|
||||||
|
|
|
@ -57,7 +57,6 @@ public class ManagedFeatureStore extends ManagedResource implements ManagedResou
|
||||||
|
|
||||||
/** the feature store rest endpoint **/
|
/** the feature store rest endpoint **/
|
||||||
public static final String REST_END_POINT = "/schema/feature-store";
|
public static final String REST_END_POINT = "/schema/feature-store";
|
||||||
// TODO: reduce from public to package visibility (once tests no longer need public access)
|
|
||||||
|
|
||||||
/** name of the attribute containing the feature class **/
|
/** name of the attribute containing the feature class **/
|
||||||
static final String CLASS_KEY = "class";
|
static final String CLASS_KEY = "class";
|
||||||
|
|
|
@ -61,7 +61,6 @@ public class ManagedModelStore extends ManagedResource implements ManagedResourc
|
||||||
|
|
||||||
/** the model store rest endpoint **/
|
/** the model store rest endpoint **/
|
||||||
public static final String REST_END_POINT = "/schema/model-store";
|
public static final String REST_END_POINT = "/schema/model-store";
|
||||||
// TODO: reduce from public to package visibility (once tests no longer need public access)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Managed model store: the name of the attribute containing all the models of
|
* Managed model store: the name of the attribute containing all the models of
|
||||||
|
@ -124,16 +123,20 @@ public class ManagedModelStore extends ManagedResource implements ManagedResourc
|
||||||
if ((managedData != null) && (managedData instanceof List)) {
|
if ((managedData != null) && (managedData instanceof List)) {
|
||||||
final List<Map<String,Object>> up = (List<Map<String,Object>>) managedData;
|
final List<Map<String,Object>> up = (List<Map<String,Object>>) managedData;
|
||||||
for (final Map<String,Object> u : up) {
|
for (final Map<String,Object> u : up) {
|
||||||
try {
|
addModelFromMap(u);
|
||||||
final LTRScoringModel algo = fromLTRScoringModelMap(solrResourceLoader, u, managedFeatureStore);
|
|
||||||
addModel(algo);
|
|
||||||
} catch (final ModelException e) {
|
|
||||||
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void addModelFromMap(Map<String,Object> modelMap) {
|
||||||
|
try {
|
||||||
|
final LTRScoringModel algo = fromLTRScoringModelMap(solrResourceLoader, modelMap, managedFeatureStore);
|
||||||
|
addModel(algo);
|
||||||
|
} catch (final ModelException e) {
|
||||||
|
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public synchronized void addModel(LTRScoringModel ltrScoringModel) throws ModelException {
|
public synchronized void addModel(LTRScoringModel ltrScoringModel) throws ModelException {
|
||||||
try {
|
try {
|
||||||
log.info("adding model {}", ltrScoringModel.getName());
|
log.info("adding model {}", ltrScoringModel.getName());
|
||||||
|
@ -146,26 +149,17 @@ public class ManagedModelStore extends ManagedResource implements ManagedResourc
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
@Override
|
@Override
|
||||||
protected Object applyUpdatesToManagedData(Object updates) {
|
protected Object applyUpdatesToManagedData(Object updates) {
|
||||||
|
|
||||||
if (updates instanceof List) {
|
if (updates instanceof List) {
|
||||||
final List<Map<String,Object>> up = (List<Map<String,Object>>) updates;
|
final List<Map<String,Object>> up = (List<Map<String,Object>>) updates;
|
||||||
for (final Map<String,Object> u : up) {
|
for (final Map<String,Object> u : up) {
|
||||||
try {
|
addModelFromMap(u);
|
||||||
final LTRScoringModel algo = fromLTRScoringModelMap(solrResourceLoader, u, managedFeatureStore);
|
|
||||||
addModel(algo);
|
|
||||||
} catch (final ModelException e) {
|
|
||||||
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (updates instanceof Map) {
|
if (updates instanceof Map) {
|
||||||
final Map<String,Object> map = (Map<String,Object>) updates;
|
final Map<String,Object> map = (Map<String,Object>) updates;
|
||||||
try {
|
addModelFromMap(map);
|
||||||
final LTRScoringModel algo = fromLTRScoringModelMap(solrResourceLoader, map, managedFeatureStore);
|
|
||||||
addModel(algo);
|
|
||||||
} catch (final ModelException e) {
|
|
||||||
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return modelsAsManagedResources(store.getModels());
|
return modelsAsManagedResources(store.getModels());
|
||||||
|
|
|
@ -24,6 +24,8 @@
|
||||||
<field name="keywords" type="text_general" indexed="true" stored="true" multiValued="true"/>
|
<field name="keywords" type="text_general" indexed="true" stored="true" multiValued="true"/>
|
||||||
<field name="popularity" type="int" indexed="true" stored="true" />
|
<field name="popularity" type="int" indexed="true" stored="true" />
|
||||||
<field name="normHits" type="float" indexed="true" stored="true" />
|
<field name="normHits" type="float" indexed="true" stored="true" />
|
||||||
|
<field name="isTrendy" type="boolean" indexed="true" stored="true" />
|
||||||
|
|
||||||
<field name="text" type="text_general" indexed="true" stored="false" multiValued="true"/>
|
<field name="text" type="text_general" indexed="true" stored="false" multiValued="true"/>
|
||||||
<field name="_version_" type="long" indexed="true" stored="true"/>
|
<field name="_version_" type="long" indexed="true" stored="true"/>
|
||||||
|
|
||||||
|
|
|
@ -32,21 +32,21 @@ public class TestFieldValueFeature extends TestRerankBase {
|
||||||
setuptest("solrconfig-ltr.xml", "schema.xml");
|
setuptest("solrconfig-ltr.xml", "schema.xml");
|
||||||
|
|
||||||
assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity",
|
assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity",
|
||||||
"1"));
|
"1","isTrendy","true"));
|
||||||
assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description",
|
assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description",
|
||||||
"w2 2asd asdd didid", "popularity", "2"));
|
"w2 2asd asdd didid", "popularity", "2"));
|
||||||
assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity",
|
assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity",
|
||||||
"3"));
|
"3","isTrendy","true"));
|
||||||
assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity",
|
assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity",
|
||||||
"4"));
|
"4","isTrendy","false"));
|
||||||
assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity",
|
assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity",
|
||||||
"5"));
|
"5","isTrendy","true"));
|
||||||
assertU(adoc("id", "6", "title", "w1 w2", "description", "w1 w2",
|
assertU(adoc("id", "6", "title", "w1 w2", "description", "w1 w2",
|
||||||
"popularity", "6"));
|
"popularity", "6","isTrendy","false"));
|
||||||
assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description",
|
assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description",
|
||||||
"w1 w2 w3 w4 w5 w8", "popularity", "7"));
|
"w1 w2 w3 w4 w5 w8", "popularity", "7","isTrendy","true"));
|
||||||
assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description",
|
assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description",
|
||||||
"w1 w1 w1 w2 w2", "popularity", "8"));
|
"w1 w1 w1 w2 w2", "popularity", "8","isTrendy","false"));
|
||||||
|
|
||||||
// a document without the popularity field
|
// a document without the popularity field
|
||||||
assertU(adoc("id", "42", "title", "NO popularity", "description", "NO popularity"));
|
assertU(adoc("id", "42", "title", "NO popularity", "description", "NO popularity"));
|
||||||
|
@ -169,5 +169,39 @@ public class TestFieldValueFeature extends TestRerankBase {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBooleanValue() throws Exception {
|
||||||
|
final String fstore = "test_boolean_store";
|
||||||
|
loadFeature("trendy", FieldValueFeature.class.getCanonicalName(), fstore,
|
||||||
|
"{\"field\":\"isTrendy\"}");
|
||||||
|
|
||||||
|
loadModel("trendy-model", LinearModel.class.getCanonicalName(),
|
||||||
|
new String[] {"trendy"}, fstore, "{\"weights\":{\"trendy\":1.0}}");
|
||||||
|
|
||||||
|
SolrQuery query = new SolrQuery();
|
||||||
|
query.setQuery("id:4");
|
||||||
|
query.add("rq", "{!ltr model=trendy-model reRankDocs=4}");
|
||||||
|
query.add("fl", "[fv]");
|
||||||
|
assertJQ("/query" + query.toQueryString(),
|
||||||
|
"/response/docs/[0]/=={'[fv]':'trendy:0.0'}");
|
||||||
|
|
||||||
|
|
||||||
|
query = new SolrQuery();
|
||||||
|
query.setQuery("id:5");
|
||||||
|
query.add("rq", "{!ltr model=trendy-model reRankDocs=4}");
|
||||||
|
query.add("fl", "[fv]");
|
||||||
|
assertJQ("/query" + query.toQueryString(),
|
||||||
|
"/response/docs/[0]/=={'[fv]':'trendy:1.0'}");
|
||||||
|
|
||||||
|
// check default value is false
|
||||||
|
query = new SolrQuery();
|
||||||
|
query.setQuery("id:2");
|
||||||
|
query.add("rq", "{!ltr model=trendy-model reRankDocs=4}");
|
||||||
|
query.add("fl", "[fv]");
|
||||||
|
assertJQ("/query" + query.toQueryString(),
|
||||||
|
"/response/docs/[0]/=={'[fv]':'trendy:0.0'}");
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.solr.ltr.feature;
|
||||||
|
|
||||||
|
import java.lang.reflect.Method;
|
||||||
|
import java.lang.reflect.Modifier;
|
||||||
|
|
||||||
|
import org.apache.lucene.search.Scorer;
|
||||||
|
import org.apache.lucene.util.LuceneTestCase;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
public class TestOriginalScoreScorer extends LuceneTestCase {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testOverridesAbstractScorerMethods() {
|
||||||
|
final Class<?> ossClass = OriginalScoreFeature.OriginalScoreWeight.OriginalScoreScorer.class;
|
||||||
|
for (final Method scorerClassMethod : Scorer.class.getDeclaredMethods()) {
|
||||||
|
final int modifiers = scorerClassMethod.getModifiers();
|
||||||
|
if (!Modifier.isAbstract(modifiers)) continue;
|
||||||
|
|
||||||
|
try {
|
||||||
|
final Method ossClassMethod = ossClass.getDeclaredMethod(
|
||||||
|
scorerClassMethod.getName(),
|
||||||
|
scorerClassMethod.getParameterTypes());
|
||||||
|
assertEquals("getReturnType() difference",
|
||||||
|
scorerClassMethod.getReturnType(),
|
||||||
|
ossClassMethod.getReturnType());
|
||||||
|
} catch (NoSuchMethodException e) {
|
||||||
|
fail(ossClass + " needs to override '" + scorerClassMethod + "'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -16,7 +16,7 @@
|
||||||
*/
|
*/
|
||||||
package org.apache.solr.ltr.model;
|
package org.apache.solr.ltr.model;
|
||||||
|
|
||||||
//import static org.junit.internal.matchers.StringContains.containsString;
|
import static org.junit.internal.matchers.StringContains.containsString;
|
||||||
|
|
||||||
import org.apache.solr.client.solrj.SolrQuery;
|
import org.apache.solr.client.solrj.SolrQuery;
|
||||||
import org.apache.solr.ltr.TestRerankBase;
|
import org.apache.solr.ltr.TestRerankBase;
|
||||||
|
@ -93,30 +93,28 @@ public class TestMultipleAdditiveTreesModel extends TestRerankBase {
|
||||||
|
|
||||||
// test out the explain feature, make sure it returns something
|
// test out the explain feature, make sure it returns something
|
||||||
query.setParam("debugQuery", "on");
|
query.setParam("debugQuery", "on");
|
||||||
String qryResult = JQ("/query" + query.toQueryString());
|
|
||||||
|
|
||||||
|
String qryResult = JQ("/query" + query.toQueryString());
|
||||||
qryResult = qryResult.replaceAll("\n", " ");
|
qryResult = qryResult.replaceAll("\n", " ");
|
||||||
// FIXME containsString doesn't exist.
|
|
||||||
// assertThat(qryResult, containsString("\"debug\":{"));
|
assertThat(qryResult, containsString("\"debug\":{"));
|
||||||
// qryResult = qryResult.substring(qryResult.indexOf("debug"));
|
qryResult = qryResult.substring(qryResult.indexOf("debug"));
|
||||||
//
|
|
||||||
// assertThat(qryResult, containsString("\"explain\":{"));
|
assertThat(qryResult, containsString("\"explain\":{"));
|
||||||
// qryResult = qryResult.substring(qryResult.indexOf("explain"));
|
qryResult = qryResult.substring(qryResult.indexOf("explain"));
|
||||||
//
|
|
||||||
// assertThat(qryResult, containsString("multipleadditivetreesmodel"));
|
assertThat(qryResult, containsString("multipleadditivetreesmodel"));
|
||||||
// assertThat(qryResult,
|
assertThat(qryResult, containsString(MultipleAdditiveTreesModel.class.getCanonicalName()));
|
||||||
// containsString(MultipleAdditiveTreesModel.class.getCanonicalName()));
|
|
||||||
//
|
assertThat(qryResult, containsString("-100.0 = tree 0"));
|
||||||
// assertThat(qryResult, containsString("-100.0 = tree 0"));
|
assertThat(qryResult, containsString("50.0 = tree 0"));
|
||||||
// assertThat(qryResult, containsString("50.0 = tree 0"));
|
assertThat(qryResult, containsString("-20.0 = tree 1"));
|
||||||
// assertThat(qryResult, containsString("-20.0 = tree 1"));
|
assertThat(qryResult, containsString("'matchedTitle':1.0 > 0.5"));
|
||||||
// assertThat(qryResult, containsString("'matchedTitle':1.0 > 0.5"));
|
assertThat(qryResult, containsString("'matchedTitle':0.0 <= 0.5"));
|
||||||
// assertThat(qryResult, containsString("'matchedTitle':0.0 <= 0.5"));
|
|
||||||
//
|
assertThat(qryResult, containsString(" Go Right "));
|
||||||
// assertThat(qryResult, containsString(" Go Right "));
|
assertThat(qryResult, containsString(" Go Left "));
|
||||||
// assertThat(qryResult, containsString(" Go Left "));
|
assertThat(qryResult, containsString("'this_feature_doesnt_exist' does not exist in FV"));
|
||||||
// assertThat(qryResult,
|
|
||||||
// containsString("'this_feature_doesnt_exist' does not exist in FV"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -71,8 +71,8 @@ public class BoolField extends PrimitiveFieldType {
|
||||||
}
|
}
|
||||||
|
|
||||||
// avoid instantiating every time...
|
// avoid instantiating every time...
|
||||||
protected final static char[] TRUE_TOKEN = {'T'};
|
public final static char[] TRUE_TOKEN = {'T'};
|
||||||
protected final static char[] FALSE_TOKEN = {'F'};
|
public final static char[] FALSE_TOKEN = {'F'};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
// TODO: look into creating my own queryParser that can more efficiently
|
// TODO: look into creating my own queryParser that can more efficiently
|
||||||
|
|
Loading…
Reference in New Issue