SOLR-8542: Adds Solr Learning to Rank (LTR) plugin for reranking results with machine learning models. (Michael Nilsson, Diego Ceccarelli, Joshua Pantony, Jon Dorando, Naveen Santhapuri, Alessandro Benedetti, David Grohmann, Christine Poerschke)

This commit is contained in:
Christine Poerschke 2016-11-01 17:50:14 +00:00
parent b6ff3fdace
commit 5a66b3bc08
117 changed files with 14167 additions and 0 deletions

View File

@ -60,6 +60,7 @@
<module group="Solr/Contrib" filepath="$PROJECT_DIR$/solr/contrib/uima/uima.iml" />
<module group="Solr/Contrib" filepath="$PROJECT_DIR$/solr/contrib/velocity/velocity.iml" />
<module group="Solr/Contrib" filepath="$PROJECT_DIR$/solr/contrib/analytics/analytics.iml" />
<module group="Solr/Contrib" filepath="$PROJECT_DIR$/solr/contrib/ltr/ltr.iml" />
</modules>
</component>
</project>

View File

@ -0,0 +1,37 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="JAVA_MODULE" version="4">
<component name="NewModuleRootManager" inherit-compiler-output="false">
<output url="file://$MODULE_DIR$/../../../idea-build/solr/contrib/ltr/classes/java" />
<output-test url="file://$MODULE_DIR$/../../../idea-build/solr/contrib/ltr/classes/test" />
<exclude-output />
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/src/test" isTestSource="true" />
<sourceFolder url="file://$MODULE_DIR$/src/test-files" type="java-test-resource" />
<sourceFolder url="file://$MODULE_DIR$/src/java" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/src/resources" type="java-resource" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
<orderEntry type="library" scope="TEST" name="JUnit" level="project" />
<orderEntry type="library" name="Solr core library" level="project" />
<orderEntry type="library" name="Solrj library" level="project" />
<orderEntry type="module-library">
<library>
<CLASSES>
<root url="file://$MODULE_DIR$/lib" />
</CLASSES>
<JAVADOC />
<SOURCES />
<jarDirectory url="file://$MODULE_DIR$/lib" recursive="false" />
</library>
</orderEntry>
<orderEntry type="library" scope="TEST" name="Solr example library" level="project" />
<orderEntry type="library" scope="TEST" name="Solr core test library" level="project" />
<orderEntry type="module" scope="TEST" module-name="lucene-test-framework" />
<orderEntry type="module" scope="TEST" module-name="solr-test-framework" />
<orderEntry type="module" module-name="solr-core" />
<orderEntry type="module" module-name="solrj" />
<orderEntry type="module" module-name="lucene-core" />
<orderEntry type="module" module-name="analysis-common" />
</component>
</module>

View File

@ -93,6 +93,9 @@ New Features
SOLR_HOME on every node. Editing config through API is supported but affects only that one node.
(janhoy)
* SOLR-8542: Adds Solr Learning to Rank (LTR) plugin for reranking results with machine learning models.
(Michael Nilsson, Diego Ceccarelli, Joshua Pantony, Jon Dorando, Naveen Santhapuri, Alessandro Benedetti, David Grohmann, Christine Poerschke)
Optimizations
----------------------
* SOLR-9704: Facet Module / JSON Facet API: Optimize blockChildren facets that have

406
solr/contrib/ltr/README.md Normal file
View File

@ -0,0 +1,406 @@
Apache Solr Learning to Rank
========
This is the main [learning to rank integrated into solr](http://www.slideshare.net/lucidworks/learning-to-rank-in-solr-presented-by-michael-nilsson-diego-ceccarelli-bloomberg-lp)
repository.
[Read up on learning to rank](https://en.wikipedia.org/wiki/Learning_to_rank)
Apache Solr Learning to Rank (LTR) provides a way for you to extract features
directly inside Solr for use in training a machine learned model. You can then
deploy that model to Solr and use it to rerank your top X search results.
# Test the plugin with solr/example/techproducts in a few easy steps!
Solr provides some simple example of indices. In order to test the plugin with
the techproducts example please follow these steps.
1. Compile solr and the examples
`cd solr`
`ant dist`
`ant server`
2. Run the example to setup the index
`./bin/solr -e techproducts`
3. Stop solr and install the plugin:
1. Stop solr
`./bin/solr stop`
2. Create the lib folder
`mkdir example/techproducts/solr/techproducts/lib`
3. Install the plugin in the lib folder
`cp build/contrib/ltr/solr-ltr-7.0.0-SNAPSHOT.jar example/techproducts/solr/techproducts/lib/`
4. Replace the original solrconfig with one importing all the ltr components
`cp contrib/ltr/example/solrconfig.xml example/techproducts/solr/techproducts/conf/`
4. Run the example again
`./bin/solr -e techproducts`
Note you could also have just restarted your collection using the admin page.
You can find more detailed instructions [here](https://wiki.apache.org/solr/SolrPlugins).
5. Deploy features and a model
`curl -XPUT 'http://localhost:8983/solr/techproducts/schema/feature-store' --data-binary "@./contrib/ltr/example/techproducts-features.json" -H 'Content-type:application/json'`
`curl -XPUT 'http://localhost:8983/solr/techproducts/schema/model-store' --data-binary "@./contrib/ltr/example/techproducts-model.json" -H 'Content-type:application/json'`
6. Have fun !
* Access to the default feature store
http://localhost:8983/solr/techproducts/schema/feature-store/\_DEFAULT\_
* Access to the model store
http://localhost:8983/solr/techproducts/schema/model-store
* Perform a reranking query using the model, and retrieve the features
http://localhost:8983/solr/techproducts/query?indent=on&q=test&wt=json&rq={!ltr%20model=linear%20reRankDocs=25%20efi.user_query=%27test%27}&fl=[features],price,score,name
BONUS: Train an actual machine learning model
1. Download and install [liblinear](https://www.csie.ntu.edu.tw/~cjlin/liblinear/)
2. Change `contrib/ltr/example/config.json` "trainingLibraryLocation" to point to the train directory where you installed liblinear.
3. Extract features, train a reranking model, and deploy it to Solr.
`cd contrib/ltr/example`
`python train_and_upload_demo_model.py -c config.json`
This script deploys your features from `config.json` "featuresFile" to Solr. Then it takes the relevance judged query
document pairs of "userQueriesFile" and merges it with the features extracted from Solr into a training
file. That file is used to train a linear model, which is then deployed to Solr for you to rerank results.
4. Search and rerank the results using the trained model
http://localhost:8983/solr/techproducts/query?indent=on&q=test&wt=json&rq={!ltr%20model=ExampleModel%20reRankDocs=25%20efi.user_query=%27test%27}&fl=price,score,name
# Changes to solrconfig.xml
```xml
<config>
...
<!-- Query parser used to rerank top docs with a provided model -->
<queryParser name="ltr" class="org.apache.solr.search.LTRQParserPlugin" />
<!-- Transformer that will encode the document features in the response.
For each document the transformer will add the features as an extra field
in the response. The name of the field will be the the name of the
transformer enclosed between brackets (in this case [features]).
In order to get the feature vector you will have to
specify that you want the field (e.g., fl="*,[features]) -->
<transformer name="features" class="org.apache.solr.response.transform.LTRFeatureLoggerTransformerFactory" />
<query>
...
<!-- Cache for storing and fetching feature vectors -->
<cache name="QUERY_DOC_FV"
class="solr.search.LRUCache"
size="4096"
initialSize="2048"
autowarmCount="4096"
regenerator="solr.search.NoOpRegenerator" />
</query>
</config>
```
# Defining Features
In the learning to rank plugin, you can define features in a feature space
using standard Solr queries. As an example:
###### features.json
```json
[
{ "name": "isBook",
"class": "org.apache.solr.ltr.feature.SolrFeature",
"params":{ "fq": ["{!terms f=category}book"] }
},
{
"name": "documentRecency",
"class": "org.apache.solr.ltr.feature.SolrFeature",
"params": {
"q": "{!func}recip( ms(NOW,publish_date), 3.16e-11, 1, 1)"
}
},
{
"name":"originalScore",
"class":"org.apache.solr.ltr.feature.OriginalScoreFeature",
"params":{}
},
{
"name" : "userTextTitleMatch",
"class" : "org.apache.solr.ltr.feature.SolrFeature",
"params" : { "q" : "{!field f=title}${user_text}" }
},
{
"name" : "userFromMobile",
"class" : "org.apache.solr.ltr.feature.ValueFeature",
"params" : { "value" : "${userFromMobile}", "required":true }
}
]
```
Defines five features. Anything that is a valid Solr query can be used to define
a feature.
### Filter Query Features
The first feature isBook fires if the term 'book' matches the category field
for the given examined document. Since in this feature q was not specified,
either the score 1 (in case of a match) or the score 0 (in case of no match)
will be returned.
### Query Features
In the second feature (documentRecency) q was specified using a function query.
In this case the score for the feature on a given document is whatever the query
returns (1 for docs dated now, 1/2 for docs dated 1 year ago, 1/3 for docs dated
2 years ago, etc..) . If both an fq and q is used, documents that don't match
the fq will receive a score of 0 for the documentRecency feature, all other
documents will receive the score specified by the query for this feature.
### Original Score Feature
The third feature (originalScore) has no parameters, and uses the
OriginalScoreFeature class instead of the SolrFeature class. Its purpose is
to simply return the score for the original search request against the current
matching document.
### External Features
Users can specify external information that can to be passed in as
part of the query to the ltr ranking framework. In this case, the
fourth feature (userTextPhraseMatch) will be looking for an external field
called 'user_text' passed in through the request, and will fire if there is
a term match for the document field 'title' from the value of the external
field 'user_text'. You can provide default values for external features as
well by specifying ${myField:myDefault}, similar to how you would in a Solr config.
In this case, the fifth feature (userFromMobile) will be looking for an external parameter
called 'userFromMobile' passed in through the request, if the ValueFeature is :
required=true, it will throw an exception if the external feature is not passed
required=false, it will silently ignore the feature and avoid the scoring ( at Document scoring time, the model will consider 0 as feature value)
The advantage in defining a feature as not required, where possible, is to avoid wasting caching space and time in calculating the featureScore.
See the [Run a Rerank Query](#run-a-rerank-query) section for how to pass in external information.
### Custom Features
Custom features can be created by extending from
org.apache.solr.ltr.feature.Feature, however this is generally not recommended.
The majority of features should be possible to create using the methods described
above.
# Defining Models
Currently the Learning to Rank plugin supports 2 generalized forms of
models: 1. Linear Model i.e. [RankSVM](http://www.cs.cornell.edu/people/tj/publications/joachims_02c.pdf), [Pranking](https://papers.nips.cc/paper/2023-pranking-with-ranking.pdf)
and 2. Multiple Additive Trees i.e. [LambdaMART](http://research.microsoft.com/pubs/132652/MSR-TR-2010-82.pdf), [Gradient Boosted Regression Trees (GBRT)](https://papers.nips.cc/paper/3305-a-general-boosting-method-and-its-application-to-learning-ranking-functions-for-web-search.pdf)
### Linear
If you'd like to introduce a bias set a constant feature
to the bias value you'd like and make a weight of 1.0 for that feature.
###### model.json
```json
{
"class":"org.apache.solr.ltr.model.LinearModel",
"name":"myModelName",
"features":[
{ "name": "userTextTitleMatch"},
{ "name": "originalScore"},
{ "name": "isBook"}
],
"params":{
"weights": {
"userTextTitleMatch": 1.0,
"originalScore": 0.5,
"isBook": 0.1
}
}
}
```
This is an example of a toy Linear model. Class specifies the class to be
using to interpret the model. Name is the model identifier you will use
when making request to the ltr framework. Features specifies the feature
space that you want extracted when using this model. All features that
appear in the model params will be used for scoring and must appear in
the features list. You can add extra features to the features list that
will be computed but not used in the model for scoring, which can be useful
for logging. Params are the Linear parameters.
Good library for training SVM, an example of a Linear model, is
(https://www.csie.ntu.edu.tw/~cjlin/liblinear/ , https://www.csie.ntu.edu.tw/~cjlin/libsvm/) .
You will need to convert the libSVM model format to the format specified above.
### Multiple Additive Trees
###### model2.json
```json
{
"class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
"name":"multipleadditivetreesmodel",
"features":[
{ "name": "userTextTitleMatch"},
{ "name": "originalScore"}
],
"params":{
"trees": [
{
"weight" : 1,
"root": {
"feature": "userTextTitleMatch",
"threshold": 0.5,
"left" : {
"value" : -100
},
"right": {
"feature" : "originalScore",
"threshold": 10.0,
"left" : {
"value" : 50
},
"right" : {
"value" : 75
}
}
}
},
{
"weight" : 2,
"root": {
"value" : -10
}
}
]
}
}
```
This is an example of a toy Multiple Additive Trees. Class specifies the class to be using to
interpret the model. Name is the
model identifier you will use when making request to the ltr framework.
Features specifies the feature space that you want extracted when using this
model. All features that appear in the model params will be used for scoring and
must appear in the features list. You can add extra features to the features
list that will be computed but not used in the model for scoring, which can
be useful for logging. Params are the Multiple Additive Trees specific parameters. In this
case we have 2 trees, one with 3 leaf nodes and one with 1 leaf node.
A good library for training LambdaMART, an example of Multiple Additive Trees, is ( http://sourceforge.net/p/lemur/wiki/RankLib/ ).
You will need to convert the RankLib model format to the format specified above.
# Deploy Models and Features
To send features run
`curl -XPUT 'http://localhost:8983/solr/collection1/schema/feature-store' --data-binary @/path/features.json -H 'Content-type:application/json'`
To send models run
`curl -XPUT 'http://localhost:8983/solr/collection1/schema/model-store' --data-binary @/path/model.json -H 'Content-type:application/json'`
# View Models and Features
`curl -XGET 'http://localhost:8983/solr/collection1/schema/feature-store'`
`curl -XGET 'http://localhost:8983/solr/collection1/schema/model-store'`
# Run a Rerank Query
Add to your original solr query
`rq={!ltr model=myModelName reRankDocs=25}`
The model name is the name of the model you sent to solr earlier.
The number of documents you want reranked, which can be larger than the
number you display, is reRankDocs.
### Pass in external information for external features
Add to your original solr query
`rq={!ltr reRankDocs=3 model=externalmodel efi.field1='text1' efi.field2='text2'}`
Where "field1" specifies the name of the customized field to be used by one
or more of your features, and text1 is the information to be pass in. As an
example that matches the earlier shown userTextTitleMatch feature one could do:
`rq={!ltr reRankDocs=3 model=externalmodel efi.user_text='Casablanca' efi.user_intent='movie'}`
# Extract features
To extract features you need to use the feature vector transformer `features`
`fl=*,score,[features]&rq={!ltr model=yourModel reRankDocs=25}`
If you use `[features]` together with your reranking model, it will return
the array of features used by your model. Otherwise you can just ask solr to
produce the features without doing the reranking:
`fl=*,score,[features store=yourFeatureStore format=[dense|sparse] ]`
This will return the values of the features in the given store. The format of the
extracted features will be based on the format parameter. The default is sparse.
# Assemble training data
In order to train a learning to rank model you need training data. Training data is
what "teaches" the model what the appropriate weight for each feature is. In general
training data is a collection of queries with associated documents and what their ranking/score
should be. As an example:
```
secretary of state|John Kerry|0.66|CROWDSOURCE
secretary of state|Cesar A. Perales|0.33|CROWDSOURCE
secretary of state|New York State|0.0|CROWDSOURCE
secretary of state|Colorado State University Secretary|0.0|CROWDSOURCE
microsoft ceo|Satya Nadella|1.0|CLICK_LOG
microsoft ceo|Microsoft|0.0|CLICK_LOG
microsoft ceo|State|0.0|CLICK_LOG
microsoft ceo|Secretary|0.0|CLICK_LOG
```
In this example the first column indicates the query, the second column indicates a unique id for that doc,
the third column indicates the relative importance or relevance of that doc, and the fourth column indicates the source.
There are 2 primary ways you might collect data for use with your machine learning algorithim. The first
is to collect the clicks of your users given a specific query. There are many ways of preparing this data
to train a model (http://www.cs.cornell.edu/people/tj/publications/joachims_etal_05a.pdf). The general idea
is that if a user sees multiple documents and clicks the one lower down, that document should be scored higher
than the one above it. The second way is explicitly through a crowdsourcing platform like Mechanical Turk or
CrowdFlower. These platforms allow you to show human workers documents associated with a query and have them
tell you what the correct ranking should be.
At this point you'll need to collect feature vectors for each query document pair. You can use the information
from the Extract features section above to do this. An example script has been included in example/train_and_upload_demo_model.py.
# Explanation of the core reranking logic
An LTR model is plugged into the ranking through the [LTRQParserPlugin](/solr/contrib/ltr/src/java/org/apache/solr/search/LTRQParserPlugin.java). The plugin will
read from the request the model, an instance of [LTRScoringModel](/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/LTRScoringModel.java),
plus other parameters. The plugin will generate an LTRQuery, a particular [ReRankQuery](/solr/core/src/java/org/apache/solr/search/AbstractReRankQuery.java).
It wraps the original solr query for the first pass ranking, and uses the provided model in an
[LTRScoringQuery](/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java) to
rescore and rerank the top documents. The LTRScoringQuery will take care of computing the values of all the
[features](/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/Feature.java) and then will delegate the final score
generation to the LTRScoringModel.
# Speeding up the weight creation with threads
About half the time for ranking is spent in the creation of weights for each feature used in ranking. If the number of features is significantly high (say, 500 or more), this increases the ranking overhead proportionally. To alleviate this problem, parallel weight creation is provided as a configurable option. In order to use this feature, the following lines need to be added to the solrconfig.xml
```xml
<config>
<!-- Query parser used to rerank top docs with a provided model -->
<queryParser name="ltr" class="org.apache.solr.search.LTRQParserPlugin">
<int name="threadModule.totalPoolThreads">10</int> <!-- Maximum threads to share for all requests -->
<int name="threadModule.numThreadsPerRequest">5</int> <!-- Maximum threads to use for a single requests-->
</queryParser>
<!-- Transformer for extracting features -->
<transformer name="features" class="org.apache.solr.response.transform.LTRFeatureLoggerTransformerFactory">
<int name="threadModule.totalPoolThreads">10</int> <!-- Maximum threads to share for all requests -->
<int name="threadModule.numThreadsPerRequest">5</int> <!-- Maximum threads to use for a single requests-->
</transformer>
</config>
```
The threadModule.totalPoolThreads option limits the total number of threads to be used across all query instances at any given time. threadModule.numThreadsPerRequest limits the number of threads used to process a single query. In the above example, 10 threads will be used to services all queries and a maximum of 5 threads to service a single query. If the solr instances is expected to receive no more than one query at a time, it is best to set both these numbers to the same value. If multiple queries need to serviced simultaneously, the numbers can be adjusted based on the expected response times. If the value of threadModule.numThreadsPerRequest is higher, the reponse time for a single query will be improved upto a point. If multiple queries are serviced simultaneously, the threadModule.totalPoolThreads imposes a contention between the queries if (threadModule.numThreadsPerRequest*total parallel queries > threadModule.totalPoolThreads).

1
solr/contrib/ltr/README.txt Symbolic link
View File

@ -0,0 +1 @@
README.md

View File

@ -0,0 +1,30 @@
<?xml version="1.0"?>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<project name="solr-ltr" default="default">
<description>
Learning to Rank Package
</description>
<import file="../contrib-build.xml"/>
<target name="compile-core" depends=" solr-contrib-build.compile-core"/>
</project>

View File

@ -0,0 +1,14 @@
{
"host": "localhost",
"port": 8983,
"collection": "techproducts",
"requestHandler": "query",
"q": "*:*",
"otherParams": "fl=id,score,[features efi.user_query='$USERQUERY']",
"userQueriesFile": "user_queries.txt",
"trainingFile": "ClickData",
"featuresFile": "techproducts-features.json",
"trainingLibraryLocation": "liblinear/train",
"solrModelFile": "solrModel.json",
"solrModelName": "ExampleModel"
}

View File

@ -0,0 +1,124 @@
from subprocess import call
import os
PAIRWISE_THRESHOLD = 1.e-1
FEATURE_DIFF_THRESHOLD = 1.e-6
class LibSvmFormatter:
def processQueryDocFeatureVector(self,docClickInfo,trainingFile):
'''Expects as input a sorted by queries list or generator that provides the context
for each query in a tuple composed of: (query , docId , relevance , source , featureVector).
The list of documents that are part of the same query will generate comparisons
against each other for training. '''
curQueryAndSource = "";
with open(trainingFile,"w") as output:
self.featureNameToId = {}
self.featureIdToName = {}
self.curFeatIndex = 1;
curListOfFv = []
for query,docId,relevance,source,featureVector in docClickInfo:
if curQueryAndSource != query + source:
#Time to flush out all the pairs
_writeRankSVMPairs(curListOfFv,output);
curListOfFv = []
curQueryAndSource = query + source
curListOfFv.append((relevance,self._makeFeaturesMap(featureVector)))
_writeRankSVMPairs(curListOfFv,output); #This catches the last list of comparisons
def _makeFeaturesMap(self,featureVector):
'''expects a list of strings with "feature name":"feature value" pairs. Outputs a map of map[key] = value.
Where key is now an integer. libSVM requires the key to be an integer but not all libraries have
this requirement.'''
features = {}
for keyValuePairStr in featureVector:
featName,featValue = keyValuePairStr.split(":");
features[self._getFeatureId(featName)] = float(featValue);
return features
def _getFeatureId(self,key):
if key not in self.featureNameToId:
self.featureNameToId[key] = self.curFeatIndex;
self.featureIdToName[self.curFeatIndex] = key;
self.curFeatIndex += 1;
return self.featureNameToId[key];
def convertLibSvmModelToLtrModel(self,libSvmModelLocation, outputFile, modelName):
with open(libSvmModelLocation, 'r') as inFile:
with open(outputFile,'w') as convertedOutFile:
convertedOutFile.write('{\n\t"class":"org.apache.solr.ltr.model.LinearModel",\n')
convertedOutFile.write('\t"name": "' + str(modelName) + '",\n')
convertedOutFile.write('\t"features": [\n')
isFirst = True;
for featKey in self.featureNameToId.keys():
convertedOutFile.write('\t\t{ "name":"' + featKey + '"}' if isFirst else ',\n\t\t{ "name":"' + featKey + '"}' );
isFirst = False;
convertedOutFile.write("\n\t],\n");
convertedOutFile.write('\t"params": {\n\t\t"weights": {\n');
startReading = False
isFirst = True
counter = 1
for line in inFile:
if startReading:
newParamVal = float(line.strip())
if not isFirst:
convertedOutFile.write(',\n\t\t\t"' + self.featureIdToName[counter] + '":' + str(newParamVal))
else:
convertedOutFile.write('\t\t\t"' + self.featureIdToName[counter] + '":' + str(newParamVal))
isFirst = False
counter += 1
elif line.strip() == 'w':
startReading = True
convertedOutFile.write('\n\t\t}\n\t}\n}')
def _writeRankSVMPairs(listOfFeatures,output):
'''Given a list of (relevance, {Features Map}) where the list represents
a set of documents to be compared, this calculates all pairs and
writes the Feature Vectors in a format compatible with libSVM.
Ex: listOfFeatures = [
#(relevance, {feature1:value, featureN:value})
(4, {1:0.9, 2:0.9, 3:0.1})
(3, {1:0.7, 2:0.9, 3:0.2})
(1, {1:0.1, 2:0.9, 6:0.1})
]
'''
for d1 in range(0,len(listOfFeatures)):
for d2 in range(d1+1,len(listOfFeatures)):
doc1,doc2 = listOfFeatures[d1], listOfFeatures[d2]
fv1,fv2 = doc1[1],doc2[1]
d1Relevance, d2Relevance = float(doc1[0]),float(doc2[0])
if d1Relevance - d2Relevance > PAIRWISE_THRESHOLD:#d1Relevance > d2Relevance
outputLibSvmLine("+1",subtractFvMap(fv1,fv2),output);
outputLibSvmLine("-1",subtractFvMap(fv2,fv1),output);
elif d1Relevance - d2Relevance < -PAIRWISE_THRESHOLD: #d1Relevance < d2Relevance:
outputLibSvmLine("+1",subtractFvMap(fv2,fv1),output);
outputLibSvmLine("-1",subtractFvMap(fv1,fv2),output);
else: #Must be approximately equal relevance, in which case this is a useless signal and we should skip
continue;
def subtractFvMap(fv1,fv2):
'''returns the fv from fv1 - fv2'''
retFv = fv1.copy();
for featInd in fv2.keys():
subVal = 0.0;
if featInd in fv1:
subVal = fv1[featInd] - fv2[featInd]
else:
subVal = -fv2[featInd]
if abs(subVal) > FEATURE_DIFF_THRESHOLD: #This ensures everything is in sparse format, and removes useless signals
retFv[featInd] = subVal;
else:
retFv.pop(featInd, None)
return retFv;
def outputLibSvmLine(sign,fvMap,outputFile):
outputFile.write(sign)
for feat in fvMap.keys():
outputFile.write(" " + str(feat) + ":" + str(fvMap[feat]));
outputFile.write("\n")
def trainLibSvm(libraryLocation,trainingFileName):
if os.path.isfile(libraryLocation):
call([libraryLocation, trainingFileName])
else:
raise Exception("NO LIBRARY FOUND: " + libraryLocation);

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,26 @@
[
{
"name": "isInStock",
"class": "org.apache.solr.ltr.feature.FieldValueFeature",
"params": {
"field": "inStock"
}
},
{
"name": "price",
"class": "org.apache.solr.ltr.feature.FieldValueFeature",
"params": {
"field": "price"
}
},
{
"name":"originalScore",
"class":"org.apache.solr.ltr.feature.OriginalScoreFeature",
"params":{}
},
{
"name" : "productNameMatchQuery",
"class" : "org.apache.solr.ltr.feature.SolrFeature",
"params" : { "q" : "{!field f=name}${user_query}" }
}
]

View File

@ -0,0 +1,18 @@
{
"class":"org.apache.solr.ltr.model.LinearModel",
"name":"linear",
"features":[
{"name":"isInStock"},
{"name":"price"},
{"name":"originalScore"},
{"name":"productNameMatchQuery"}
],
"params":{
"weights":{
"isInStock":15.0,
"price":1.0,
"originalScore":5.0,
"productNameMatchQuery":1.0
}
}
}

View File

@ -0,0 +1,163 @@
#!/usr/bin/env python
import sys
import json
import httplib
import urllib
import libsvm_formatter
from optparse import OptionParser
solrQueryUrl = ""
def generateQueries(config):
with open(config["userQueriesFile"]) as input:
solrQueryUrls = [] #A list of tuples with solrQueryUrl,solrQuery,docId,scoreForPQ,source
for line in input:
line = line.strip();
searchText,docId,score,source = line.split("|");
solrQuery = generateHttpRequest(config,searchText,docId)
solrQueryUrls.append((solrQuery,searchText,docId,score,source))
return solrQueryUrls;
def generateHttpRequest(config,searchText,docId):
global solrQueryUrl
if len(solrQueryUrl) < 1:
solrQueryUrl = "/solr/%(collection)s/%(requestHandler)s?%(otherParams)s&q=" % config
solrQueryUrl = solrQueryUrl.replace(" ","+")
solrQueryUrl += urllib.quote_plus("id:")
userQuery = urllib.quote_plus(searchText.strip().replace("'","\\'").replace("/","\\\\/"))
solrQuery = solrQueryUrl + '"' + urllib.quote_plus(docId) + '"' #+ solrQueryUrlEnd
solrQuery = solrQuery.replace("%24USERQUERY", userQuery).replace('$USERQUERY', urllib.quote_plus("\\'" + userQuery + "\\'"))
return solrQuery
def generateTrainingData(solrQueries, config):
'''Given a list of solr queries, yields a tuple of query , docId , score , source , feature vector for each query.
Feature Vector is a list of strings of form "key:value"'''
conn = httplib.HTTPConnection(config["host"], config["port"])
headers = {"Connection":" keep-alive"}
try:
for queryUrl,query,docId,score,source in solrQueries:
conn.request("GET", queryUrl, headers=headers)
r = conn.getresponse()
msg = r.read()
msgDict = json.loads(msg)
fv = ""
docs = msgDict['response']['docs']
if len(docs) > 0 and "[features]" in docs[0]:
if not msgDict['response']['docs'][0]["[features]"] == None:
fv = msgDict['response']['docs'][0]["[features]"];
else:
print "ERROR NULL FV FOR: " + docId;
print msg
continue;
else:
print "ERROR FOR: " + docId;
print msg
continue;
if r.status == httplib.OK:
#print "http connection was ok for: " + queryUrl
yield(query,docId,score,source,fv.split(";"));
else:
raise Exception("Status: {0} {1}\nResponse: {2}".format(r.status, r.reason, msg))
except Exception as e:
print msg
print e
conn.close()
def setupSolr(config):
'''Sets up solr with the proper features for the test'''
conn = httplib.HTTPConnection(config["host"], config["port"])
baseUrl = "/solr/" + config["collection"]
featureUrl = baseUrl + "/schema/feature-store"
# CAUTION! This will delete all feature stores. This is just for demo purposes
conn.request("DELETE", featureUrl+"/*")
r = conn.getresponse()
msg = r.read()
if (r.status != httplib.OK and
r.status != httplib.CREATED and
r.status != httplib.ACCEPTED and
r.status != httplib.NOT_FOUND):
raise Exception("Status: {0} {1}\nResponse: {2}".format(r.status, r.reason, msg))
# Add features
headers = {'Content-type': 'application/json'}
featuresBody = open(config["featuresFile"])
conn.request("POST", featureUrl, featuresBody, headers)
r = conn.getresponse()
msg = r.read()
if (r.status != httplib.OK and
r.status != httplib.ACCEPTED):
print r.status
print ""
print r.reason;
raise Exception("Status: {0} {1}\nResponse: {2}".format(r.status, r.reason, msg))
conn.close()
def main(argv=None):
if argv is None:
argv = sys.argv
parser = OptionParser(usage="usage: %prog [options] ", version="%prog 1.0")
parser.add_option('-c', '--config',
dest='configFile',
help='File of configuration for the test')
(options, args) = parser.parse_args()
if options.configFile == None:
parser.print_help()
return 1
with open(options.configFile) as configFile:
config = json.load(configFile)
print "Uploading feature space to Solr"
setupSolr(config)
print "Generating feature extraction Solr queries"
reRankQueries = generateQueries(config)
print "Extracting features"
fvGenerator = generateTrainingData(reRankQueries, config);
formatter = libsvm_formatter.LibSvmFormatter();
formatter.processQueryDocFeatureVector(fvGenerator,config["trainingFile"]);
print "Training ranksvm model"
libsvm_formatter.trainLibSvm(config["trainingLibraryLocation"],config["trainingFile"])
print "Converting ranksvm model to solr model"
formatter.convertLibSvmModelToLtrModel(config["trainingFile"] + ".model", config["solrModelFile"], config["solrModelName"])
print "Uploading model to solr"
uploadModel(config["collection"], config["host"], config["port"], config["solrModelFile"])
def uploadModel(collection, host, port, modelFile):
modelUrl = "/solr/" + collection + "/schema/model-store"
headers = {'Content-type': 'application/json'}
with open(modelFile) as modelBody:
conn = httplib.HTTPConnection(host, port)
conn.request("POST", modelUrl, modelBody, headers)
r = conn.getresponse()
msg = r.read()
if (r.status != httplib.OK and
r.status != httplib.CREATED and
r.status != httplib.ACCEPTED):
raise Exception("Status: {0} {1}\nResponse: {2}".format(r.status, r.reason, msg))
if __name__ == '__main__':
sys.exit(main())

View File

@ -0,0 +1,8 @@
hard drive|SP2514N|0.6666666|CLICK_LOGS
hard drive|6H500F0|0.330082034|CLICK_LOGS
hard drive|F8V7067-APL-KIT|0.0|CLICK_LOGS
hard drive|IW-02|0.0|CLICK_LOGS
ipod|MA147LL/A|1.0|EXPLICIT
ipod|F8V7067-APL-KIT|0.25|EXPLICIT
ipod|IW-02|0.25|EXPLICIT
ipod|6H500F0|0.0|EXPLICIT

32
solr/contrib/ltr/ivy.xml Normal file
View File

@ -0,0 +1,32 @@
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
<ivy-module version="2.0">
<info organisation="org.apache.solr" module="ltr"/>
<configurations defaultconfmapping="compile->master;test->master">
<conf name="compile" transitive="false"/> <!-- keep unused 'compile' configuration to allow build to succeed -->
<conf name="test" transitive="false"/>
</configurations>
<dependencies>
<dependency org="org.slf4j" name="jcl-over-slf4j" rev="${/org.slf4j/jcl-over-slf4j}" conf="test"/>
<exclude org="*" ext="*" matcher="regexp" type="${ivy.exclude.types}"/>
</dependencies>
</ivy-module>

View File

@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr;
import java.util.HashMap;
public class DocInfo extends HashMap<String,Object> {
// Name of key used to store the original score of a doc
private static final String ORIGINAL_DOC_SCORE = "ORIGINAL_DOC_SCORE";
public DocInfo() {
super();
}
public void setOriginalDocScore(Float score) {
put(ORIGINAL_DOC_SCORE, score);
}
public Float getOriginalDocScore() {
return (Float)get(ORIGINAL_DOC_SCORE);
}
public boolean hasOriginalDocScore() {
return containsKey(ORIGINAL_DOC_SCORE);
}
}

View File

@ -0,0 +1,193 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr;
import java.lang.invoke.MethodHandles;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.apache.solr.search.SolrIndexSearcher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* FeatureLogger can be registered in a model and provide a strategy for logging
* the feature values.
*/
public abstract class FeatureLogger<FV_TYPE> {
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
/** the name of the cache using for storing the feature value **/
private static final String QUERY_FV_CACHE_NAME = "QUERY_DOC_FV";
protected enum FeatureFormat {DENSE, SPARSE};
protected final FeatureFormat featureFormat;
protected FeatureLogger(FeatureFormat f) {
this.featureFormat = f;
}
/**
* Log will be called every time that the model generates the feature values
* for a document and a query.
*
* @param docid
* Solr document id whose features we are saving
* @param featuresInfo
* List of all the {@link LTRScoringQuery.FeatureInfo} objects which contain name and value
* for all the features triggered by the result set
* @return true if the logger successfully logged the features, false
* otherwise.
*/
public boolean log(int docid, LTRScoringQuery scoringQuery,
SolrIndexSearcher searcher, LTRScoringQuery.FeatureInfo[] featuresInfo) {
final FV_TYPE featureVector = makeFeatureVector(featuresInfo);
if (featureVector == null) {
return false;
}
return searcher.cacheInsert(QUERY_FV_CACHE_NAME,
fvCacheKey(scoringQuery, docid), featureVector) != null;
}
/**
* returns a FeatureLogger that logs the features in output, using the format
* specified in the 'stringFormat' param: 'csv' will log the features as a unique
* string in csv format 'json' will log the features in a map in a Map of
* featureName keys to featureValue values if format is null or empty, csv
* format will be selected.
* 'featureFormat' param: 'dense' will write features in dense format,
* 'sparse' will write the features in sparse format, null or empty will
* default to 'sparse'
*
*
* @return a feature logger for the format specified.
*/
public static FeatureLogger<?> createFeatureLogger(String stringFormat, String featureFormat) {
final FeatureFormat f;
if (featureFormat == null || featureFormat.isEmpty() ||
featureFormat.equals("sparse")) {
f = FeatureFormat.SPARSE;
}
else if (featureFormat.equals("dense")) {
f = FeatureFormat.DENSE;
}
else {
f = FeatureFormat.SPARSE;
log.warn("unknown feature logger feature format {} | {}", stringFormat, featureFormat);
}
if ((stringFormat == null) || stringFormat.isEmpty()) {
return new CSVFeatureLogger(f);
}
if (stringFormat.equals("csv")) {
return new CSVFeatureLogger(f);
}
if (stringFormat.equals("json")) {
return new MapFeatureLogger(f);
}
log.warn("unknown feature logger string format {} | {}", stringFormat, featureFormat);
return null;
}
public abstract FV_TYPE makeFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo);
private static int fvCacheKey(LTRScoringQuery scoringQuery, int docid) {
return scoringQuery.hashCode() + (31 * docid);
}
/**
* populate the document with its feature vector
*
* @param docid
* Solr document id
* @return String representation of the list of features calculated for docid
*/
public FV_TYPE getFeatureVector(int docid, LTRScoringQuery scoringQuery,
SolrIndexSearcher searcher) {
return (FV_TYPE) searcher.cacheLookup(QUERY_FV_CACHE_NAME, fvCacheKey(scoringQuery, docid));
}
public static class MapFeatureLogger extends FeatureLogger<Map<String,Float>> {
public MapFeatureLogger(FeatureFormat f) {
super(f);
}
@Override
public Map<String,Float> makeFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo) {
boolean isDense = featureFormat.equals(FeatureFormat.DENSE);
Map<String,Float> hashmap = Collections.emptyMap();
if (featuresInfo.length > 0) {
hashmap = new HashMap<String,Float>(featuresInfo.length);
for (LTRScoringQuery.FeatureInfo featInfo:featuresInfo){
if (featInfo.isUsed() || isDense){
hashmap.put(featInfo.getName(), featInfo.getValue());
}
}
}
return hashmap;
}
}
public static class CSVFeatureLogger extends FeatureLogger<String> {
StringBuilder sb = new StringBuilder(500);
char keyValueSep = ':';
char featureSep = ';';
public CSVFeatureLogger(FeatureFormat f) {
super(f);
}
public CSVFeatureLogger setKeyValueSep(char keyValueSep) {
this.keyValueSep = keyValueSep;
return this;
}
public CSVFeatureLogger setFeatureSep(char featureSep) {
this.featureSep = featureSep;
return this;
}
@Override
public String makeFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo) {
boolean isDense = featureFormat.equals(FeatureFormat.DENSE);
for (LTRScoringQuery.FeatureInfo featInfo:featuresInfo) {
if (featInfo.isUsed() || isDense){
sb.append(featInfo.getName())
.append(keyValueSep)
.append(featInfo.getValue())
.append(featureSep);
}
}
final String features = (sb.length() > 0 ? sb.substring(0,
sb.length() - 1) : "");
sb.setLength(0);
return features;
}
}
}

View File

@ -0,0 +1,249 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Rescorer;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
import org.apache.solr.search.SolrIndexSearcher;
/**
* Implements the rescoring logic. The top documents returned by solr with their
* original scores, will be processed by a {@link LTRScoringQuery} that will assign a
* new score to each document. The top documents will be resorted based on the
* new score.
* */
public class LTRRescorer extends Rescorer {
LTRScoringQuery scoringQuery;
public LTRRescorer(LTRScoringQuery scoringQuery) {
this.scoringQuery = scoringQuery;
}
private void heapAdjust(ScoreDoc[] hits, int size, int root) {
final ScoreDoc doc = hits[root];
final float score = doc.score;
int i = root;
while (i <= ((size >> 1) - 1)) {
final int lchild = (i << 1) + 1;
final ScoreDoc ldoc = hits[lchild];
final float lscore = ldoc.score;
float rscore = Float.MAX_VALUE;
final int rchild = (i << 1) + 2;
ScoreDoc rdoc = null;
if (rchild < size) {
rdoc = hits[rchild];
rscore = rdoc.score;
}
if (lscore < score) {
if (rscore < lscore) {
hits[i] = rdoc;
hits[rchild] = doc;
i = rchild;
} else {
hits[i] = ldoc;
hits[lchild] = doc;
i = lchild;
}
} else if (rscore < score) {
hits[i] = rdoc;
hits[rchild] = doc;
i = rchild;
} else {
return;
}
}
}
private void heapify(ScoreDoc[] hits, int size) {
for (int i = (size >> 1) - 1; i >= 0; i--) {
heapAdjust(hits, size, i);
}
}
/**
* rescores the documents:
*
* @param searcher
* current IndexSearcher
* @param firstPassTopDocs
* documents to rerank;
* @param topN
* documents to return;
*/
@Override
public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs,
int topN) throws IOException {
if ((topN == 0) || (firstPassTopDocs.totalHits == 0)) {
return firstPassTopDocs;
}
final ScoreDoc[] hits = firstPassTopDocs.scoreDocs;
Arrays.sort(hits, new Comparator<ScoreDoc>() {
@Override
public int compare(ScoreDoc a, ScoreDoc b) {
return a.doc - b.doc;
}
});
topN = Math.min(topN, firstPassTopDocs.totalHits);
final ScoreDoc[] reranked = new ScoreDoc[topN];
final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) searcher
.createNormalizedWeight(scoringQuery, true);
final SolrIndexSearcher solrIndexSearch = (SolrIndexSearcher) searcher;
scoreFeatures(solrIndexSearch, firstPassTopDocs,topN, modelWeight, hits, leaves, reranked);
// Must sort all documents that we reranked, and then select the top
Arrays.sort(reranked, new Comparator<ScoreDoc>() {
@Override
public int compare(ScoreDoc a, ScoreDoc b) {
// Sort by score descending, then docID ascending:
if (a.score > b.score) {
return -1;
} else if (a.score < b.score) {
return 1;
} else {
// This subtraction can't overflow int
// because docIDs are >= 0:
return a.doc - b.doc;
}
}
});
return new TopDocs(firstPassTopDocs.totalHits, reranked, reranked[0].score);
}
public void scoreFeatures(SolrIndexSearcher solrIndexSearch, TopDocs firstPassTopDocs,
int topN, LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List<LeafReaderContext> leaves,
ScoreDoc[] reranked) throws IOException {
int readerUpto = -1;
int endDoc = 0;
int docBase = 0;
LTRScoringQuery.ModelWeight.ModelScorer scorer = null;
int hitUpto = 0;
final FeatureLogger<?> featureLogger = scoringQuery.getFeatureLogger();
while (hitUpto < hits.length) {
final ScoreDoc hit = hits[hitUpto];
final int docID = hit.doc;
LeafReaderContext readerContext = null;
while (docID >= endDoc) {
readerUpto++;
readerContext = leaves.get(readerUpto);
endDoc = readerContext.docBase + readerContext.reader().maxDoc();
}
// We advanced to another segment
if (readerContext != null) {
docBase = readerContext.docBase;
scorer = modelWeight.scorer(readerContext);
}
// Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to
// call score
// even if no feature scorers match, since a model might use that info to
// return a
// non-zero score. Same applies for the case of advancing a LTRScoringQuery.ModelWeight.ModelScorer
// past the target
// doc since the model algorithm still needs to compute a potentially
// non-zero score from blank features.
assert (scorer != null);
final int targetDoc = docID - docBase;
scorer.docID();
scorer.iterator().advance(targetDoc);
scorer.getDocInfo().setOriginalDocScore(new Float(hit.score));
hit.score = scorer.score();
if (hitUpto < topN) {
reranked[hitUpto] = hit;
// if the heap is not full, maybe I want to log the features for this
// document
if (featureLogger != null) {
featureLogger.log(hit.doc, scoringQuery, solrIndexSearch,
modelWeight.getFeaturesInfo());
}
} else if (hitUpto == topN) {
// collected topN document, I create the heap
heapify(reranked, topN);
}
if (hitUpto >= topN) {
// once that heap is ready, if the score of this document is lower that
// the minimum
// i don't want to log the feature. Otherwise I replace it with the
// minimum and fix the
// heap.
if (hit.score > reranked[0].score) {
reranked[0] = hit;
heapAdjust(reranked, topN, 0);
if (featureLogger != null) {
featureLogger.log(hit.doc, scoringQuery, solrIndexSearch,
modelWeight.getFeaturesInfo());
}
}
}
hitUpto++;
}
}
@Override
public Explanation explain(IndexSearcher searcher,
Explanation firstPassExplanation, int docID) throws IOException {
final List<LeafReaderContext> leafContexts = searcher.getTopReaderContext()
.leaves();
final int n = ReaderUtil.subIndex(docID, leafContexts);
final LeafReaderContext context = leafContexts.get(n);
final int deBasedDoc = docID - context.docBase;
final Weight modelWeight = searcher.createNormalizedWeight(scoringQuery,
true);
return modelWeight.explain(context, deBasedDoc);
}
public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo(LTRScoringQuery.ModelWeight modelWeight,
int docid,
Float originalDocScore,
List<LeafReaderContext> leafContexts)
throws IOException {
final int n = ReaderUtil.subIndex(docid, leafContexts);
final LeafReaderContext atomicContext = leafContexts.get(n);
final int deBasedDoc = docid - atomicContext.docBase;
final LTRScoringQuery.ModelWeight.ModelScorer r = modelWeight.scorer(atomicContext);
if ( (r == null) || (r.iterator().advance(deBasedDoc) != deBasedDoc) ) {
return new LTRScoringQuery.FeatureInfo[0];
} else {
if (originalDocScore != null) {
// If results have not been reranked, the score passed in is the original query's
// score, which some features can use instead of recalculating it
r.getDocInfo().setOriginalDocScore(originalDocScore);
}
r.score();
return modelWeight.getFeaturesInfo();
}
}
}

View File

@ -0,0 +1,738 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr;
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;
import java.util.concurrent.RunnableFuture;
import java.util.concurrent.Semaphore;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.DisiPriorityQueue;
import org.apache.lucene.search.DisiWrapper;
import org.apache.lucene.search.DisjunctionDISIApproximation;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.request.SolrQueryRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* The ranking query that is run, reranking results using the
* LTRScoringModel algorithm
*/
public class LTRScoringQuery extends Query {
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
// contains a description of the model
final private LTRScoringModel ltrScoringModel;
final private boolean extractAllFeatures;
final private LTRThreadModule ltrThreadMgr;
final private Semaphore querySemaphore; // limits the number of threads per query, so that multiple requests can be serviced simultaneously
// feature logger to output the features.
private FeatureLogger<?> fl;
// Map of external parameters, such as query intent, that can be used by
// features
final private Map<String,String[]> efi;
// Original solr query used to fetch matching documents
private Query originalQuery;
// Original solr request
private SolrQueryRequest request;
public LTRScoringQuery(LTRScoringModel ltrScoringModel) {
this(ltrScoringModel, Collections.<String,String[]>emptyMap(), false, null);
}
public LTRScoringQuery(LTRScoringModel ltrScoringModel, boolean extractAllFeatures) {
this(ltrScoringModel, Collections.<String, String[]>emptyMap(), extractAllFeatures, null);
}
public LTRScoringQuery(LTRScoringModel ltrScoringModel,
Map<String, String[]> externalFeatureInfo,
boolean extractAllFeatures, LTRThreadModule ltrThreadMgr) {
this.ltrScoringModel = ltrScoringModel;
this.efi = externalFeatureInfo;
this.extractAllFeatures = extractAllFeatures;
this.ltrThreadMgr = ltrThreadMgr;
if (this.ltrThreadMgr != null) {
this.querySemaphore = this.ltrThreadMgr.createQuerySemaphore();
} else{
this.querySemaphore = null;
}
}
public LTRScoringModel getScoringModel() {
return ltrScoringModel;
}
public void setFeatureLogger(FeatureLogger fl) {
this.fl = fl;
}
public FeatureLogger getFeatureLogger() {
return fl;
}
public void setOriginalQuery(Query originalQuery) {
this.originalQuery = originalQuery;
}
public Query getOriginalQuery() {
return originalQuery;
}
public Map<String,String[]> getExternalFeatureInfo() {
return efi;
}
public void setRequest(SolrQueryRequest request) {
this.request = request;
}
public SolrQueryRequest getRequest() {
return request;
}
@Override
public int hashCode() {
final int prime = 31;
int result = classHash();
result = (prime * result) + ((ltrScoringModel == null) ? 0 : ltrScoringModel.hashCode());
result = (prime * result)
+ ((originalQuery == null) ? 0 : originalQuery.hashCode());
if (efi == null) {
result = (prime * result) + 0;
}
else {
for (final Map.Entry<String,String[]> entry : efi.entrySet()) {
final String key = entry.getKey();
final String[] values = entry.getValue();
result = (prime * result) + key.hashCode();
result = (prime * result) + Arrays.hashCode(values);
}
}
result = (prime * result) + this.toString().hashCode();
return result;
}
@Override
public boolean equals(Object o) {
return sameClassAs(o) && equalsTo(getClass().cast(o));
}
private boolean equalsTo(LTRScoringQuery other) {
if (ltrScoringModel == null) {
if (other.ltrScoringModel != null) {
return false;
}
} else if (!ltrScoringModel.equals(other.ltrScoringModel)) {
return false;
}
if (originalQuery == null) {
if (other.originalQuery != null) {
return false;
}
} else if (!originalQuery.equals(other.originalQuery)) {
return false;
}
if (efi == null) {
if (other.efi != null) {
return false;
}
} else {
if (other.efi == null || efi.size() != other.efi.size()) {
return false;
}
for(final Map.Entry<String,String[]> entry : efi.entrySet()) {
final String key = entry.getKey();
final String[] otherValues = other.efi.get(key);
if (otherValues == null || !Arrays.equals(otherValues,entry.getValue())) {
return false;
}
}
}
return true;
}
@Override
public ModelWeight createWeight(IndexSearcher searcher, boolean needsScores, float boost)
throws IOException {
final Collection<Feature> modelFeatures = ltrScoringModel.getFeatures();
final Collection<Feature> allFeatures = ltrScoringModel.getAllFeatures();
int modelFeatSize = modelFeatures.size();
Collection<Feature> features = null;
if (this.extractAllFeatures) {
features = allFeatures;
}
else{
features = modelFeatures;
}
final Feature.FeatureWeight[] extractedFeatureWeights = new Feature.FeatureWeight[features.size()];
final Feature.FeatureWeight[] modelFeaturesWeights = new Feature.FeatureWeight[modelFeatSize];
List<Feature.FeatureWeight > featureWeights = new ArrayList<>(features.size());
if (querySemaphore == null) {
createWeights(searcher, needsScores, boost, featureWeights, features);
}
else{
createWeightsParallel(searcher, needsScores, boost, featureWeights, features);
}
int i=0, j = 0;
if (this.extractAllFeatures) {
for (final Feature.FeatureWeight fw : featureWeights) {
extractedFeatureWeights[i++] = fw;
}
for (final Feature f : modelFeatures){
modelFeaturesWeights[j++] = extractedFeatureWeights[f.getIndex()]; // we can lookup by featureid because all features will be extracted when this.extractAllFeatures is set
}
}
else{
for (final Feature.FeatureWeight fw: featureWeights){
extractedFeatureWeights[i++] = fw;
modelFeaturesWeights[j++] = fw;
}
}
return new ModelWeight(modelFeaturesWeights, extractedFeatureWeights, allFeatures.size());
}
private void createWeights(IndexSearcher searcher, boolean needsScores, float boost,
List<Feature.FeatureWeight > featureWeights, Collection<Feature> features) throws IOException {
final SolrQueryRequest req = getRequest();
// since the feature store is a linkedhashmap order is preserved
for (final Feature f : features) {
try{
Feature.FeatureWeight fw = f.createWeight(searcher, needsScores, req, originalQuery, efi);
featureWeights.add(fw);
} catch (final Exception e) {
throw new RuntimeException("Exception from createWeight for " + f.toString() + " "
+ e.getMessage(), e);
}
}
}
private class CreateWeightCallable implements Callable<Feature.FeatureWeight>{
final private Feature f;
final private IndexSearcher searcher;
final private boolean needsScores;
final private SolrQueryRequest req;
public CreateWeightCallable(Feature f, IndexSearcher searcher, boolean needsScores, SolrQueryRequest req){
this.f = f;
this.searcher = searcher;
this.needsScores = needsScores;
this.req = req;
}
@Override
public Feature.FeatureWeight call() throws Exception{
try {
Feature.FeatureWeight fw = f.createWeight(searcher, needsScores, req, originalQuery, efi);
return fw;
} catch (final Exception e) {
throw new RuntimeException("Exception from createWeight for " + f.toString() + " "
+ e.getMessage(), e);
} finally {
querySemaphore.release();
ltrThreadMgr.releaseLTRSemaphore();
}
}
} // end of call CreateWeightCallable
private void createWeightsParallel(IndexSearcher searcher, boolean needsScores, float boost,
List<Feature.FeatureWeight > featureWeights, Collection<Feature> features) throws RuntimeException {
final SolrQueryRequest req = getRequest();
List<Future<Feature.FeatureWeight> > futures = new ArrayList<>(features.size());
try{
for (final Feature f : features) {
CreateWeightCallable callable = new CreateWeightCallable(f, searcher, needsScores, req);
RunnableFuture<Feature.FeatureWeight> runnableFuture = new FutureTask<>(callable);
querySemaphore.acquire(); // always acquire before the ltrSemaphore is acquired, to guarantee a that the current query is within the limit for max. threads
ltrThreadMgr.acquireLTRSemaphore();//may block and/or interrupt
ltrThreadMgr.execute(runnableFuture);//releases semaphore when done
futures.add(runnableFuture);
}
//Loop over futures to get the feature weight objects
for (final Future<Feature.FeatureWeight> future : futures) {
featureWeights.add(future.get()); // future.get() will block if the job is still running
}
} catch (Exception e) { // To catch InterruptedException and ExecutionException
log.info("Error while creating weights in LTR: InterruptedException", e);
throw new RuntimeException("Error while creating weights in LTR: " + e.getMessage(), e);
}
}
@Override
public String toString(String field) {
return field;
}
public class FeatureInfo {
final private String name;
private float value;
private boolean used;
FeatureInfo(String n, float v, boolean u){
name = n; value = v; used = u;
}
public void setValue(float value){
this.value = value;
}
public String getName(){
return name;
}
public float getValue(){
return value;
}
public boolean isUsed(){
return used;
}
public void setUsed(boolean used){
this.used = used;
}
}
public class ModelWeight extends Weight {
// List of the model's features used for scoring. This is a subset of the
// features used for logging.
final private Feature.FeatureWeight[] modelFeatureWeights;
final private float[] modelFeatureValuesNormalized;
final private Feature.FeatureWeight[] extractedFeatureWeights;
// List of all the feature names, values - used for both scoring and logging
/*
* What is the advantage of using a hashmap here instead of an array of objects?
* A set of arrays was used earlier and the elements were accessed using the featureId.
* With the updated logic to create weights selectively,
* the number of elements in the array can be fewer than the total number of features.
* When [features] are not requested, only the model features are extracted.
* In this case, the indexing by featureId, fails. For this reason,
* we need a map which holds just the features that were triggered by the documents in the result set.
*
*/
final private FeatureInfo[] featuresInfo;
/*
* @param modelFeatureWeights
* - should be the same size as the number of features used by the model
* @param extractedFeatureWeights
* - if features are requested from the same store as model feature store,
* this will be the size of total number of features in the model feature store
* else, this will be the size of the modelFeatureWeights
* @param allFeaturesSize
* - total number of feature in the feature store used by this model
*/
public ModelWeight(Feature.FeatureWeight[] modelFeatureWeights,
Feature.FeatureWeight[] extractedFeatureWeights, int allFeaturesSize) {
super(LTRScoringQuery.this);
this.extractedFeatureWeights = extractedFeatureWeights;
this.modelFeatureWeights = modelFeatureWeights;
this.modelFeatureValuesNormalized = new float[modelFeatureWeights.length];
this.featuresInfo = new FeatureInfo[allFeaturesSize];
setFeaturesInfo();
}
private void setFeaturesInfo(){
for (int i = 0; i < extractedFeatureWeights.length;++i){
String featName = extractedFeatureWeights[i].getName();
int featId = extractedFeatureWeights[i].getIndex();
float value = extractedFeatureWeights[i].getDefaultValue();
featuresInfo[featId] = new FeatureInfo(featName,value,false);
}
}
public FeatureInfo[] getFeaturesInfo(){
return featuresInfo;
}
// for test use
Feature.FeatureWeight[] getModelFeatureWeights() {
return modelFeatureWeights;
}
// for test use
float[] getModelFeatureValuesNormalized() {
return modelFeatureValuesNormalized;
}
// for test use
Feature.FeatureWeight[] getExtractedFeatureWeights() {
return extractedFeatureWeights;
}
/**
* Goes through all the stored feature values, and calculates the normalized
* values for all the features that will be used for scoring.
*/
private void makeNormalizedFeatures() {
int pos = 0;
for (final Feature.FeatureWeight feature : modelFeatureWeights) {
final int featureId = feature.getIndex();
FeatureInfo fInfo = featuresInfo[featureId];
if (fInfo.isUsed()) { // not checking for finfo == null as that would be a bug we should catch
modelFeatureValuesNormalized[pos] = fInfo.getValue();
} else {
modelFeatureValuesNormalized[pos] = feature.getDefaultValue();
}
pos++;
}
ltrScoringModel.normalizeFeaturesInPlace(modelFeatureValuesNormalized);
}
@Override
public Explanation explain(LeafReaderContext context, int doc)
throws IOException {
final Explanation[] explanations = new Explanation[this.featuresInfo.length];
for (final Feature.FeatureWeight feature : extractedFeatureWeights) {
explanations[feature.getIndex()] = feature.explain(context, doc);
}
final List<Explanation> featureExplanations = new ArrayList<>();
for (int idx = 0 ;idx < modelFeatureWeights.length; ++idx) {
final Feature.FeatureWeight f = modelFeatureWeights[idx];
Explanation e = ltrScoringModel.getNormalizerExplanation(explanations[f.getIndex()], idx);
featureExplanations.add(e);
}
final ModelScorer bs = scorer(context);
bs.iterator().advance(doc);
final float finalScore = bs.score();
return ltrScoringModel.explain(context, doc, finalScore, featureExplanations);
}
@Override
public void extractTerms(Set<Term> terms) {
for (final Feature.FeatureWeight feature : extractedFeatureWeights) {
feature.extractTerms(terms);
}
}
protected void reset() {
for (int i = 0; i < extractedFeatureWeights.length;++i){
int featId = extractedFeatureWeights[i].getIndex();
float value = extractedFeatureWeights[i].getDefaultValue();
featuresInfo[featId].setValue(value); // need to set default value everytime as the default value is used in 'dense' mode even if used=false
featuresInfo[featId].setUsed(false);
}
}
@Override
public ModelScorer scorer(LeafReaderContext context) throws IOException {
final List<Feature.FeatureWeight.FeatureScorer> featureScorers = new ArrayList<Feature.FeatureWeight.FeatureScorer>(
extractedFeatureWeights.length);
for (final Feature.FeatureWeight featureWeight : extractedFeatureWeights) {
final Feature.FeatureWeight.FeatureScorer scorer = featureWeight.scorer(context);
if (scorer != null) {
featureScorers.add(scorer);
}
}
// Always return a ModelScorer, even if no features match, because we
// always need to call
// score on the model for every document, since 0 features matching could
// return a
// non 0 score for a given model.
ModelScorer mscorer = new ModelScorer(this, featureScorers);
return mscorer;
}
public class ModelScorer extends Scorer {
final private DocInfo docInfo;
final private Scorer featureTraversalScorer;
public DocInfo getDocInfo() {
return docInfo;
}
public ModelScorer(Weight weight, List<Feature.FeatureWeight.FeatureScorer> featureScorers) {
super(weight);
docInfo = new DocInfo();
for (final Feature.FeatureWeight.FeatureScorer subSocer : featureScorers) {
subSocer.setDocInfo(docInfo);
}
if (featureScorers.size() <= 1) { // TODO: Allow the use of dense
// features in other cases
featureTraversalScorer = new DenseModelScorer(weight, featureScorers);
} else {
featureTraversalScorer = new SparseModelScorer(weight, featureScorers);
}
}
@Override
public Collection<ChildScorer> getChildren() {
return featureTraversalScorer.getChildren();
}
@Override
public int docID() {
return featureTraversalScorer.docID();
}
@Override
public float score() throws IOException {
return featureTraversalScorer.score();
}
@Override
public int freq() throws IOException {
return featureTraversalScorer.freq();
}
@Override
public DocIdSetIterator iterator() {
return featureTraversalScorer.iterator();
}
private class SparseModelScorer extends Scorer {
final private DisiPriorityQueue subScorers;
final private ScoringQuerySparseIterator itr;
private int targetDoc = -1;
private int activeDoc = -1;
private SparseModelScorer(Weight weight,
List<Feature.FeatureWeight.FeatureScorer> featureScorers) {
super(weight);
if (featureScorers.size() <= 1) {
throw new IllegalArgumentException(
"There must be at least 2 subScorers");
}
subScorers = new DisiPriorityQueue(featureScorers.size());
for (final Scorer scorer : featureScorers) {
final DisiWrapper w = new DisiWrapper(scorer);
subScorers.add(w);
}
itr = new ScoringQuerySparseIterator(subScorers);
}
@Override
public int docID() {
return itr.docID();
}
@Override
public float score() throws IOException {
final DisiWrapper topList = subScorers.topList();
// If target doc we wanted to advance to matches the actual doc
// the underlying features advanced to, perform the feature
// calculations,
// otherwise just continue with the model's scoring process with empty
// features.
reset();
if (activeDoc == targetDoc) {
for (DisiWrapper w = topList; w != null; w = w.next) {
final Scorer subScorer = w.scorer;
Feature.FeatureWeight scFW = (Feature.FeatureWeight) subScorer.getWeight();
final int featureId = scFW.getIndex();
featuresInfo[featureId].setValue(subScorer.score());
featuresInfo[featureId].setUsed(true);
}
}
makeNormalizedFeatures();
return ltrScoringModel.score(modelFeatureValuesNormalized);
}
@Override
public int freq() throws IOException {
final DisiWrapper subMatches = subScorers.topList();
int freq = 1;
for (DisiWrapper w = subMatches.next; w != null; w = w.next) {
freq += 1;
}
return freq;
}
@Override
public DocIdSetIterator iterator() {
return itr;
}
@Override
public final Collection<ChildScorer> getChildren() {
final ArrayList<ChildScorer> children = new ArrayList<>();
for (final DisiWrapper scorer : subScorers) {
children.add(new ChildScorer(scorer.scorer, "SHOULD"));
}
return children;
}
private class ScoringQuerySparseIterator extends DisjunctionDISIApproximation {
public ScoringQuerySparseIterator(DisiPriorityQueue subIterators) {
super(subIterators);
}
@Override
public final int nextDoc() throws IOException {
if (activeDoc == targetDoc) {
activeDoc = super.nextDoc();
} else if (activeDoc < targetDoc) {
activeDoc = super.advance(targetDoc + 1);
}
return ++targetDoc;
}
@Override
public final int advance(int target) throws IOException {
// If target doc we wanted to advance to matches the actual doc
// the underlying features advanced to, perform the feature
// calculations,
// otherwise just continue with the model's scoring process with
// empty features.
if (activeDoc < target) {
activeDoc = super.advance(target);
}
targetDoc = target;
return targetDoc;
}
}
}
private class DenseModelScorer extends Scorer {
private int activeDoc = -1; // The doc that our scorer's are actually at
private int targetDoc = -1; // The doc we were most recently told to go to
private int freq = -1;
final private List<Feature.FeatureWeight.FeatureScorer> featureScorers;
private DenseModelScorer(Weight weight,
List<Feature.FeatureWeight.FeatureScorer> featureScorers) {
super(weight);
this.featureScorers = featureScorers;
}
@Override
public int docID() {
return targetDoc;
}
@Override
public float score() throws IOException {
reset();
freq = 0;
if (targetDoc == activeDoc) {
for (final Scorer scorer : featureScorers) {
if (scorer.docID() == activeDoc) {
freq++;
Feature.FeatureWeight scFW = (Feature.FeatureWeight) scorer.getWeight();
final int featureId = scFW.getIndex();
featuresInfo[featureId].setValue(scorer.score());
featuresInfo[featureId].setUsed(true);
}
}
}
makeNormalizedFeatures();
return ltrScoringModel.score(modelFeatureValuesNormalized);
}
@Override
public final Collection<ChildScorer> getChildren() {
final ArrayList<ChildScorer> children = new ArrayList<>();
for (final Scorer scorer : featureScorers) {
children.add(new ChildScorer(scorer, "SHOULD"));
}
return children;
}
@Override
public int freq() throws IOException {
return freq;
}
@Override
public DocIdSetIterator iterator() {
return new DenseIterator();
}
private class DenseIterator extends DocIdSetIterator {
@Override
public int docID() {
return targetDoc;
}
@Override
public int nextDoc() throws IOException {
if (activeDoc <= targetDoc) {
activeDoc = NO_MORE_DOCS;
for (final Scorer scorer : featureScorers) {
if (scorer.docID() != NO_MORE_DOCS) {
activeDoc = Math.min(activeDoc, scorer.iterator().nextDoc());
}
}
}
return ++targetDoc;
}
@Override
public int advance(int target) throws IOException {
if (activeDoc < target) {
activeDoc = NO_MORE_DOCS;
for (final Scorer scorer : featureScorers) {
if (scorer.docID() != NO_MORE_DOCS) {
activeDoc = Math.min(activeDoc,
scorer.iterator().advance(target));
}
}
}
targetDoc = target;
return target;
}
@Override
public long cost() {
long sum = 0;
for (int i = 0; i < featureScorers.size(); i++) {
sum += featureScorers.get(i).iterator().cost();
}
return sum;
}
}
}
}
}
}

View File

@ -0,0 +1,163 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.Semaphore;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.TimeUnit;
import org.apache.solr.common.util.ExecutorUtil;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.util.DefaultSolrThreadFactory;
import org.apache.solr.util.SolrPluginUtils;
import org.apache.solr.util.plugin.NamedListInitializedPlugin;
final public class LTRThreadModule implements NamedListInitializedPlugin {
public static LTRThreadModule getInstance(NamedList args) {
final LTRThreadModule threadManager;
final NamedList threadManagerArgs = extractThreadModuleParams(args);
// if and only if there are thread module args then we want a thread module!
if (threadManagerArgs.size() > 0) {
// create and initialize the new instance
threadManager = new LTRThreadModule();
threadManager.init(threadManagerArgs);
} else {
threadManager = null;
}
return threadManager;
}
private static String CONFIG_PREFIX = "threadModule.";
private static NamedList extractThreadModuleParams(NamedList args) {
// gather the thread module args from amongst the general args
final NamedList extractedArgs = new NamedList();
for (Iterator<Map.Entry<String,Object>> it = args.iterator();
it.hasNext(); ) {
final Map.Entry<String,Object> entry = it.next();
final String key = entry.getKey();
if (key.startsWith(CONFIG_PREFIX)) {
extractedArgs.add(key.substring(CONFIG_PREFIX.length()), entry.getValue());
}
}
// remove consumed keys only once iteration is complete
// since NamedList iterator does not support 'remove'
for (Object key : extractedArgs.asShallowMap().keySet()) {
args.remove(CONFIG_PREFIX+key);
}
return extractedArgs;
}
// settings
private int totalPoolThreads = 1;
private int numThreadsPerRequest = 1;
private int maxPoolSize = Integer.MAX_VALUE;
private long keepAliveTimeSeconds = 10;
private String threadNamePrefix = "ltrExecutor";
// implementation
private Semaphore ltrSemaphore;
private Executor createWeightScoreExecutor;
public LTRThreadModule() {
}
// For test use only.
LTRThreadModule(int totalPoolThreads, int numThreadsPerRequest) {
this.totalPoolThreads = totalPoolThreads;
this.numThreadsPerRequest = numThreadsPerRequest;
init(null);
}
@Override
public void init(NamedList args) {
if (args != null) {
SolrPluginUtils.invokeSetters(this, args);
}
validate();
if (this.totalPoolThreads > 1 ){
ltrSemaphore = new Semaphore(totalPoolThreads);
} else {
ltrSemaphore = null;
}
createWeightScoreExecutor = new ExecutorUtil.MDCAwareThreadPoolExecutor(
0,
maxPoolSize,
keepAliveTimeSeconds, TimeUnit.SECONDS, // terminate idle threads after 10 sec
new SynchronousQueue<Runnable>(), // directly hand off tasks
new DefaultSolrThreadFactory(threadNamePrefix)
);
}
private void validate() {
if (totalPoolThreads <= 0){
throw new IllegalArgumentException("totalPoolThreads cannot be less than 1");
}
if (numThreadsPerRequest <= 0){
throw new IllegalArgumentException("numThreadsPerRequest cannot be less than 1");
}
if (totalPoolThreads < numThreadsPerRequest){
throw new IllegalArgumentException("numThreadsPerRequest cannot be greater than totalPoolThreads");
}
}
public void setTotalPoolThreads(int totalPoolThreads) {
this.totalPoolThreads = totalPoolThreads;
}
public void setNumThreadsPerRequest(int numThreadsPerRequest) {
this.numThreadsPerRequest = numThreadsPerRequest;
}
public void setMaxPoolSize(int maxPoolSize) {
this.maxPoolSize = maxPoolSize;
}
public void setKeepAliveTimeSeconds(long keepAliveTimeSeconds) {
this.keepAliveTimeSeconds = keepAliveTimeSeconds;
}
public void setThreadNamePrefix(String threadNamePrefix) {
this.threadNamePrefix = threadNamePrefix;
}
public Semaphore createQuerySemaphore() {
return (numThreadsPerRequest > 1 ? new Semaphore(numThreadsPerRequest) : null);
}
public void acquireLTRSemaphore() throws InterruptedException {
ltrSemaphore.acquire();
}
public void releaseLTRSemaphore() throws InterruptedException {
ltrSemaphore.release();
}
public void execute(Runnable command) {
createWeightScoreExecutor.execute(command);
}
}

View File

@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr;
import org.apache.solr.request.SolrQueryRequest;
public class SolrQueryRequestContextUtils {
/** key prefix to reduce possibility of clash with other code's key choices **/
private static final String LTR_PREFIX = "ltr.";
/** key of the feature logger in the request context **/
private static final String FEATURE_LOGGER = LTR_PREFIX + "feature_logger";
/** key of the scoring query in the request context **/
private static final String SCORING_QUERY = LTR_PREFIX + "scoring_query";
/** key of the isExtractingFeatures flag in the request context **/
private static final String IS_EXTRACTING_FEATURES = LTR_PREFIX + "isExtractingFeatures";
/** key of the feature vector store name in the request context **/
private static final String STORE = LTR_PREFIX + "store";
/** feature logger accessors **/
public static void setFeatureLogger(SolrQueryRequest req, FeatureLogger<?> featureLogger) {
req.getContext().put(FEATURE_LOGGER, featureLogger);
}
public static FeatureLogger<?> getFeatureLogger(SolrQueryRequest req) {
return (FeatureLogger<?>) req.getContext().get(FEATURE_LOGGER);
}
/** scoring query accessors **/
public static void setScoringQuery(SolrQueryRequest req, LTRScoringQuery scoringQuery) {
req.getContext().put(SCORING_QUERY, scoringQuery);
}
public static LTRScoringQuery getScoringQuery(SolrQueryRequest req) {
return (LTRScoringQuery) req.getContext().get(SCORING_QUERY);
}
/** isExtractingFeatures flag accessors **/
public static void setIsExtractingFeatures(SolrQueryRequest req) {
req.getContext().put(IS_EXTRACTING_FEATURES, Boolean.TRUE);
}
public static void clearIsExtractingFeatures(SolrQueryRequest req) {
req.getContext().put(IS_EXTRACTING_FEATURES, Boolean.FALSE);
}
public static boolean isExtractingFeatures(SolrQueryRequest req) {
return Boolean.TRUE.equals(req.getContext().get(IS_EXTRACTING_FEATURES));
}
/** feature vector store name accessors **/
public static void setFvStoreName(SolrQueryRequest req, String fvStoreName) {
req.getContext().put(STORE, fvStoreName);
}
public static String getFvStoreName(SolrQueryRequest req) {
return (String) req.getContext().get(STORE);
}
}

View File

@ -0,0 +1,335 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.feature;
import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.ltr.DocInfo;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.request.macro.MacroExpander;
import org.apache.solr.util.SolrPluginUtils;
/**
* A recipe for computing a feature. Subclass this for specialized feature calculations.
* <p>
* A feature consists of
* <ul>
* <li> a name as the identifier
* <li> parameters to represent the specific feature
* </ul>
* <p>
* Example configuration (snippet):
* <pre>{
"class" : "...",
"name" : "myFeature",
"params" : {
...
}
}</pre>
* <p>
* {@link Feature} is an abstract class and concrete classes should implement
* the {@link #validate()} function, and must implement the {@link #paramsToMap()}
* and createWeight() methods.
*/
public abstract class Feature extends Query {
final protected String name;
private int index = -1;
private float defaultValue = 0.0f;
final private Map<String,Object> params;
public static Feature getInstance(SolrResourceLoader solrResourceLoader,
String className, String name, Map<String,Object> params) {
final Feature f = solrResourceLoader.newInstance(
className,
Feature.class,
new String[0], // no sub packages
new Class[] { String.class, Map.class },
new Object[] { name, params });
if (params != null) {
SolrPluginUtils.invokeSetters(f, params.entrySet());
}
f.validate();
return f;
}
public Feature(String name, Map<String,Object> params) {
this.name = name;
this.params = params;
}
/**
* As part of creation of a feature instance, this function confirms
* that the feature parameters are valid.
*
* @throws FeatureException
* Feature Exception
*/
protected abstract void validate() throws FeatureException;
@Override
public String toString(String field) {
final StringBuilder sb = new StringBuilder(64); // default initialCapacity of 16 won't be enough
sb.append(getClass().getSimpleName());
sb.append(" [name=").append(name);
final LinkedHashMap<String,Object> params = paramsToMap();
if (params != null) {
sb.append(", params=").append(params);
}
sb.append(']');
return sb.toString();
}
public abstract FeatureWeight createWeight(IndexSearcher searcher,
boolean needsScores, SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi) throws IOException;
public float getDefaultValue() {
return defaultValue;
}
public void setDefaultValue(String value){
defaultValue = Float.parseFloat(value);
}
@Override
public int hashCode() {
final int prime = 31;
int result = classHash();
result = (prime * result) + index;
result = (prime * result) + ((name == null) ? 0 : name.hashCode());
result = (prime * result) + ((params == null) ? 0 : params.hashCode());
return result;
}
@Override
public boolean equals(Object o) {
return sameClassAs(o) && equalsTo(getClass().cast(o));
}
private boolean equalsTo(Feature other) {
if (index != other.index) {
return false;
}
if (name == null) {
if (other.name != null) {
return false;
}
} else if (!name.equals(other.name)) {
return false;
}
if (params == null) {
if (other.params != null) {
return false;
}
} else if (!params.equals(other.params)) {
return false;
}
return true;
}
/**
* @return the name
*/
public String getName() {
return name;
}
/**
* @return the id
*/
public int getIndex() {
return index;
}
/**
* @param index
* Unique ID for this feature. Similar to feature name, except it can
* be used to directly access the feature in the global list of
* features.
*/
public void setIndex(int index) {
this.index = index;
}
public abstract LinkedHashMap<String,Object> paramsToMap();
/**
* Weight for a feature
**/
public abstract class FeatureWeight extends Weight {
final protected IndexSearcher searcher;
final protected SolrQueryRequest request;
final protected Map<String,String[]> efi;
final protected MacroExpander macroExpander;
final protected Query originalQuery;
/**
* Initialize a feature without the normalizer from the feature file. This is
* called on initial construction since multiple models share the same
* features, but have different normalizers. A concrete model's feature is
* copied through featForNewModel().
*
* @param q
* Solr query associated with this FeatureWeight
* @param searcher
* Solr searcher available for features if they need them
*/
public FeatureWeight(Query q, IndexSearcher searcher,
SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi) {
super(q);
this.searcher = searcher;
this.request = request;
this.originalQuery = originalQuery;
this.efi = efi;
macroExpander = new MacroExpander(efi,true);
}
public String getName() {
return Feature.this.getName();
}
public int getIndex() {
return Feature.this.getIndex();
}
public float getDefaultValue() {
return Feature.this.getDefaultValue();
}
@Override
public abstract FeatureScorer scorer(LeafReaderContext context)
throws IOException;
@Override
public Explanation explain(LeafReaderContext context, int doc)
throws IOException {
final FeatureScorer r = scorer(context);
float score = getDefaultValue();
if (r != null) {
r.iterator().advance(doc);
if (r.docID() == doc) {
score = r.score();
}
return Explanation.match(score, toString());
}else{
return Explanation.match(score, "The feature has no value");
}
}
/**
* Used in the FeatureWeight's explain. Each feature should implement this
* returning properties of the specific scorer useful for an explain. For
* example "MyCustomClassFeature [name=" + name + "myVariable:" + myVariable +
* "]"; If not provided, a default implementation will return basic feature
* properties, which might not include query time specific values.
*/
@Override
public String toString() {
return Feature.this.toString();
}
@Override
public void extractTerms(Set<Term> terms) {
// needs to be implemented by query subclasses
throw new UnsupportedOperationException();
}
/**
* A 'recipe' for computing a feature
*/
public abstract class FeatureScorer extends Scorer {
final protected String name;
private DocInfo docInfo;
final protected DocIdSetIterator itr;
public FeatureScorer(Feature.FeatureWeight weight,
DocIdSetIterator itr) {
super(weight);
this.itr = itr;
name = weight.getName();
docInfo = null;
}
@Override
public abstract float score() throws IOException;
/**
* Used to provide context from initial score steps to later reranking steps.
*/
public void setDocInfo(DocInfo docInfo) {
this.docInfo = docInfo;
}
public DocInfo getDocInfo() {
return docInfo;
}
@Override
public int freq() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int docID() {
return itr.docID();
}
@Override
public DocIdSetIterator iterator() {
return itr;
}
}
/**
* Default FeatureScorer class that returns the score passed in. Can be used
* as a simple ValueFeature, or to return a default scorer in case an
* underlying feature's scorer is null.
*/
public class ValueFeatureScorer extends FeatureScorer {
float constScore;
public ValueFeatureScorer(FeatureWeight weight, float constScore,
DocIdSetIterator itr) {
super(weight,itr);
this.constScore = constScore;
}
@Override
public float score() {
return constScore;
}
}
}
}

View File

@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.feature;
public class FeatureException extends RuntimeException {
private static final long serialVersionUID = 1L;
public FeatureException(String message) {
super(message);
}
public FeatureException(String message, Exception cause) {
super(message, cause);
}
}

View File

@ -0,0 +1,152 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.feature;
import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.Map;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.SmallFloat;
import org.apache.solr.request.SolrQueryRequest;
/**
* This feature returns the length of a field (in terms) for the current document.
* Example configuration:
* <pre>{
"name": "titleLength",
"class": "org.apache.solr.ltr.feature.FieldLengthFeature",
"params": {
"field": "title"
}
}</pre>
* Note: since this feature relies on norms values that are stored in a single byte
* the value of the feature could have a lightly different value.
* (see also {@link org.apache.lucene.search.similarities.ClassicSimilarity})
**/
public class FieldLengthFeature extends Feature {
private String field;
public String getField() {
return field;
}
public void setField(String field) {
this.field = field;
}
@Override
public LinkedHashMap<String,Object> paramsToMap() {
final LinkedHashMap<String,Object> params = new LinkedHashMap<>(1, 1.0f);
params.put("field", field);
return params;
}
@Override
protected void validate() throws FeatureException {
if (field == null || field.isEmpty()) {
throw new FeatureException(getClass().getSimpleName()+
": field must be provided");
}
}
/** Cache of decoded bytes. */
private static final float[] NORM_TABLE = new float[256];
static {
NORM_TABLE[0] = 0;
for (int i = 1; i < 256; i++) {
float norm = SmallFloat.byte315ToFloat((byte) i);
NORM_TABLE[i] = 1.0f / (norm * norm);
}
}
/**
* Decodes the norm value, assuming it is a single byte.
*
*/
private final float decodeNorm(long norm) {
return NORM_TABLE[(int) (norm & 0xFF)]; // & 0xFF maps negative bytes to
// positive above 127
}
public FieldLengthFeature(String name, Map<String,Object> params) {
super(name, params);
}
@Override
public FeatureWeight createWeight(IndexSearcher searcher, boolean needsScores,
SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi)
throws IOException {
return new FieldLengthFeatureWeight(searcher, request, originalQuery, efi);
}
public class FieldLengthFeatureWeight extends FeatureWeight {
public FieldLengthFeatureWeight(IndexSearcher searcher,
SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi) {
super(FieldLengthFeature.this, searcher, request, originalQuery, efi);
}
@Override
public FeatureScorer scorer(LeafReaderContext context) throws IOException {
NumericDocValues norms = context.reader().getNormValues(field);
if (norms == null){
return new ValueFeatureScorer(this, 0f,
DocIdSetIterator.all(DocIdSetIterator.NO_MORE_DOCS));
}
return new FieldLengthFeatureScorer(this, norms);
}
public class FieldLengthFeatureScorer extends FeatureScorer {
NumericDocValues norms = null;
public FieldLengthFeatureScorer(FeatureWeight weight,
NumericDocValues norms) throws IOException {
super(weight, norms);
this.norms = norms;
// In the constructor, docId is -1, so using 0 as default lookup
final IndexableField idxF = searcher.doc(0).getField(field);
if (idxF.fieldType().omitNorms()) {
throw new IOException(
"FieldLengthFeatures can't be used if omitNorms is enabled (field="
+ field + ")");
}
}
@Override
public float score() throws IOException {
final long l = norms.longValue();
final float numTerms = decodeNorm(l);
return numTerms;
}
}
}
}

View File

@ -0,0 +1,141 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.feature;
import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.solr.request.SolrQueryRequest;
/**
* This feature returns the value of a field in the current document
* Example configuration:
* <pre>{
"name": "rawHits",
"class": "org.apache.solr.ltr.feature.FieldValueFeature",
"params": {
"field": "hits"
}
}</pre>
*/
public class FieldValueFeature extends Feature {
private String field;
private Set<String> fieldAsSet;
public String getField() {
return field;
}
public void setField(String field) {
this.field = field;
fieldAsSet = Collections.singleton(field);
}
@Override
public LinkedHashMap<String,Object> paramsToMap() {
final LinkedHashMap<String,Object> params = new LinkedHashMap<>(1, 1.0f);
params.put("field", field);
return params;
}
@Override
protected void validate() throws FeatureException {
if (field == null || field.isEmpty()) {
throw new FeatureException(getClass().getSimpleName()+
": field must be provided");
}
}
public FieldValueFeature(String name, Map<String,Object> params) {
super(name, params);
}
@Override
public FeatureWeight createWeight(IndexSearcher searcher, boolean needsScores,
SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi)
throws IOException {
return new FieldValueFeatureWeight(searcher, request, originalQuery, efi);
}
public class FieldValueFeatureWeight extends FeatureWeight {
public FieldValueFeatureWeight(IndexSearcher searcher,
SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi) {
super(FieldValueFeature.this, searcher, request, originalQuery, efi);
}
@Override
public FeatureScorer scorer(LeafReaderContext context) throws IOException {
return new FieldValueFeatureScorer(this, context,
DocIdSetIterator.all(DocIdSetIterator.NO_MORE_DOCS));
}
public class FieldValueFeatureScorer extends FeatureScorer {
LeafReaderContext context = null;
public FieldValueFeatureScorer(FeatureWeight weight,
LeafReaderContext context, DocIdSetIterator itr) {
super(weight, itr);
this.context = context;
}
@Override
public float score() throws IOException {
try {
final Document document = context.reader().document(itr.docID(),
fieldAsSet);
final IndexableField indexableField = document.getField(field);
if (indexableField == null) {
return getDefaultValue();
}
final Number number = indexableField.numericValue();
if (number != null) {
return number.floatValue();
} else {
final String string = indexableField.stringValue();
// boolean values in the index are encoded with the
// chars T/F
if (string.equals("T")) {
return 1;
}
if (string.equals("F")) {
return 0;
}
}
} catch (final IOException e) {
throw new FeatureException(
e.toString() + ": " +
"Unable to extract feature for "
+ name, e);
}
return getDefaultValue();
}
}
}
}

View File

@ -0,0 +1,118 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.feature;
import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.Map;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.solr.ltr.DocInfo;
import org.apache.solr.request.SolrQueryRequest;
/**
* This feature returns the original score that the document had before performing
* the reranking.
* Example configuration:
* <pre>{
"name": "originalScore",
"class": "org.apache.solr.ltr.feature.OriginalScoreFeature",
"params": { }
}</pre>
**/
public class OriginalScoreFeature extends Feature {
public OriginalScoreFeature(String name, Map<String,Object> params) {
super(name, params);
}
@Override
public LinkedHashMap<String,Object> paramsToMap() {
return null;
}
@Override
protected void validate() throws FeatureException {
}
@Override
public OriginalScoreWeight createWeight(IndexSearcher searcher,
boolean needsScores, SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi) throws IOException {
return new OriginalScoreWeight(searcher, request, originalQuery, efi);
}
public class OriginalScoreWeight extends FeatureWeight {
final Weight w;
public OriginalScoreWeight(IndexSearcher searcher,
SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi) throws IOException {
super(OriginalScoreFeature.this, searcher, request, originalQuery, efi);
w = searcher.createNormalizedWeight(originalQuery, true);
};
@Override
public String toString() {
return "OriginalScoreFeature [query:" + originalQuery.toString() + "]";
}
@Override
public FeatureScorer scorer(LeafReaderContext context) throws IOException {
final Scorer originalScorer = w.scorer(context);
return new OriginalScoreScorer(this, originalScorer);
}
public class OriginalScoreScorer extends FeatureScorer {
final private Scorer originalScorer;
public OriginalScoreScorer(FeatureWeight weight, Scorer originalScorer) {
super(weight,null);
this.originalScorer = originalScorer;
}
@Override
public float score() throws IOException {
// This is done to improve the speed of feature extraction. Since this
// was already scored in step 1
// we shouldn't need to calc original score again.
final DocInfo docInfo = getDocInfo();
return (docInfo.hasOriginalDocScore() ? docInfo.getOriginalDocScore() : originalScorer.score());
}
@Override
public int docID() {
return originalScorer.docID();
}
@Override
public DocIdSetIterator iterator() {
return originalScorer.iterator();
}
}
}
}

View File

@ -0,0 +1,320 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.feature;
import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSet;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;
import org.apache.solr.common.params.CommonParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.SolrCore;
import org.apache.solr.request.LocalSolrQueryRequest;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.search.QParser;
import org.apache.solr.search.SolrIndexSearcher;
import org.apache.solr.search.SyntaxError;
/**
* This feature allows you to reuse any Solr query as a feature. The value
* of the feature will be the score of the given query for the current document.
* See <a href="https://cwiki.apache.org/confluence/display/solr/Other+Parsers">Solr documentation of other parsers</a> you can use as a feature.
* Example configurations:
* <pre>[{ "name": "isBook",
"class": "org.apache.solr.ltr.feature.SolrFeature",
"params":{ "fq": ["{!terms f=category}book"] }
},
{
"name": "documentRecency",
"class": "org.apache.solr.ltr.feature.SolrFeature",
"params": {
"q": "{!func}recip( ms(NOW,publish_date), 3.16e-11, 1, 1)"
}
}]</pre>
**/
public class SolrFeature extends Feature {
private String df;
private String q;
private List<String> fq;
public String getDf() {
return df;
}
public void setDf(String df) {
this.df = df;
}
public String getQ() {
return q;
}
public void setQ(String q) {
this.q = q;
}
public List<String> getFq() {
return fq;
}
public void setFq(List<String> fq) {
this.fq = fq;
}
public SolrFeature(String name, Map<String,Object> params) {
super(name, params);
}
@Override
public LinkedHashMap<String,Object> paramsToMap() {
final LinkedHashMap<String,Object> params = new LinkedHashMap<>(3, 1.0f);
if (df != null) {
params.put("df", df);
}
if (q != null) {
params.put("q", q);
}
if (fq != null) {
params.put("fq", fq);
}
return params;
}
@Override
public FeatureWeight createWeight(IndexSearcher searcher, boolean needsScores,
SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi)
throws IOException {
return new SolrFeatureWeight(searcher, request, originalQuery, efi);
}
@Override
protected void validate() throws FeatureException {
if ((q == null || q.isEmpty()) &&
((fq == null) || fq.isEmpty())) {
throw new FeatureException(getClass().getSimpleName()+
": Q or FQ must be provided");
}
}
/**
* Weight for a SolrFeature
**/
public class SolrFeatureWeight extends FeatureWeight {
Weight solrQueryWeight;
Query query;
List<Query> queryAndFilters;
public SolrFeatureWeight(IndexSearcher searcher,
SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi) throws IOException {
super(SolrFeature.this, searcher, request, originalQuery, efi);
try {
String solrQuery = q;
final List<String> fqs = fq;
if ((solrQuery == null) || solrQuery.isEmpty()) {
solrQuery = "*:*";
}
solrQuery = macroExpander.expand(solrQuery);
if (solrQuery == null) {
throw new FeatureException(this.getClass().getSimpleName()+" requires efi parameter that was not passed in request.");
}
final SolrQueryRequest req = makeRequest(request.getCore(), solrQuery,
fqs, df);
if (req == null) {
throw new IOException("ERROR: No parameters provided");
}
// Build the filter queries
queryAndFilters = new ArrayList<Query>(); // If there are no fqs we just want an empty list
if (fqs != null) {
for (String fq : fqs) {
if ((fq != null) && (fq.trim().length() != 0)) {
fq = macroExpander.expand(fq);
final QParser fqp = QParser.getParser(fq, req);
final Query filterQuery = fqp.getQuery();
if (filterQuery != null) {
queryAndFilters.add(filterQuery);
}
}
}
}
final QParser parser = QParser.getParser(solrQuery, req);
query = parser.parse();
// Query can be null if there was no input to parse, for instance if you
// make a phrase query with "to be", and the analyzer removes all the
// words
// leaving nothing for the phrase query to parse.
if (query != null) {
queryAndFilters.add(query);
solrQueryWeight = searcher.createNormalizedWeight(query, true);
}
} catch (final SyntaxError e) {
throw new FeatureException("Failed to parse feature query.", e);
}
}
private LocalSolrQueryRequest makeRequest(SolrCore core, String solrQuery,
List<String> fqs, String df) {
final NamedList<String> returnList = new NamedList<String>();
if ((solrQuery != null) && !solrQuery.isEmpty()) {
returnList.add(CommonParams.Q, solrQuery);
}
if (fqs != null) {
for (final String fq : fqs) {
returnList.add(CommonParams.FQ, fq);
}
}
if ((df != null) && !df.isEmpty()) {
returnList.add(CommonParams.DF, df);
}
if (returnList.size() > 0) {
return new LocalSolrQueryRequest(core, returnList);
} else {
return null;
}
}
@Override
public FeatureScorer scorer(LeafReaderContext context) throws IOException {
Scorer solrScorer = null;
if (solrQueryWeight != null) {
solrScorer = solrQueryWeight.scorer(context);
}
final DocIdSetIterator idItr = getDocIdSetIteratorFromQueries(
queryAndFilters, context);
if (idItr != null) {
return solrScorer == null ? new ValueFeatureScorer(this, 1f, idItr)
: new SolrFeatureScorer(this, solrScorer,
new SolrFeatureScorerIterator(idItr, solrScorer.iterator()));
} else {
return null;
}
}
/**
* Given a list of Solr filters/queries, return a doc iterator that
* traverses over the documents that matched all the criteria of the
* queries.
*
* @param queries
* Filtering criteria to match documents against
* @param context
* Index reader
* @return DocIdSetIterator to traverse documents that matched all filter
* criteria
*/
private DocIdSetIterator getDocIdSetIteratorFromQueries(List<Query> queries,
LeafReaderContext context) throws IOException {
final SolrIndexSearcher.ProcessedFilter pf = ((SolrIndexSearcher) searcher)
.getProcessedFilter(null, queries);
final Bits liveDocs = context.reader().getLiveDocs();
DocIdSetIterator idIter = null;
if (pf.filter != null) {
final DocIdSet idSet = pf.filter.getDocIdSet(context, liveDocs);
if (idSet != null) {
idIter = idSet.iterator();
}
}
return idIter;
}
/**
* Scorer for a SolrFeature
**/
public class SolrFeatureScorer extends FeatureScorer {
final private Scorer solrScorer;
public SolrFeatureScorer(FeatureWeight weight, Scorer solrScorer,
SolrFeatureScorerIterator itr) {
super(weight, itr);
this.solrScorer = solrScorer;
}
@Override
public float score() throws IOException {
try {
return solrScorer.score();
} catch (UnsupportedOperationException e) {
throw new FeatureException(
e.toString() + ": " +
"Unable to extract feature for "
+ name, e);
}
}
}
/**
* An iterator that allows to iterate only on the documents for which a feature has
* a value.
**/
public class SolrFeatureScorerIterator extends DocIdSetIterator {
final private DocIdSetIterator filterIterator;
final private DocIdSetIterator scorerFilter;
SolrFeatureScorerIterator(DocIdSetIterator filterIterator,
DocIdSetIterator scorerFilter) {
this.filterIterator = filterIterator;
this.scorerFilter = scorerFilter;
}
@Override
public int docID() {
return filterIterator.docID();
}
@Override
public int nextDoc() throws IOException {
int docID = filterIterator.nextDoc();
scorerFilter.advance(docID);
return docID;
}
@Override
public int advance(int target) throws IOException {
// We use iterator to catch the scorer up since
// that checks if the target id is in the query + all the filters
int docID = filterIterator.advance(target);
scorerFilter.advance(docID);
return docID;
}
@Override
public long cost() {
return filterIterator.cost() + scorerFilter.cost();
}
}
}
}

View File

@ -0,0 +1,148 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.feature;
import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.Map;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.solr.request.SolrQueryRequest;
/**
* This feature allows to return a constant given value for the current document.
*
* Example configuration:
* <pre>{
"name" : "userFromMobile",
"class" : "org.apache.solr.ltr.feature.ValueFeature",
"params" : { "value" : "${userFromMobile}", "required":true }
}</pre>
*
*You can place a constant value like "1.3f" in the value params, but many times you
*would want to pass in external information to use per request. For instance, maybe
*you want to rank things differently if the search came from a mobile device, or maybe
*you want to use your external query intent system as a feature.
*In the rerank request you can pass in rq={... efi.userFromMobile=1}, and the above
*feature will return 1 for all the docs for that request. If required is set to true,
*the request will return an error since you failed to pass in the efi, otherwise if will
*just skip the feature and use a default value of 0 instead.
**/
public class ValueFeature extends Feature {
private float configValue = -1f;
private String configValueStr = null;
private Object value = null;
private Boolean required = null;
public Object getValue() {
return value;
}
public void setValue(Object value) {
this.value = value;
if (value instanceof String) {
configValueStr = (String) value;
} else if (value instanceof Double) {
configValue = ((Double) value).floatValue();
} else if (value instanceof Float) {
configValue = ((Float) value).floatValue();
} else if (value instanceof Integer) {
configValue = ((Integer) value).floatValue();
} else if (value instanceof Long) {
configValue = ((Long) value).floatValue();
} else {
throw new FeatureException("Invalid type for 'value' in params for " + this);
}
}
public boolean isRequired() {
return Boolean.TRUE.equals(required);
}
public void setRequired(boolean required) {
this.required = required;
}
@Override
public LinkedHashMap<String,Object> paramsToMap() {
final LinkedHashMap<String,Object> params = new LinkedHashMap<>(2, 1.0f);
params.put("value", value);
if (required != null) {
params.put("required", required);
}
return params;
}
@Override
protected void validate() throws FeatureException {
if (configValueStr != null && configValueStr.trim().isEmpty()) {
throw new FeatureException("Empty field 'value' in params for " + this);
}
}
public ValueFeature(String name, Map<String,Object> params) {
super(name, params);
}
@Override
public FeatureWeight createWeight(IndexSearcher searcher, boolean needsScores,
SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi)
throws IOException {
return new ValueFeatureWeight(searcher, request, originalQuery, efi);
}
public class ValueFeatureWeight extends FeatureWeight {
final protected Float featureValue;
public ValueFeatureWeight(IndexSearcher searcher,
SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi) {
super(ValueFeature.this, searcher, request, originalQuery, efi);
if (configValueStr != null) {
final String expandedValue = macroExpander.expand(configValueStr);
if (expandedValue != null) {
featureValue = Float.parseFloat(expandedValue);
} else if (isRequired()) {
throw new FeatureException(this.getClass().getSimpleName() + " requires efi parameter that was not passed in request.");
} else {
featureValue=null;
}
} else {
featureValue = configValue;
}
}
@Override
public FeatureScorer scorer(LeafReaderContext context) throws IOException {
if(featureValue!=null) {
return new ValueFeatureScorer(this, featureValue,
DocIdSetIterator.all(DocIdSetIterator.NO_MORE_DOCS));
} else {
return null;
}
}
}
}

View File

@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Contains Feature related classes
*/
package org.apache.solr.ltr.feature;

View File

@ -0,0 +1,298 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.model;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.FeatureException;
import org.apache.solr.ltr.norm.IdentityNormalizer;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.util.SolrPluginUtils;
/**
* A scoring model computes scores that can be used to rerank documents.
* <p>
* A scoring model consists of
* <ul>
* <li> a list of features ({@link Feature}) and
* <li> a list of normalizers ({@link Normalizer}) plus
* <li> parameters or configuration to represent the scoring algorithm.
* </ul>
* <p>
* Example configuration (snippet):
* <pre>{
"class" : "...",
"name" : "myModelName",
"features" : [
{
"name" : "isBook"
},
{
"name" : "originalScore",
"norm": {
"class" : "org.apache.solr.ltr.norm.StandardNormalizer",
"params" : { "avg":"100", "std":"10" }
}
},
{
"name" : "price",
"norm": {
"class" : "org.apache.solr.ltr.norm.MinMaxNormalizer",
"params" : { "min":"0", "max":"1000" }
}
}
],
"params" : {
...
}
}</pre>
* <p>
* {@link LTRScoringModel} is an abstract class and concrete classes must
* implement the {@link #score(float[])} and
* {@link #explain(LeafReaderContext, int, float, List)} methods.
*/
public abstract class LTRScoringModel {
protected final String name;
private final String featureStoreName;
protected final List<Feature> features;
private final List<Feature> allFeatures;
private final Map<String,Object> params;
private final List<Normalizer> norms;
public static LTRScoringModel getInstance(SolrResourceLoader solrResourceLoader,
String className, String name, List<Feature> features,
List<Normalizer> norms,
String featureStoreName, List<Feature> allFeatures,
Map<String,Object> params) throws ModelException {
final LTRScoringModel model;
try {
// create an instance of the model
model = solrResourceLoader.newInstance(
className,
LTRScoringModel.class,
new String[0], // no sub packages
new Class[] { String.class, List.class, List.class, String.class, List.class, Map.class },
new Object[] { name, features, norms, featureStoreName, allFeatures, params });
if (params != null) {
SolrPluginUtils.invokeSetters(model, params.entrySet());
}
} catch (final Exception e) {
throw new ModelException("Model type does not exist " + className, e);
}
model.validate();
return model;
}
public LTRScoringModel(String name, List<Feature> features,
List<Normalizer> norms,
String featureStoreName, List<Feature> allFeatures,
Map<String,Object> params) {
this.name = name;
this.features = features;
this.featureStoreName = featureStoreName;
this.allFeatures = allFeatures;
this.params = params;
this.norms = norms;
}
/**
* Validate that settings make sense and throws
* {@link ModelException} if they do not make sense.
*/
protected void validate() throws ModelException {
if (features.isEmpty()) {
throw new ModelException("no features declared for model "+name);
}
final HashSet<String> featureNames = new HashSet<>();
for (final Feature feature : features) {
final String featureName = feature.getName();
if (!featureNames.add(featureName)) {
throw new ModelException("duplicated feature "+featureName+" in model "+name);
}
}
if (features.size() != norms.size()) {
throw new ModelException("counted "+features.size()+" features and "+norms.size()+" norms in model "+name);
}
}
/**
* @return the norms
*/
public List<Normalizer> getNorms() {
return Collections.unmodifiableList(norms);
}
/**
* @return the name
*/
public String getName() {
return name;
}
/**
* @return the features
*/
public List<Feature> getFeatures() {
return Collections.unmodifiableList(features);
}
public Map<String,Object> getParams() {
return params;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = (prime * result) + ((features == null) ? 0 : features.hashCode());
result = (prime * result) + ((name == null) ? 0 : name.hashCode());
result = (prime * result) + ((params == null) ? 0 : params.hashCode());
result = (prime * result) + ((norms == null) ? 0 : norms.hashCode());
result = (prime * result) + ((featureStoreName == null) ? 0 : featureStoreName.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
final LTRScoringModel other = (LTRScoringModel) obj;
if (features == null) {
if (other.features != null) {
return false;
}
} else if (!features.equals(other.features)) {
return false;
}
if (norms == null) {
if (other.norms != null) {
return false;
}
} else if (!norms.equals(other.norms)) {
return false;
}
if (name == null) {
if (other.name != null) {
return false;
}
} else if (!name.equals(other.name)) {
return false;
}
if (params == null) {
if (other.params != null) {
return false;
}
} else if (!params.equals(other.params)) {
return false;
}
if (featureStoreName == null) {
if (other.featureStoreName != null) {
return false;
}
} else if (!featureStoreName.equals(other.featureStoreName)) {
return false;
}
return true;
}
public boolean hasParams() {
return !((params == null) || params.isEmpty());
}
public Collection<Feature> getAllFeatures() {
return allFeatures;
}
public String getFeatureStoreName() {
return featureStoreName;
}
/**
* Given a list of normalized values for all features a scoring algorithm
* cares about, calculate and return a score.
*
* @param modelFeatureValuesNormalized
* List of normalized feature values. Each feature is identified by
* its id, which is the index in the array
* @return The final score for a document
*/
public abstract float score(float[] modelFeatureValuesNormalized);
/**
* Similar to the score() function, except it returns an explanation of how
* the features were used to calculate the score.
*
* @param context
* Context the document is in
* @param doc
* Document to explain
* @param finalScore
* Original score
* @param featureExplanations
* Explanations for each feature calculation
* @return Explanation for the scoring of a document
*/
public abstract Explanation explain(LeafReaderContext context, int doc,
float finalScore, List<Explanation> featureExplanations);
@Override
public String toString() {
return getClass().getSimpleName() + "(name="+getName()+")";
}
/**
* Goes through all the stored feature values, and calculates the normalized
* values for all the features that will be used for scoring.
*/
public void normalizeFeaturesInPlace(float[] modelFeatureValues) {
float[] modelFeatureValuesNormalized = modelFeatureValues;
if (modelFeatureValues.length != norms.size()) {
throw new FeatureException("Must have normalizer for every feature");
}
for(int idx = 0; idx < modelFeatureValuesNormalized.length; ++idx) {
modelFeatureValuesNormalized[idx] =
norms.get(idx).normalize(modelFeatureValuesNormalized[idx]);
}
}
public Explanation getNormalizerExplanation(Explanation e, int idx) {
Normalizer n = norms.get(idx);
if (n != IdentityNormalizer.INSTANCE) {
return n.explain(e);
}
return e;
}
}

View File

@ -0,0 +1,147 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.model;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.norm.Normalizer;
/**
* A scoring model that computes scores using a dot product.
* Example models are RankSVM and Pranking.
* <p>
* Example configuration:
* <pre>{
"class" : "org.apache.solr.ltr.model.LinearModel",
"name" : "myModelName",
"features" : [
{ "name" : "userTextTitleMatch" },
{ "name" : "originalScore" },
{ "name" : "isBook" }
],
"params" : {
"weights" : {
"userTextTitleMatch" : 1.0,
"originalScore" : 0.5,
"isBook" : 0.1
}
}
}</pre>
* <p>
* Background reading:
* <ul>
* <li> <a href="http://www.cs.cornell.edu/people/tj/publications/joachims_02c.pdf">
* Thorsten Joachims. Optimizing Search Engines Using Clickthrough Data.
* Proceedings of the ACM Conference on Knowledge Discovery and Data Mining (KDD), ACM, 2002.</a>
* </ul>
* <ul>
* <li> <a href="https://papers.nips.cc/paper/2023-pranking-with-ranking.pdf">
* Koby Crammer and Yoram Singer. Pranking with Ranking.
* Advances in Neural Information Processing Systems (NIPS), 2001.</a>
* </ul>
*/
public class LinearModel extends LTRScoringModel {
protected Float[] featureToWeight;
public void setWeights(Object weights) {
final Map<String,Double> modelWeights = (Map<String,Double>) weights;
for (int ii = 0; ii < features.size(); ++ii) {
final String key = features.get(ii).getName();
final Double val = modelWeights.get(key);
featureToWeight[ii] = (val == null ? null : new Float(val.floatValue()));
}
}
public LinearModel(String name, List<Feature> features,
List<Normalizer> norms,
String featureStoreName, List<Feature> allFeatures,
Map<String,Object> params) {
super(name, features, norms, featureStoreName, allFeatures, params);
featureToWeight = new Float[features.size()];
}
@Override
protected void validate() throws ModelException {
super.validate();
final ArrayList<String> missingWeightFeatureNames = new ArrayList<String>();
for (int i = 0; i < features.size(); ++i) {
if (featureToWeight[i] == null) {
missingWeightFeatureNames.add(features.get(i).getName());
}
}
if (missingWeightFeatureNames.size() == features.size()) {
throw new ModelException("Model " + name + " doesn't contain any weights");
}
if (!missingWeightFeatureNames.isEmpty()) {
throw new ModelException("Model " + name + " lacks weight(s) for "+missingWeightFeatureNames);
}
}
@Override
public float score(float[] modelFeatureValuesNormalized) {
float score = 0;
for (int i = 0; i < modelFeatureValuesNormalized.length; ++i) {
score += modelFeatureValuesNormalized[i] * featureToWeight[i];
}
return score;
}
@Override
public Explanation explain(LeafReaderContext context, int doc,
float finalScore, List<Explanation> featureExplanations) {
final List<Explanation> details = new ArrayList<>();
int index = 0;
for (final Explanation featureExplain : featureExplanations) {
final List<Explanation> featureDetails = new ArrayList<>();
featureDetails.add(Explanation.match(featureToWeight[index],
"weight on feature"));
featureDetails.add(featureExplain);
details.add(Explanation.match(featureExplain.getValue()
* featureToWeight[index], "prod of:", featureDetails));
index++;
}
return Explanation.match(finalScore, toString()
+ " model applied to features, sum of:", details);
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder(getClass().getSimpleName());
sb.append("(name=").append(getName());
sb.append(",featureWeights=[");
for (int ii = 0; ii < features.size(); ++ii) {
if (ii>0) {
sb.append(',');
}
final String key = features.get(ii).getName();
sb.append(key).append('=').append(featureToWeight[ii]);
}
sb.append("])");
return sb.toString();
}
}

View File

@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.model;
public class ModelException extends RuntimeException {
private static final long serialVersionUID = 1L;
public ModelException(String message) {
super(message);
}
public ModelException(String message, Exception cause) {
super(message, cause);
}
}

View File

@ -0,0 +1,377 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.model;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.util.SolrPluginUtils;
/**
* A scoring model that computes scores based on the summation of multiple weighted trees.
* Example models are LambdaMART and Gradient Boosted Regression Trees (GBRT) .
* <p>
* Example configuration:
<pre>{
"class" : "org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
"name" : "multipleadditivetreesmodel",
"features":[
{ "name" : "userTextTitleMatch"},
{ "name" : "originalScore"}
],
"params" : {
"trees" : [
{
"weight" : 1,
"root": {
"feature" : "userTextTitleMatch",
"threshold" : 0.5,
"left" : {
"value" : -100
},
"right" : {
"feature" : "originalScore",
"threshold" : 10.0,
"left" : {
"value" : 50
},
"right" : {
"value" : 75
}
}
}
},
{
"weight" : 2,
"root" : {
"value" : -10
}
}
]
}
}</pre>
* <p>
* Background reading:
* <ul>
* <li> <a href="http://research.microsoft.com/pubs/132652/MSR-TR-2010-82.pdf">
* Christopher J.C. Burges. From RankNet to LambdaRank to LambdaMART: An Overview.
* Microsoft Research Technical Report MSR-TR-2010-82.</a>
* </ul>
* <ul>
* <li> <a href="https://papers.nips.cc/paper/3305-a-general-boosting-method-and-its-application-to-learning-ranking-functions-for-web-search.pdf">
* Z. Zheng, H. Zha, T. Zhang, O. Chapelle, K. Chen, and G. Sun. A General Boosting Method and its Application to Learning Ranking Functions for Web Search.
* Advances in Neural Information Processing Systems (NIPS), 2007.</a>
* </ul>
*/
public class MultipleAdditiveTreesModel extends LTRScoringModel {
private final HashMap<String,Integer> fname2index;
private List<RegressionTree> trees;
private RegressionTree createRegressionTree(Map<String,Object> map) {
final RegressionTree rt = new RegressionTree();
if (map != null) {
SolrPluginUtils.invokeSetters(rt, map.entrySet());
}
return rt;
}
private RegressionTreeNode createRegressionTreeNode(Map<String,Object> map) {
final RegressionTreeNode rtn = new RegressionTreeNode();
if (map != null) {
SolrPluginUtils.invokeSetters(rtn, map.entrySet());
}
return rtn;
}
public class RegressionTreeNode {
private static final float NODE_SPLIT_SLACK = 1E-6f;
private float value = 0f;
private String feature;
private int featureIndex = -1;
private Float threshold;
private RegressionTreeNode left;
private RegressionTreeNode right;
public void setValue(float value) {
this.value = value;
}
public void setValue(String value) {
this.value = Float.parseFloat(value);
}
public void setFeature(String feature) {
this.feature = feature;
final Integer idx = fname2index.get(this.feature);
// this happens if the tree specifies a feature that does not exist
// this could be due to lambdaSmart building off of pre-existing trees
// that use a feature that is no longer output during feature extraction
featureIndex = (idx == null) ? -1 : idx;
}
public void setThreshold(float threshold) {
this.threshold = threshold + NODE_SPLIT_SLACK;
}
public void setThreshold(String threshold) {
this.threshold = Float.parseFloat(threshold) + NODE_SPLIT_SLACK;
}
public void setLeft(Object left) {
this.left = createRegressionTreeNode((Map<String,Object>) left);
}
public void setRight(Object right) {
this.right = createRegressionTreeNode((Map<String,Object>) right);
}
public boolean isLeaf() {
return feature == null;
}
public float score(float[] featureVector) {
if (isLeaf()) {
return value;
}
// unsupported feature (tree is looking for a feature that does not exist)
if ((featureIndex < 0) || (featureIndex >= featureVector.length)) {
return 0f;
}
if (featureVector[featureIndex] <= threshold) {
return left.score(featureVector);
} else {
return right.score(featureVector);
}
}
public String explain(float[] featureVector) {
if (isLeaf()) {
return "val: " + value;
}
// unsupported feature (tree is looking for a feature that does not exist)
if ((featureIndex < 0) || (featureIndex >= featureVector.length)) {
return "'" + feature + "' does not exist in FV, Return Zero";
}
// could store extra information about how much training data supported
// each branch and report
// that here
if (featureVector[featureIndex] <= threshold) {
String rval = "'" + feature + "':" + featureVector[featureIndex] + " <= "
+ threshold + ", Go Left | ";
return rval + left.explain(featureVector);
} else {
String rval = "'" + feature + "':" + featureVector[featureIndex] + " > "
+ threshold + ", Go Right | ";
return rval + right.explain(featureVector);
}
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder();
if (isLeaf()) {
sb.append(value);
} else {
sb.append("(feature=").append(feature);
sb.append(",threshold=").append(threshold.floatValue()-NODE_SPLIT_SLACK);
sb.append(",left=").append(left);
sb.append(",right=").append(right);
sb.append(')');
}
return sb.toString();
}
public RegressionTreeNode() {
}
public void validate() throws ModelException {
if (isLeaf()) {
if (left != null || right != null) {
throw new ModelException("MultipleAdditiveTreesModel tree node is leaf with left="+left+" and right="+right);
}
return;
}
if (null == threshold) {
throw new ModelException("MultipleAdditiveTreesModel tree node is missing threshold");
}
if (null == left) {
throw new ModelException("MultipleAdditiveTreesModel tree node is missing left");
} else {
left.validate();
}
if (null == right) {
throw new ModelException("MultipleAdditiveTreesModel tree node is missing right");
} else {
right.validate();
}
}
}
public class RegressionTree {
private Float weight;
private RegressionTreeNode root;
public void setWeight(float weight) {
this.weight = new Float(weight);
}
public void setWeight(String weight) {
this.weight = new Float(weight);
}
public void setRoot(Object root) {
this.root = createRegressionTreeNode((Map<String,Object>)root);
}
public float score(float[] featureVector) {
return weight.floatValue() * root.score(featureVector);
}
public String explain(float[] featureVector) {
return root.explain(featureVector);
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder();
sb.append("(weight=").append(weight);
sb.append(",root=").append(root);
sb.append(")");
return sb.toString();
}
public RegressionTree() {
}
public void validate() throws ModelException {
if (weight == null) {
throw new ModelException("MultipleAdditiveTreesModel tree doesn't contain a weight");
}
if (root == null) {
throw new ModelException("MultipleAdditiveTreesModel tree doesn't contain a tree");
} else {
root.validate();
}
}
}
public void setTrees(Object trees) {
this.trees = new ArrayList<RegressionTree>();
for (final Object o : (List<Object>) trees) {
final RegressionTree rt = createRegressionTree((Map<String,Object>) o);
this.trees.add(rt);
}
}
public MultipleAdditiveTreesModel(String name, List<Feature> features,
List<Normalizer> norms,
String featureStoreName, List<Feature> allFeatures,
Map<String,Object> params) {
super(name, features, norms, featureStoreName, allFeatures, params);
fname2index = new HashMap<String,Integer>();
for (int i = 0; i < features.size(); ++i) {
final String key = features.get(i).getName();
fname2index.put(key, i);
}
}
@Override
protected void validate() throws ModelException {
super.validate();
if (trees == null) {
throw new ModelException("no trees declared for model "+name);
}
for (RegressionTree tree : trees) {
tree.validate();
}
}
@Override
public float score(float[] modelFeatureValuesNormalized) {
float score = 0;
for (final RegressionTree t : trees) {
score += t.score(modelFeatureValuesNormalized);
}
return score;
}
// /////////////////////////////////////////
// produces a string that looks like:
// 40.0 = multipleadditivetreesmodel [ org.apache.solr.ltr.model.MultipleAdditiveTreesModel ]
// model applied to
// features, sum of:
// 50.0 = tree 0 | 'matchedTitle':1.0 > 0.500001, Go Right |
// 'this_feature_doesnt_exist' does not
// exist in FV, Go Left | val: 50.0
// -10.0 = tree 1 | val: -10.0
@Override
public Explanation explain(LeafReaderContext context, int doc,
float finalScore, List<Explanation> featureExplanations) {
final float[] fv = new float[featureExplanations.size()];
int index = 0;
for (final Explanation featureExplain : featureExplanations) {
fv[index] = featureExplain.getValue();
index++;
}
final List<Explanation> details = new ArrayList<>();
index = 0;
for (final RegressionTree t : trees) {
final float score = t.score(fv);
final Explanation p = Explanation.match(score, "tree " + index + " | "
+ t.explain(fv));
details.add(p);
index++;
}
return Explanation.match(finalScore, toString()
+ " model applied to features, sum of:", details);
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder(getClass().getSimpleName());
sb.append("(name=").append(getName());
sb.append(",trees=[");
for (int ii = 0; ii < trees.size(); ++ii) {
if (ii>0) {
sb.append(',');
}
sb.append(trees.get(ii));
}
sb.append("])");
return sb.toString();
}
}

View File

@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Contains Model related classes
*/
package org.apache.solr.ltr.model;

View File

@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.norm;
import java.util.LinkedHashMap;
/**
* A Normalizer that normalizes a feature value to itself. This is the
* default normalizer class, if no normalizer is configured then the
* IdentityNormalizer will be used.
*/
public class IdentityNormalizer extends Normalizer {
public static final IdentityNormalizer INSTANCE = new IdentityNormalizer();
public IdentityNormalizer() {
}
@Override
public float normalize(float value) {
return value;
}
@Override
public LinkedHashMap<String,Object> paramsToMap() {
return null;
}
@Override
protected void validate() throws NormalizerException {
}
@Override
public String toString() {
return getClass().getSimpleName();
}
}

View File

@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.norm;
import java.util.LinkedHashMap;
/**
* A Normalizer to scale a feature value using a (min,max) range.
* <p>
* Example configuration:
<pre>
"norm" : {
"class" : "org.apache.solr.ltr.norm.MinMaxNormalizer",
"params" : { "min":"0", "max":"50" }
}
</pre>
* Example normalizations:
* <ul>
* <li>-5 will be normalized to -0.1
* <li>55 will be normalized to 1.1
* <li>+5 will be normalized to +0.1
* </ul>
*/
public class MinMaxNormalizer extends Normalizer {
private float min = Float.NEGATIVE_INFINITY;
private float max = Float.POSITIVE_INFINITY;
private float delta = max - min;
private void updateDelta() {
delta = max - min;
}
public float getMin() {
return min;
}
public void setMin(float min) {
this.min = min;
updateDelta();
}
public void setMin(String min) {
this.min = Float.parseFloat(min);
updateDelta();
}
public float getMax() {
return max;
}
public void setMax(float max) {
this.max = max;
updateDelta();
}
public void setMax(String max) {
this.max = Float.parseFloat(max);
updateDelta();
}
@Override
protected void validate() throws NormalizerException {
if (delta == 0f) {
throw
new NormalizerException("MinMax Normalizer delta must not be zero " +
"| min = " + min + ",max = " + max + ",delta = " + delta);
}
}
@Override
public float normalize(float value) {
return (value - min) / delta;
}
@Override
public LinkedHashMap<String,Object> paramsToMap() {
final LinkedHashMap<String,Object> params = new LinkedHashMap<>(2, 1.0f);
params.put("min", min);
params.put("max", max);
return params;
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder(64); // default initialCapacity of 16 won't be enough
sb.append(getClass().getSimpleName()).append('(');
sb.append("min=").append(min);
sb.append(",max=").append(max).append(')');
return sb.toString();
}
}

View File

@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.norm;
import java.util.LinkedHashMap;
import java.util.Map;
import org.apache.lucene.search.Explanation;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.util.SolrPluginUtils;
/**
* A normalizer normalizes the value of a feature. After the feature values
* have been computed, the {@link Normalizer#normalize(float)} methods will
* be called and the resulting values will be used by the model.
*/
public abstract class Normalizer {
public abstract float normalize(float value);
public abstract LinkedHashMap<String,Object> paramsToMap();
public Explanation explain(Explanation explain) {
final float normalized = normalize(explain.getValue());
final String explainDesc = "normalized using " + toString();
return Explanation.match(normalized, explainDesc, explain);
}
public static Normalizer getInstance(SolrResourceLoader solrResourceLoader,
String className, Map<String,Object> params) {
final Normalizer f = solrResourceLoader.newInstance(className, Normalizer.class);
if (params != null) {
SolrPluginUtils.invokeSetters(f, params.entrySet());
}
f.validate();
return f;
}
/**
* As part of creation of a normalizer instance, this function confirms
* that the normalizer parameters are valid.
*
* @throws NormalizerException
* Normalizer Exception
*/
protected abstract void validate() throws NormalizerException;
}

View File

@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.norm;
public class NormalizerException extends RuntimeException {
private static final long serialVersionUID = 1L;
public NormalizerException(String message) {
super(message);
}
public NormalizerException(String message, Exception cause) {
super(message, cause);
}
}

View File

@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.norm;
import java.util.LinkedHashMap;
/**
* A Normalizer to scale a feature value around an average-and-standard-deviation distribution.
* <p>
* Example configuration:
<pre>
"norm" : {
"class" : "org.apache.solr.ltr.norm.StandardNormalizer",
"params" : { "avg":"42", "std":"6" }
}
</pre>
* <p>
* Example normalizations:
* <ul>
* <li>39 will be normalized to -0.5
* <li>42 will be normalized to 0
* <li>45 will be normalized to +0.5
* </ul>
*/
public class StandardNormalizer extends Normalizer {
private float avg = 0f;
private float std = 1f;
public float getAvg() {
return avg;
}
public void setAvg(float avg) {
this.avg = avg;
}
public float getStd() {
return std;
}
public void setStd(float std) {
this.std = std;
}
public void setAvg(String avg) {
this.avg = Float.parseFloat(avg);
}
public void setStd(String std) {
this.std = Float.parseFloat(std);
}
@Override
public float normalize(float value) {
return (value - avg) / std;
}
@Override
protected void validate() throws NormalizerException {
if (std <= 0f) {
throw
new NormalizerException("Standard Normalizer standard deviation must "
+ "be positive | avg = " + avg + ",std = " + std);
}
}
@Override
public LinkedHashMap<String,Object> paramsToMap() {
final LinkedHashMap<String,Object> params = new LinkedHashMap<>(2, 1.0f);
params.put("avg", avg);
params.put("std", std);
return params;
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder(64); // default initialCapacity of 16 won't be enough
sb.append(getClass().getSimpleName()).append('(');
sb.append("avg=").append(avg);
sb.append(",std=").append(avg).append(')');
return sb.toString();
}
}

View File

@ -0,0 +1,23 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* A normalizer normalizes the value of a feature. Once that the feature values
* will be computed, the normalizer will be applied and the resulting values
* will be received by the model.
*/
package org.apache.solr.ltr.norm;

View File

@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* <p>
* This package contains the main logic for performing the reranking using
* a Learning to Rank model.
* </p>
* <p>
* A model will be applied on each document through a {@link org.apache.solr.ltr.LTRScoringQuery}, a
* subclass of {@link org.apache.lucene.search.Query}. As a normal query,
* the learned model will produce a new score
* for each document reranked.
* </p>
* <p>
* A {@link org.apache.solr.ltr.LTRScoringQuery} is created by providing an instance of
* {@link org.apache.solr.ltr.model.LTRScoringModel}. An instance of
* {@link org.apache.solr.ltr.model.LTRScoringModel}
* defines how to combine the features in order to create a new
* score for a document. A new Learning to Rank model is plugged
* into the framework by extending {@link org.apache.solr.ltr.model.LTRScoringModel},
* (see for example {@link org.apache.solr.ltr.model.MultipleAdditiveTreesModel} and {@link org.apache.solr.ltr.model.LinearModel}).
* </p>
* <p>
* The {@link org.apache.solr.ltr.LTRScoringQuery} will take care of computing the values of
* all the features (see {@link org.apache.solr.ltr.feature.Feature}) and then will delegate the final score
* generation to the {@link org.apache.solr.ltr.model.LTRScoringModel}, by calling the method
* {@link org.apache.solr.ltr.model.LTRScoringModel#score(float[] modelFeatureValuesNormalized) score(float[] modelFeatureValuesNormalized)}.
* </p>
*/
package org.apache.solr.ltr;

View File

@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.store;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.FeatureException;
public class FeatureStore {
/** the name of the default feature store **/
public static final String DEFAULT_FEATURE_STORE_NAME = "_DEFAULT_";
private final LinkedHashMap<String,Feature> store = new LinkedHashMap<>(); // LinkedHashMap because we need predictable iteration order
private final String name;
public FeatureStore(String name) {
this.name = name;
}
public String getName() {
return name;
}
public Feature get(String name) {
return store.get(name);
}
public void add(Feature feature) {
final String name = feature.getName();
if (store.containsKey(name)) {
throw new FeatureException(name
+ " already contained in the store, please use a different name");
}
feature.setIndex(store.size());
store.put(name, feature);
}
public List<Feature> getFeatures() {
final List<Feature> storeValues = new ArrayList<Feature>(store.values());
return Collections.unmodifiableList(storeValues);
}
@Override
public String toString() {
return "FeatureStore [features=" + store.keySet() + "]";
}
}

View File

@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.store;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.model.ModelException;
/**
* Contains the model and features declared.
*/
public class ModelStore {
private final Map<String,LTRScoringModel> availableModels;
public ModelStore() {
availableModels = new HashMap<>();
}
public synchronized LTRScoringModel getModel(String name) {
return availableModels.get(name);
}
public void clear() {
availableModels.clear();
}
public List<LTRScoringModel> getModels() {
final List<LTRScoringModel> availableModelsValues =
new ArrayList<LTRScoringModel>(availableModels.values());
return Collections.unmodifiableList(availableModelsValues);
}
@Override
public String toString() {
return "ModelStore [availableModels=" + availableModels.keySet() + "]";
}
public LTRScoringModel delete(String modelName) {
return availableModels.remove(modelName);
}
public synchronized void addModel(LTRScoringModel modeldata)
throws ModelException {
final String name = modeldata.getName();
if (availableModels.containsKey(name)) {
throw new ModelException("model '" + name
+ "' already exists. Please use a different name");
}
availableModels.put(modeldata.getName(), modeldata);
}
}

View File

@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Contains feature and model store related classes.
*/
package org.apache.solr.ltr.store;

View File

@ -0,0 +1,215 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.store.rest;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.SolrCore;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.store.FeatureStore;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.rest.BaseSolrResource;
import org.apache.solr.rest.ManagedResource;
import org.apache.solr.rest.ManagedResourceObserver;
import org.apache.solr.rest.ManagedResourceStorage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Managed resource for a storing a feature.
*/
public class ManagedFeatureStore extends ManagedResource implements ManagedResource.ChildResourceSupport {
public static void registerManagedFeatureStore(SolrResourceLoader solrResourceLoader,
ManagedResourceObserver managedResourceObserver) {
solrResourceLoader.getManagedResourceRegistry().registerManagedResource(
REST_END_POINT,
ManagedFeatureStore.class,
managedResourceObserver);
}
public static ManagedFeatureStore getManagedFeatureStore(SolrCore core) {
return (ManagedFeatureStore) core.getRestManager()
.getManagedResource(REST_END_POINT);
}
/** the feature store rest endpoint **/
public static final String REST_END_POINT = "/schema/feature-store";
// TODO: reduce from public to package visibility (once tests no longer need public access)
/** name of the attribute containing the feature class **/
static final String CLASS_KEY = "class";
/** name of the attribute containing the feature name **/
static final String NAME_KEY = "name";
/** name of the attribute containing the feature params **/
static final String PARAMS_KEY = "params";
/** name of the attribute containing the feature store used **/
static final String FEATURE_STORE_NAME_KEY = "store";
private final Map<String,FeatureStore> stores = new HashMap<>();
/**
* Managed feature store: the name of the attribute containing all the feature
* stores
**/
private static final String FEATURE_STORE_JSON_FIELD = "featureStores";
/**
* Managed feature store: the name of the attribute containing all the
* features of a feature store
**/
private static final String FEATURES_JSON_FIELD = "features";
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
public ManagedFeatureStore(String resourceId, SolrResourceLoader loader,
ManagedResourceStorage.StorageIO storageIO) throws SolrException {
super(resourceId, loader, storageIO);
}
public synchronized FeatureStore getFeatureStore(String name) {
if (name == null) {
name = FeatureStore.DEFAULT_FEATURE_STORE_NAME;
}
if (!stores.containsKey(name)) {
stores.put(name, new FeatureStore(name));
}
return stores.get(name);
}
@Override
protected void onManagedDataLoadedFromStorage(NamedList<?> managedInitArgs,
Object managedData) throws SolrException {
stores.clear();
log.info("------ managed feature ~ loading ------");
if (managedData instanceof List) {
@SuppressWarnings("unchecked")
final List<Map<String,Object>> up = (List<Map<String,Object>>) managedData;
for (final Map<String,Object> u : up) {
final String featureStore = (String) u.get(FEATURE_STORE_NAME_KEY);
addFeature(u, featureStore);
}
}
}
public synchronized void addFeature(Map<String,Object> map, String featureStore) {
log.info("register feature based on {}", map);
final FeatureStore fstore = getFeatureStore(featureStore);
final Feature feature = fromFeatureMap(solrResourceLoader, map);
fstore.add(feature);
}
@SuppressWarnings("unchecked")
@Override
public Object applyUpdatesToManagedData(Object updates) {
if (updates instanceof List) {
final List<Map<String,Object>> up = (List<Map<String,Object>>) updates;
for (final Map<String,Object> u : up) {
final String featureStore = (String) u.get(FEATURE_STORE_NAME_KEY);
addFeature(u, featureStore);
}
}
if (updates instanceof Map) {
// a unique feature
Map<String,Object> updatesMap = (Map<String,Object>) updates;
final String featureStore = (String) updatesMap.get(FEATURE_STORE_NAME_KEY);
addFeature(updatesMap, featureStore);
}
final List<Object> features = new ArrayList<>();
for (final FeatureStore fs : stores.values()) {
features.addAll(featuresAsManagedResources(fs));
}
return features;
}
@Override
public synchronized void doDeleteChild(BaseSolrResource endpoint, String childId) {
if (childId.equals("*")) {
stores.clear();
}
if (stores.containsKey(childId)) {
stores.remove(childId);
}
storeManagedData(applyUpdatesToManagedData(null));
}
/**
* Called to retrieve a named part (the given childId) of the resource at the
* given endpoint. Note: since we have a unique child feature store we ignore
* the childId.
*/
@Override
public void doGet(BaseSolrResource endpoint, String childId) {
final SolrQueryResponse response = endpoint.getSolrResponse();
// If no feature store specified, show all the feature stores available
if (childId == null) {
response.add(FEATURE_STORE_JSON_FIELD, stores.keySet());
} else {
final FeatureStore store = getFeatureStore(childId);
if (store == null) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
"missing feature store [" + childId + "]");
}
response.add(FEATURES_JSON_FIELD,
featuresAsManagedResources(store));
}
}
private static List<Object> featuresAsManagedResources(FeatureStore store) {
final List<Feature> storedFeatures = store.getFeatures();
final List<Object> features = new ArrayList<Object>(storedFeatures.size());
for (final Feature f : storedFeatures) {
final LinkedHashMap<String,Object> m = toFeatureMap(f);
m.put(FEATURE_STORE_NAME_KEY, store.getName());
features.add(m);
}
return features;
}
private static LinkedHashMap<String,Object> toFeatureMap(Feature feat) {
final LinkedHashMap<String,Object> o = new LinkedHashMap<>(4, 1.0f); // 1 extra for caller to add store
o.put(NAME_KEY, feat.getName());
o.put(CLASS_KEY, feat.getClass().getCanonicalName());
o.put(PARAMS_KEY, feat.paramsToMap());
return o;
}
private static Feature fromFeatureMap(SolrResourceLoader solrResourceLoader,
Map<String,Object> featureMap) {
final String className = (String) featureMap.get(CLASS_KEY);
final String name = (String) featureMap.get(NAME_KEY);
@SuppressWarnings("unchecked")
final Map<String,Object> params = (Map<String,Object>) featureMap.get(PARAMS_KEY);
return Feature.getInstance(solrResourceLoader, className, name, params);
}
}

View File

@ -0,0 +1,319 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.store.rest;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.SolrCore;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.model.ModelException;
import org.apache.solr.ltr.norm.IdentityNormalizer;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.ltr.store.FeatureStore;
import org.apache.solr.ltr.store.ModelStore;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.rest.BaseSolrResource;
import org.apache.solr.rest.ManagedResource;
import org.apache.solr.rest.ManagedResourceObserver;
import org.apache.solr.rest.ManagedResourceStorage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Menaged resource for storing a model
*/
public class ManagedModelStore extends ManagedResource implements ManagedResource.ChildResourceSupport {
public static void registerManagedModelStore(SolrResourceLoader solrResourceLoader,
ManagedResourceObserver managedResourceObserver) {
solrResourceLoader.getManagedResourceRegistry().registerManagedResource(
REST_END_POINT,
ManagedModelStore.class,
managedResourceObserver);
}
public static ManagedModelStore getManagedModelStore(SolrCore core) {
return (ManagedModelStore) core.getRestManager()
.getManagedResource(REST_END_POINT);
}
/** the model store rest endpoint **/
public static final String REST_END_POINT = "/schema/model-store";
// TODO: reduce from public to package visibility (once tests no longer need public access)
/**
* Managed model store: the name of the attribute containing all the models of
* a model store
**/
private static final String MODELS_JSON_FIELD = "models";
/** name of the attribute containing a class **/
static final String CLASS_KEY = "class";
/** name of the attribute containing the features **/
static final String FEATURES_KEY = "features";
/** name of the attribute containing a name **/
static final String NAME_KEY = "name";
/** name of the attribute containing a normalizer **/
static final String NORM_KEY = "norm";
/** name of the attribute containing parameters **/
static final String PARAMS_KEY = "params";
/** name of the attribute containing a store **/
static final String STORE_KEY = "store";
private final ModelStore store;
private ManagedFeatureStore managedFeatureStore;
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
public ManagedModelStore(String resourceId, SolrResourceLoader loader,
ManagedResourceStorage.StorageIO storageIO) throws SolrException {
super(resourceId, loader, storageIO);
store = new ModelStore();
}
public void setManagedFeatureStore(ManagedFeatureStore managedFeatureStore) {
log.info("INIT model store");
this.managedFeatureStore = managedFeatureStore;
}
public ManagedFeatureStore getManagedFeatureStore() {
return managedFeatureStore;
}
private Object managedData;
@SuppressWarnings("unchecked")
@Override
protected void onManagedDataLoadedFromStorage(NamedList<?> managedInitArgs,
Object managedData) throws SolrException {
store.clear();
// the managed models on the disk or on zookeeper will be loaded in a lazy
// way, since we need to set the managed features first (unfortunately
// managed resources do not
// decouple the creation of a managed resource with the reading of the data
// from the storage)
this.managedData = managedData;
}
public void loadStoredModels() {
log.info("------ managed models ~ loading ------");
if ((managedData != null) && (managedData instanceof List)) {
final List<Map<String,Object>> up = (List<Map<String,Object>>) managedData;
for (final Map<String,Object> u : up) {
try {
final LTRScoringModel algo = fromLTRScoringModelMap(solrResourceLoader, u, managedFeatureStore);
addModel(algo);
} catch (final ModelException e) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e);
}
}
}
}
public synchronized void addModel(LTRScoringModel ltrScoringModel) throws ModelException {
try {
log.info("adding model {}", ltrScoringModel.getName());
store.addModel(ltrScoringModel);
} catch (final ModelException e) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e);
}
}
@SuppressWarnings("unchecked")
@Override
protected Object applyUpdatesToManagedData(Object updates) {
if (updates instanceof List) {
final List<Map<String,Object>> up = (List<Map<String,Object>>) updates;
for (final Map<String,Object> u : up) {
try {
final LTRScoringModel algo = fromLTRScoringModelMap(solrResourceLoader, u, managedFeatureStore);
addModel(algo);
} catch (final ModelException e) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e);
}
}
}
if (updates instanceof Map) {
final Map<String,Object> map = (Map<String,Object>) updates;
try {
final LTRScoringModel algo = fromLTRScoringModelMap(solrResourceLoader, map, managedFeatureStore);
addModel(algo);
} catch (final ModelException e) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e);
}
}
return modelsAsManagedResources(store.getModels());
}
@Override
public synchronized void doDeleteChild(BaseSolrResource endpoint, String childId) {
if (childId.equals("*")) {
store.clear();
} else {
store.delete(childId);
}
storeManagedData(applyUpdatesToManagedData(null));
}
/**
* Called to retrieve a named part (the given childId) of the resource at the
* given endpoint. Note: since we have a unique child managed store we ignore
* the childId.
*/
@Override
public void doGet(BaseSolrResource endpoint, String childId) {
final SolrQueryResponse response = endpoint.getSolrResponse();
response.add(MODELS_JSON_FIELD,
modelsAsManagedResources(store.getModels()));
}
public LTRScoringModel getModel(String modelName) {
// this function replicates getModelStore().getModel(modelName), but
// it simplifies the testing (we can avoid to mock also a ModelStore).
return store.getModel(modelName);
}
@Override
public String toString() {
return "ManagedModelStore [store=" + store + ", featureStores="
+ managedFeatureStore + "]";
}
/**
* Returns the available models as a list of Maps objects. After an update the
* managed resources needs to return the resources in this format in order to
* store in json somewhere (zookeeper, disk...)
*
*
* @return the available models as a list of Maps objects
*/
private static List<Object> modelsAsManagedResources(List<LTRScoringModel> models) {
final List<Object> list = new ArrayList<>(models.size());
for (final LTRScoringModel model : models) {
list.add(toLTRScoringModelMap(model));
}
return list;
}
@SuppressWarnings("unchecked")
public static LTRScoringModel fromLTRScoringModelMap(SolrResourceLoader solrResourceLoader,
Map<String,Object> modelMap, ManagedFeatureStore managedFeatureStore) {
final FeatureStore featureStore =
managedFeatureStore.getFeatureStore((String) modelMap.get(STORE_KEY));
final List<Feature> features = new ArrayList<>();
final List<Normalizer> norms = new ArrayList<>();
final List<Object> featureList = (List<Object>) modelMap.get(FEATURES_KEY);
if (featureList != null) {
for (final Object feature : featureList) {
final Map<String,Object> featureMap = (Map<String,Object>) feature;
features.add(lookupFeatureFromFeatureMap(featureMap, featureStore));
norms.add(createNormalizerFromFeatureMap(solrResourceLoader, featureMap));
}
}
return LTRScoringModel.getInstance(solrResourceLoader,
(String) modelMap.get(CLASS_KEY), // modelClassName
(String) modelMap.get(NAME_KEY), // modelName
features,
norms,
featureStore.getName(),
featureStore.getFeatures(),
(Map<String,Object>) modelMap.get(PARAMS_KEY));
}
private static LinkedHashMap<String,Object> toLTRScoringModelMap(LTRScoringModel model) {
final LinkedHashMap<String,Object> modelMap = new LinkedHashMap<>(5, 1.0f);
modelMap.put(NAME_KEY, model.getName());
modelMap.put(CLASS_KEY, model.getClass().getCanonicalName());
modelMap.put(STORE_KEY, model.getFeatureStoreName());
final List<Map<String,Object>> features = new ArrayList<>();
final List<Feature> featuresList = model.getFeatures();
final List<Normalizer> normsList = model.getNorms();
for (int ii=0; ii<featuresList.size(); ++ii) {
features.add(toFeatureMap(featuresList.get(ii), normsList.get(ii)));
}
modelMap.put(FEATURES_KEY, features);
modelMap.put(PARAMS_KEY, model.getParams());
return modelMap;
}
private static Feature lookupFeatureFromFeatureMap(Map<String,Object> featureMap,
FeatureStore featureStore) {
final String featureName = (String)featureMap.get(NAME_KEY);
return (featureName == null ? null
: featureStore.get(featureName));
}
@SuppressWarnings("unchecked")
private static Normalizer createNormalizerFromFeatureMap(SolrResourceLoader solrResourceLoader,
Map<String,Object> featureMap) {
final Map<String,Object> normMap = (Map<String,Object>)featureMap.get(NORM_KEY);
return (normMap == null ? IdentityNormalizer.INSTANCE
: fromNormalizerMap(solrResourceLoader, normMap));
}
private static LinkedHashMap<String,Object> toFeatureMap(Feature feature, Normalizer norm) {
final LinkedHashMap<String,Object> map = new LinkedHashMap<String,Object>(2, 1.0f);
map.put(NAME_KEY, feature.getName());
map.put(NORM_KEY, toNormalizerMap(norm));
return map;
}
private static Normalizer fromNormalizerMap(SolrResourceLoader solrResourceLoader,
Map<String,Object> normMap) {
final String className = (String) normMap.get(CLASS_KEY);
@SuppressWarnings("unchecked")
final Map<String,Object> params = (Map<String,Object>) normMap.get(PARAMS_KEY);
return Normalizer.getInstance(solrResourceLoader, className, params);
}
private static LinkedHashMap<String,Object> toNormalizerMap(Normalizer norm) {
final LinkedHashMap<String,Object> normalizer = new LinkedHashMap<>(2, 1.0f);
normalizer.put(CLASS_KEY, norm.getClass().getCanonicalName());
final LinkedHashMap<String,Object> params = norm.paramsToMap();
if (params != null) {
normalizer.put(PARAMS_KEY, params);
}
return normalizer;
}
}

View File

@ -0,0 +1,22 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Contains the {@link org.apache.solr.rest.ManagedResource} that encapsulate
* the feature and the model stores.
*/
package org.apache.solr.ltr.store.rest;

View File

@ -0,0 +1,254 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.response.transform;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.solr.common.SolrDocument;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.ltr.FeatureLogger;
import org.apache.solr.ltr.LTRRescorer;
import org.apache.solr.ltr.LTRScoringQuery;
import org.apache.solr.ltr.LTRThreadModule;
import org.apache.solr.ltr.SolrQueryRequestContextUtils;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.ltr.store.FeatureStore;
import org.apache.solr.ltr.store.rest.ManagedFeatureStore;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.response.ResultContext;
import org.apache.solr.search.LTRQParserPlugin;
import org.apache.solr.search.SolrIndexSearcher;
import org.apache.solr.util.SolrPluginUtils;
/**
* This transformer will take care to generate and append in the response the
* features declared in the feature store of the current model. The class is
* useful if you are not interested in the reranking (e.g., bootstrapping a
* machine learning framework).
*/
public class LTRFeatureLoggerTransformerFactory extends TransformerFactory {
// used inside fl to specify the output format (csv/json) of the extracted features
private static final String FV_RESPONSE_WRITER = "fvwt";
// used inside fl to specify the format (dense|sparse) of the extracted features
private static final String FV_FORMAT = "format";
// used inside fl to specify the feature store to use for the feature extraction
private static final String FV_STORE = "store";
private static String DEFAULT_LOGGING_MODEL_NAME = "logging-model";
private String loggingModelName = DEFAULT_LOGGING_MODEL_NAME;
private String defaultFvStore;
private String defaultFvwt;
private String defaultFvFormat;
private LTRThreadModule threadManager = null;
public void setLoggingModelName(String loggingModelName) {
this.loggingModelName = loggingModelName;
}
public void setStore(String defaultFvStore) {
this.defaultFvStore = defaultFvStore;
}
public void setFvwt(String defaultFvwt) {
this.defaultFvwt = defaultFvwt;
}
public void setFormat(String defaultFvFormat) {
this.defaultFvFormat = defaultFvFormat;
}
@Override
public void init(@SuppressWarnings("rawtypes") NamedList args) {
super.init(args);
threadManager = LTRThreadModule.getInstance(args);
SolrPluginUtils.invokeSetters(this, args);
}
@Override
public DocTransformer create(String name, SolrParams params,
SolrQueryRequest req) {
// Hint to enable feature vector cache since we are requesting features
SolrQueryRequestContextUtils.setIsExtractingFeatures(req);
// Communicate which feature store we are requesting features for
SolrQueryRequestContextUtils.setFvStoreName(req, params.get(FV_STORE, defaultFvStore));
// Create and supply the feature logger to be used
SolrQueryRequestContextUtils.setFeatureLogger(req,
FeatureLogger.createFeatureLogger(
params.get(FV_RESPONSE_WRITER, defaultFvwt),
params.get(FV_FORMAT, defaultFvFormat)));
return new FeatureTransformer(name, params, req);
}
class FeatureTransformer extends DocTransformer {
final private String name;
final private SolrParams params;
final private SolrQueryRequest req;
private List<LeafReaderContext> leafContexts;
private SolrIndexSearcher searcher;
private LTRScoringQuery scoringQuery;
private LTRScoringQuery.ModelWeight modelWeight;
private FeatureLogger<?> featureLogger;
private boolean docsWereNotReranked;
/**
* @param name
* Name of the field to be added in a document representing the
* feature vectors
*/
public FeatureTransformer(String name, SolrParams params,
SolrQueryRequest req) {
this.name = name;
this.params = params;
this.req = req;
}
@Override
public String getName() {
return name;
}
@Override
public void setContext(ResultContext context) {
super.setContext(context);
if (context == null) {
return;
}
if (context.getRequest() == null) {
return;
}
searcher = context.getSearcher();
if (searcher == null) {
throw new SolrException(
SolrException.ErrorCode.BAD_REQUEST,
"searcher is null");
}
leafContexts = searcher.getTopReaderContext().leaves();
// Setup LTRScoringQuery
scoringQuery = SolrQueryRequestContextUtils.getScoringQuery(req);
docsWereNotReranked = (scoringQuery == null);
String featureStoreName = SolrQueryRequestContextUtils.getFvStoreName(req);
if (docsWereNotReranked || (featureStoreName != null && (!featureStoreName.equals(scoringQuery.getScoringModel().getFeatureStoreName())))) {
// if store is set in the transformer we should overwrite the logger
final ManagedFeatureStore fr = ManagedFeatureStore.getManagedFeatureStore(req.getCore());
final FeatureStore store = fr.getFeatureStore(featureStoreName);
featureStoreName = store.getName(); // if featureStoreName was null before this gets actual name
try {
final LoggingModel lm = new LoggingModel(loggingModelName,
featureStoreName, store.getFeatures());
scoringQuery = new LTRScoringQuery(lm,
LTRQParserPlugin.extractEFIParams(params),
true,
threadManager); // request feature weights to be created for all features
// Local transformer efi if provided
scoringQuery.setOriginalQuery(context.getQuery());
}catch (final Exception e) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
"retrieving the feature store "+featureStoreName, e);
}
}
if (scoringQuery.getFeatureLogger() == null){
scoringQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) );
}
scoringQuery.setRequest(req);
featureLogger = scoringQuery.getFeatureLogger();
try {
modelWeight = scoringQuery.createWeight(searcher, true, 1f);
} catch (final IOException e) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e.getMessage(), e);
}
if (modelWeight == null) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
"error logging the features, model weight is null");
}
}
@Override
public void transform(SolrDocument doc, int docid, float score)
throws IOException {
Object fv = featureLogger.getFeatureVector(docid, scoringQuery, searcher);
if (fv == null) { // FV for this document was not in the cache
fv = featureLogger.makeFeatureVector(
LTRRescorer.extractFeaturesInfo(
modelWeight,
docid,
(docsWereNotReranked ? new Float(score) : null),
leafContexts));
}
doc.addField(name, fv);
}
}
private static class LoggingModel extends LTRScoringModel {
public LoggingModel(String name, String featureStoreName, List<Feature> allFeatures){
this(name, Collections.emptyList(), Collections.emptyList(),
featureStoreName, allFeatures, Collections.emptyMap());
}
protected LoggingModel(String name, List<Feature> features,
List<Normalizer> norms, String featureStoreName,
List<Feature> allFeatures, Map<String,Object> params) {
super(name, features, norms, featureStoreName, allFeatures, params);
}
@Override
public float score(float[] modelFeatureValuesNormalized) {
return 0;
}
@Override
public Explanation explain(LeafReaderContext context, int doc, float finalScore, List<Explanation> featureExplanations) {
return Explanation.match(finalScore, toString()
+ " logging model, used only for logging the features");
}
}
}

View File

@ -0,0 +1,23 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* APIs and implementations of {@link org.apache.solr.response.transform.DocTransformer} for modifying documents in Solr request responses
*/
package org.apache.solr.response.transform;

View File

@ -0,0 +1,233 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.search;
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.lucene.analysis.util.ResourceLoader;
import org.apache.lucene.analysis.util.ResourceLoaderAware;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.ltr.LTRRescorer;
import org.apache.solr.ltr.LTRScoringQuery;
import org.apache.solr.ltr.LTRThreadModule;
import org.apache.solr.ltr.SolrQueryRequestContextUtils;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.store.rest.ManagedFeatureStore;
import org.apache.solr.ltr.store.rest.ManagedModelStore;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.rest.ManagedResource;
import org.apache.solr.rest.ManagedResourceObserver;
import org.apache.solr.util.SolrPluginUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Plug into solr a rerank model.
*
* Learning to Rank Query Parser Syntax: rq={!ltr model=6029760550880411648 reRankDocs=300
* efi.myCompanyQueryIntent=0.98}
*
*/
public class LTRQParserPlugin extends QParserPlugin implements ResourceLoaderAware, ManagedResourceObserver {
public static final String NAME = "ltr";
private static Query defaultQuery = new MatchAllDocsQuery();
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
// params for setting custom external info that features can use, like query
// intent
static final String EXTERNAL_FEATURE_INFO = "efi.";
private ManagedFeatureStore fr = null;
private ManagedModelStore mr = null;
private LTRThreadModule threadManager = null;
/** query parser plugin: the name of the attribute for setting the model **/
public static final String MODEL = "model";
/** query parser plugin: default number of documents to rerank **/
public static final int DEFAULT_RERANK_DOCS = 200;
/**
* query parser plugin:the param that will select how the number of document
* to rerank
**/
public static final String RERANK_DOCS = "reRankDocs";
@Override
public void init(@SuppressWarnings("rawtypes") NamedList args) {
super.init(args);
threadManager = LTRThreadModule.getInstance(args);
SolrPluginUtils.invokeSetters(this, args);
}
@Override
public QParser createParser(String qstr, SolrParams localParams,
SolrParams params, SolrQueryRequest req) {
return new LTRQParser(qstr, localParams, params, req);
}
/**
* Given a set of local SolrParams, extract all of the efi.key=value params into a map
* @param localParams Local request parameters that might conatin efi params
* @return Map of efi params, where the key is the name of the efi param, and the
* value is the value of the efi param
*/
public static Map<String,String[]> extractEFIParams(SolrParams localParams) {
final Map<String,String[]> externalFeatureInfo = new HashMap<>();
for (final Iterator<String> it = localParams.getParameterNamesIterator(); it
.hasNext();) {
final String name = it.next();
if (name.startsWith(EXTERNAL_FEATURE_INFO)) {
externalFeatureInfo.put(
name.substring(EXTERNAL_FEATURE_INFO.length()),
new String[] {localParams.get(name)});
}
}
return externalFeatureInfo;
}
@Override
public void inform(ResourceLoader loader) throws IOException {
final SolrResourceLoader solrResourceLoader = (SolrResourceLoader) loader;
ManagedFeatureStore.registerManagedFeatureStore(solrResourceLoader, this);
ManagedModelStore.registerManagedModelStore(solrResourceLoader, this);
}
@Override
public void onManagedResourceInitialized(NamedList<?> args, ManagedResource res) throws SolrException {
if (res instanceof ManagedFeatureStore) {
fr = (ManagedFeatureStore)res;
}
if (res instanceof ManagedModelStore){
mr = (ManagedModelStore)res;
}
if (mr != null && fr != null){
mr.setManagedFeatureStore(fr);
// now we can safely load the models
mr.loadStoredModels();
}
}
public class LTRQParser extends QParser {
public LTRQParser(String qstr, SolrParams localParams, SolrParams params,
SolrQueryRequest req) {
super(qstr, localParams, params, req);
}
@Override
public Query parse() throws SyntaxError {
// ReRanking Model
final String modelName = localParams.get(LTRQParserPlugin.MODEL);
if ((modelName == null) || modelName.isEmpty()) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
"Must provide model in the request");
}
final LTRScoringModel ltrScoringModel = mr.getModel(modelName);
if (ltrScoringModel == null) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
"cannot find " + LTRQParserPlugin.MODEL + " " + modelName);
}
final String modelFeatureStoreName = ltrScoringModel.getFeatureStoreName();
final boolean extractFeatures = SolrQueryRequestContextUtils.isExtractingFeatures(req);
final String fvStoreName = SolrQueryRequestContextUtils.getFvStoreName(req);
// Check if features are requested and if the model feature store and feature-transform feature store are the same
final boolean featuresRequestedFromSameStore = (modelFeatureStoreName.equals(fvStoreName) || fvStoreName == null) ? extractFeatures:false;
final LTRScoringQuery scoringQuery = new LTRScoringQuery(ltrScoringModel,
extractEFIParams(localParams),
featuresRequestedFromSameStore, threadManager);
// Enable the feature vector caching if we are extracting features, and the features
// we requested are the same ones we are reranking with
if (featuresRequestedFromSameStore) {
scoringQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) );
}
SolrQueryRequestContextUtils.setScoringQuery(req, scoringQuery);
int reRankDocs = localParams.getInt(RERANK_DOCS, DEFAULT_RERANK_DOCS);
reRankDocs = Math.max(1, reRankDocs);
// External features
scoringQuery.setRequest(req);
return new LTRQuery(scoringQuery, reRankDocs);
}
}
/**
* A learning to rank Query, will incapsulate a learning to rank model, and delegate to it the rescoring
* of the documents.
**/
public class LTRQuery extends AbstractReRankQuery {
private final LTRScoringQuery scoringQuery;
public LTRQuery(LTRScoringQuery scoringQuery, int reRankDocs) {
super(defaultQuery, reRankDocs, new LTRRescorer(scoringQuery));
this.scoringQuery = scoringQuery;
}
@Override
public int hashCode() {
return 31 * classHash() + (mainQuery.hashCode() + scoringQuery.hashCode() + reRankDocs);
}
@Override
public boolean equals(Object o) {
return sameClassAs(o) && equalsTo(getClass().cast(o));
}
private boolean equalsTo(LTRQuery other) {
return (mainQuery.equals(other.mainQuery)
&& scoringQuery.equals(other.scoringQuery) && (reRankDocs == other.reRankDocs));
}
@Override
public RankQuery wrap(Query _mainQuery) {
super.wrap(_mainQuery);
scoringQuery.setOriginalQuery(_mainQuery);
return this;
}
@Override
public String toString(String field) {
return "{!ltr mainQuery='" + mainQuery.toString() + "' scoringQuery='"
+ scoringQuery.toString() + "' reRankDocs=" + reRankDocs + "}";
}
@Override
protected Query rewrite(Query rewrittenMainQuery) throws IOException {
return new LTRQuery(scoringQuery, reRankDocs).wrap(rewrittenMainQuery);
}
}
}

View File

@ -0,0 +1,23 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* APIs and classes for {@linkplain org.apache.solr.search.QParserPlugin parsing} and {@linkplain org.apache.solr.search.SolrIndexSearcher processing} search requests
*/
package org.apache.solr.search;

View File

@ -0,0 +1,91 @@
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<html>
<body>
Apache Solr Search Server: Learning to Rank Contrib
<p>
This module contains a logic <strong>to plug machine learned ranking modules into Solr</strong>.
</p>
<p>
In information retrieval systems, Learning to Rank is used to re-rank the top X
retrieved documents using trained machine learning models. The hope is
that sophisticated models can make more nuanced ranking decisions than standard ranking
functions like TF-IDF or BM25.
</p>
<p>
This module allows to plug a reranking component directly into Solr, enabling users
to easily build their own learning to rank systems and access the rich
matching features readily available in Solr. It also provides tools to perform
feature engineering and feature extraction.
</p>
<h2> Code structure </h2>
<p>
A Learning to Rank model is plugged into the ranking through the {@link org.apache.solr.search.LTRQParserPlugin},
a {@link org.apache.solr.search.QParserPlugin}. The plugin will
read from the request the model (instance of {@link org.apache.solr.ltr.model.LTRScoringModel})
used to perform the request plus other
parameters. The plugin will generate a {@link org.apache.solr.search.LTRQParserPlugin.LTRQuery LTRQuery}:
a particular {@link org.apache.solr.search.RankQuery}
that will encapsulate the given model and use it to
rescore and rerank the document (by using an {@link org.apache.solr.ltr.LTRRescorer}).
</p>
<p>
A model will be applied on each document through a {@link org.apache.solr.ltr.LTRScoringQuery}, a
subclass of {@link org.apache.lucene.search.Query}. As a normal query,
the learned model will produce a new score
for each document reranked.
</p>
<p>
A {@link org.apache.solr.ltr.LTRScoringQuery} is created by providing an instance of
{@link org.apache.solr.ltr.model.LTRScoringModel}. An instance of
{@link org.apache.solr.ltr.model.LTRScoringModel}
defines how to combine the features in order to create a new
score for a document. A new learning to rank model is plugged
into the framework by extending {@link org.apache.solr.ltr.model.LTRScoringModel},
(see for example {@link org.apache.solr.ltr.model.MultipleAdditiveTreesModel} and {@link org.apache.solr.ltr.model.LinearModel}).
</p>
<p>
The {@link org.apache.solr.ltr.LTRScoringQuery} will take care of computing the values of
all the features (see {@link org.apache.solr.ltr.feature.Feature}) and then will delegate the final score
generation to the {@link org.apache.solr.ltr.model.LTRScoringModel}, by calling the method
{@link org.apache.solr.ltr.model.LTRScoringModel#score(float[] modelFeatureValuesNormalized)}.
</p>
<p>
A {@link org.apache.solr.ltr.feature.Feature} will produce a particular value for each document, so
it is modeled as a {@link org.apache.lucene.search.Query}. The package
{@link org.apache.solr.ltr.feature} contains several examples
of features. One benefit of extending the Query object is that we can reuse
Query as a feature, see for example {@link org.apache.solr.ltr.feature.SolrFeature}.
Features for a document can also be returned in the response by
using the FeatureTransformer (a {@link org.apache.solr.response.transform.DocTransformer DocTransformer})
provided by {@link org.apache.solr.response.transform.LTRFeatureLoggerTransformerFactory}.
</p>
<p>
{@link org.apache.solr.ltr.store} contains all the logic to store all the features and the models.
Models are registered into a unique {@link org.apache.solr.ltr.store.ModelStore ModelStore},
and each model specifies a particular {@link org.apache.solr.ltr.store.FeatureStore FeatureStore} that
will contain a particular subset of features.
<p>
</p>
Features and models can be managed through a REST API, provided by the
{@link org.apache.solr.rest.ManagedResource Managed Resources}
{@link org.apache.solr.ltr.store.rest.ManagedFeatureStore ManagedFeatureStore}
and {@link org.apache.solr.ltr.store.rest.ManagedModelStore ManagedModelStore}.
</p>
</body>
</html>

View File

@ -0,0 +1,37 @@
[
{ "name":"origScore",
"class":"org.apache.solr.ltr.feature.OriginalScoreFeature",
"params":{},
"store": "feature-store-6"
},
{
"name": "descriptionTermFreq",
"class": "org.apache.solr.ltr.feature.SolrFeature",
"params": { "q" : "{!func}termfreq(description,${user_text})" },
"store": "feature-store-6"
},
{
"name": "popularity",
"class": "org.apache.solr.ltr.feature.SolrFeature",
"params": { "q" : "{!func}normHits"},
"store": "feature-store-6"
},
{
"name": "isPopular",
"class": "org.apache.solr.ltr.feature.SolrFeature",
"params": {"fq" : ["{!field f=popularity}201"] },
"store": "feature-store-6"
},
{
"name": "queryPartialMatch2",
"class": "org.apache.solr.ltr.feature.SolrFeature",
"params": {"q": "{!dismax qf=description mm=2}${user_text}" },
"store": "feature-store-6"
},
{
"name": "queryPartialMatch2.1",
"class": "org.apache.solr.ltr.feature.SolrFeature",
"params": {"q": "{!dismax qf=description mm=2}${user_text}" },
"store": "feature-store-6"
}
]

View File

@ -0,0 +1,51 @@
[ {
"name" : "matchedTitle",
"class" : "org.apache.solr.ltr.feature.SolrFeature",
"params" : {
"q" : "{!terms f=title}${user_query}"
}
}, {
"name" : "confidence",
"class" : "org.apache.solr.ltr.feature.ValueFeature",
"store": "fstore2",
"params" : {
"value" : "${myconf}"
}
}, {
"name":"originalScore",
"class":"org.apache.solr.ltr.feature.OriginalScoreFeature",
"store": "fstore2",
"params":{}
}, {
"name" : "occurrences",
"class" : "org.apache.solr.ltr.feature.ValueFeature",
"store": "fstore3",
"params" : {
"value" : "${myOcc}",
"required" : false
}
}, {
"name":"originalScore",
"class":"org.apache.solr.ltr.feature.OriginalScoreFeature",
"store": "fstore3",
"params":{}
}, {
"name" : "popularity",
"class" : "org.apache.solr.ltr.feature.ValueFeature",
"store": "fstore4",
"params" : {
"value" : "${myPop}",
"required" : true
}
}, {
"name":"originalScore",
"class":"org.apache.solr.ltr.feature.OriginalScoreFeature",
"store": "fstore4",
"params":{}
}, {
"name" : "titlePhraseMatch",
"class" : "org.apache.solr.ltr.feature.SolrFeature",
"params" : {
"q" : "{!field f=title}${user_query}"
}
} ]

View File

@ -0,0 +1,18 @@
[{
"name" : "user_device_smartphone",
"class":"org.apache.solr.ltr.feature.ValueFeature",
"params" : {
"value": "${user_device_smartphone}"
}
},
{
"name" : "user_device_tablet",
"class":"org.apache.solr.ltr.feature.ValueFeature",
"params" : {
"value": "${user_device_tablet}"
}
}
]

View File

@ -0,0 +1,17 @@
[
{
"name": "sampleConstant",
"class": "org.apache.solr.ltr.feature.ValueFeature",
"params": {
"value": 5
}
},
{
"name" : "search_number_of_nights",
"class":"org.apache.solr.ltr.feature.ValueFeature",
"params" : {
"value": "${search_number_of_nights}"
}
}
]

View File

@ -0,0 +1,51 @@
[
{
"name": "title",
"class": "org.apache.solr.ltr.feature.ValueFeature",
"params": {
"value": 1
}
},
{
"name": "description",
"class": "org.apache.solr.ltr.feature.ValueFeature",
"params": {
"value": 2
}
},
{
"name": "keywords",
"class": "org.apache.solr.ltr.feature.ValueFeature",
"params": {
"value": 2
}
},
{
"name": "popularity",
"class": "org.apache.solr.ltr.feature.ValueFeature",
"params": {
"value": 3
}
},
{
"name": "text",
"class": "org.apache.solr.ltr.feature.ValueFeature",
"params": {
"value": 4
}
},
{
"name": "queryIntentPerson",
"class": "org.apache.solr.ltr.feature.ValueFeature",
"params": {
"value": 5
}
},
{
"name": "queryIntentCompany",
"class": "org.apache.solr.ltr.feature.ValueFeature",
"params": {
"value": 5
}
}
]

View File

@ -0,0 +1,51 @@
[
{
"name": "constant1",
"class": "org.apache.solr.ltr.feature.ValueFeature",
"store":"test",
"params": {
"value": 1
}
},
{
"name": "constant2",
"class": "org.apache.solr.ltr.feature.ValueFeature",
"store":"test",
"params": {
"value": 2
}
},
{
"name": "constant3",
"class": "org.apache.solr.ltr.feature.ValueFeature",
"store":"test",
"params": {
"value": 3
}
},
{
"name": "constant4",
"class": "org.apache.solr.ltr.feature.ValueFeature",
"store":"test",
"params": {
"value": 4
}
},
{
"name": "constant5",
"class": "org.apache.solr.ltr.feature.ValueFeature",
"store":"test",
"params": {
"value": 5
}
},
{
"name": "pop",
"class": "org.apache.solr.ltr.feature.FieldValueFeature",
"store":"test",
"params": {
"field": "popularity"
}
}
]

View File

@ -0,0 +1,16 @@
[
{
"name": "matchedTitle",
"class": "org.apache.solr.ltr.feature.SolrFeature",
"params": {
"q": "{!terms f=title}${user_query}"
}
},
{
"name": "popularity",
"class": "org.apache.solr.ltr.feature.ValueFeature",
"params": {
"value": 3
}
}
]

View File

@ -0,0 +1,16 @@
[
{
"name": "matchedTitle",
"class": "org.apache.solr.ltr.feature.SolrFeature",
"params": {
"q": "{!terms f=title}${user_query}"
}
},
{
"name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs",
"class": "org.apache.solr.ltr.feature.ValueFeature",
"params": {
"value": 1
}
}
]

View File

@ -0,0 +1,32 @@
# Logging level
log4j.rootLogger=INFO, CONSOLE
log4j.appender.CONSOLE=org.apache.log4j.ConsoleAppender
log4j.appender.CONSOLE.Target=System.err
log4j.appender.CONSOLE.layout=org.apache.log4j.EnhancedPatternLayout
log4j.appender.CONSOLE.layout.ConversionPattern=%-4r %-5p (%t) [%X{node_name} %X{collection} %X{shard} %X{replica} %X{core}] %c{1.} %m%n
log4j.logger.org.apache.zookeeper=WARN
log4j.logger.org.apache.hadoop=WARN
log4j.logger.org.apache.directory=WARN
log4j.logger.org.apache.solr.hadoop=INFO
log4j.logger.org.apache.solr.client.solrj.embedded.JettySolrRunner=DEBUG
org.apache.solr.client.solrj.embedded.JettySolrRunner=DEBUG
#log4j.logger.org.apache.solr.update.processor.LogUpdateProcessor=DEBUG
#log4j.logger.org.apache.solr.update.processor.DistributedUpdateProcessor=DEBUG
#log4j.logger.org.apache.solr.update.PeerSync=DEBUG
#log4j.logger.org.apache.solr.core.CoreContainer=DEBUG
#log4j.logger.org.apache.solr.cloud.RecoveryStrategy=DEBUG
#log4j.logger.org.apache.solr.cloud.SyncStrategy=DEBUG
#log4j.logger.org.apache.solr.handler.admin.CoreAdminHandler=DEBUG
#log4j.logger.org.apache.solr.cloud.ZkController=DEBUG
#log4j.logger.org.apache.solr.update.DefaultSolrCoreState=DEBUG
#log4j.logger.org.apache.solr.common.cloud.ConnectionManager=DEBUG
#log4j.logger.org.apache.solr.update.UpdateLog=DEBUG
#log4j.logger.org.apache.solr.cloud.ChaosMonkey=DEBUG
#log4j.logger.org.apache.solr.update.TransactionLog=DEBUG
#log4j.logger.org.apache.solr.handler.ReplicationHandler=DEBUG
#log4j.logger.org.apache.solr.handler.IndexFetcher=DEBUG
#log4j.logger.org.apache.solr.common.cloud.ClusterStateUtil=DEBUG
#log4j.logger.org.apache.solr.cloud.OverseerAutoReplicaFailoverThread=DEBUG

View File

@ -0,0 +1,12 @@
{
"class":"org.apache.solr.ltr.model.LinearModel",
"name":"externalmodel",
"features":[
{ "name": "matchedTitle"}
],
"params":{
"weights": {
"matchedTitle": 0.999
}
}
}

View File

@ -0,0 +1,13 @@
{
"class":"org.apache.solr.ltr.model.LinearModel",
"name":"externalmodelstore",
"store": "fstore2",
"features":[
{ "name": "confidence"}
],
"params":{
"weights": {
"confidence": 0.999
}
}
}

View File

@ -0,0 +1,20 @@
{
"class":"org.apache.solr.ltr.model.LinearModel",
"name":"fqmodel",
"features":[
{
"name":"matchedTitle",
"norm": {
"class":"org.apache.solr.ltr.norm.MinMaxNormalizer",
"params":{ "min":"0.0f", "max":"10.0f" }
}
},
{ "name":"popularity"}
],
"params":{
"weights": {
"matchedTitle": 0.5,
"popularity": 0.5
}
}
}

View File

@ -0,0 +1,14 @@
{
"class":"org.apache.solr.ltr.model.LinearModel",
"name":"linear-efi",
"features":[
{"name":"sampleConstant"},
{"name":"search_number_of_nights"}
],
"params":{
"weights":{
"sampleConstant":1.0,
"search_number_of_nights":2.0
}
}
}

View File

@ -0,0 +1,30 @@
{
"class":"org.apache.solr.ltr.model.LinearModel",
"name":"6029760550880411648",
"features":[
{"name":"title"},
{"name":"description"},
{"name":"keywords"},
{
"name":"popularity",
"norm": {
"class":"org.apache.solr.ltr.norm.MinMaxNormalizer",
"params":{ "min":"0.0f", "max":"10.0f" }
}
},
{"name":"text"},
{"name":"queryIntentPerson"},
{"name":"queryIntentCompany"}
],
"params":{
"weights": {
"title": 0.0000000000,
"description": 0.1000000000,
"keywords": 0.2000000000,
"popularity": 0.3000000000,
"text": 0.4000000000,
"queryIntentPerson":0.1231231,
"queryIntentCompany":0.12121211
}
}
}

View File

@ -0,0 +1,38 @@
{
"class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
"name":"multipleadditivetreesmodel",
"features":[
{ "name": "matchedTitle"},
{ "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs"}
],
"params":{
"trees": [
{
"weight" : "1f",
"root": {
"feature": "matchedTitle",
"threshold": "0.5f",
"left" : {
"value" : "-100"
},
"right": {
"feature" : "this_feature_doesnt_exist",
"threshold": "10.0f",
"left" : {
"value" : "50"
},
"right" : {
"value" : "75"
}
}
}
},
{
"weight" : "2f",
"root": {
"value" : "-10"
}
}
]
}
}

View File

@ -0,0 +1,38 @@
{
"class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
"name":"external_model_binary_feature",
"features":[
{ "name": "user_device_smartphone"},
{ "name": "user_device_tablet"}
],
"params":{
"trees": [
{
"weight" : "1f",
"root": {
"feature": "user_device_smartphone",
"threshold": "0.5f",
"left" : {
"value" : "0"
},
"right" : {
"value" : "50"
}
}},
{
"weight" : "1f",
"root": {
"feature": "user_device_tablet",
"threshold": "0.5f",
"left" : {
"value" : "0"
},
"right" : {
"value" : "65"
}
}}
]
}
}

View File

@ -0,0 +1,24 @@
{
"class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
"name":"multipleadditivetreesmodel_no_feature",
"features":[
{ "name": "matchedTitle"},
{ "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs"}
],
"params":{
"trees": [
{
"weight" : "1f",
"root": {
"threshold": "0.5f",
"left" : {
"value" : "-100"
},
"right": {
"value" : "75"
}
}
}
]
}
}

View File

@ -0,0 +1,14 @@
{
"class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
"name":"multipleadditivetreesmodel_no_features",
"params":{
"trees": [
{
"weight" : "2f",
"root": {
"value" : "-10"
}
}
]
}
}

View File

@ -0,0 +1,22 @@
{
"class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
"name":"multipleadditivetreesmodel_no_left",
"features":[
{ "name": "matchedTitle"},
{ "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs"}
],
"params":{
"trees": [
{
"weight" : "1f",
"root": {
"feature": "matchedTitle",
"threshold": "0.5f",
"right": {
"value" : "75"
}
}
}
]
}
}

View File

@ -0,0 +1,8 @@
{
"class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
"name":"multipleadditivetreesmodel_no_params",
"features":[
{ "name": "matchedTitle"},
{ "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs"}
]
}

View File

@ -0,0 +1,22 @@
{
"class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
"name":"multipleadditivetreesmodel_no_right",
"features":[
{ "name": "matchedTitle"},
{ "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs"}
],
"params":{
"trees": [
{
"weight" : "1f",
"root": {
"feature": "matchedTitle",
"threshold": "0.5f",
"left" : {
"value" : "-100"
}
}
}
]
}
}

View File

@ -0,0 +1,24 @@
{
"class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
"name":"multipleadditivetreesmodel_no_threshold",
"features":[
{ "name": "matchedTitle"},
{ "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs"}
],
"params":{
"trees": [
{
"weight" : "1f",
"root": {
"feature": "matchedTitle",
"left" : {
"value" : "-100"
},
"right": {
"value" : "75"
}
}
}
]
}
}

View File

@ -0,0 +1,15 @@
{
"class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
"name":"multipleadditivetreesmodel_no_tree",
"features":[
{ "name": "matchedTitle"},
{ "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs"}
],
"params":{
"trees": [
{
"weight" : "2f"
}
]
}
}

View File

@ -0,0 +1,10 @@
{
"class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
"name":"multipleadditivetreesmodel_no_trees",
"features":[
{ "name": "matchedTitle"},
{ "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs"}
],
"params":{
}
}

View File

@ -0,0 +1,24 @@
{
"class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
"name":"multipleadditivetreesmodel_no_weight",
"features":[
{ "name": "matchedTitle"},
{ "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs"}
],
"params":{
"trees": [
{
"root": {
"feature": "matchedTitle",
"threshold": "0.5f",
"left" : {
"value" : "-100"
},
"right": {
"value" : "75"
}
}
}
]
}
}

View File

@ -0,0 +1,88 @@
<?xml version="1.0" encoding="UTF-8" ?>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<schema name="example" version="1.5">
<fields>
<field name="id" type="string" indexed="true" stored="true" required="true" multiValued="false" />
<field name="title" type="text_general" indexed="true" stored="true"/>
<field name="description" type="text_general" indexed="true" stored="true"/>
<field name="keywords" type="text_general" indexed="true" stored="true" multiValued="true"/>
<field name="popularity" type="int" indexed="true" stored="true" />
<field name="normHits" type="float" indexed="true" stored="true" />
<field name="text" type="text_general" indexed="true" stored="false" multiValued="true"/>
<field name="_version_" type="long" indexed="true" stored="true"/>
<dynamicField name="*_s" type="string" indexed="true" stored="true" />
<dynamicField name="*_t" type="text_general" indexed="true" stored="true"/>
</fields>
<uniqueKey>id</uniqueKey>
<copyField source="title" dest="text"/>
<copyField source="description" dest="text"/>
<types>
<fieldType name="string" class="solr.StrField" sortMissingLast="true" />
<fieldType name="boolean" class="solr.BoolField" sortMissingLast="true"/>
<fieldType name="int" class="solr.TrieIntField" precisionStep="0" positionIncrementGap="0"/>
<fieldType name="float" class="solr.TrieFloatField" precisionStep="0" positionIncrementGap="0"/>
<fieldType name="long" class="solr.TrieLongField" precisionStep="0" positionIncrementGap="0"/>
<fieldType name="double" class="solr.TrieDoubleField" precisionStep="0" positionIncrementGap="0"/>
<fieldType name="date" class="solr.TrieDateField" precisionStep="0" positionIncrementGap="0"/>
<fieldtype name="binary" class="solr.BinaryField"/>
<fieldType name="text_ws" class="solr.TextField" positionIncrementGap="100">
<analyzer>
<tokenizer class="solr.WhitespaceTokenizerFactory"/>
</analyzer>
</fieldType>
<fieldType name="text_general" class="solr.TextField" positionIncrementGap="100">
<analyzer type="index">
<tokenizer class="solr.StandardTokenizerFactory"/>
<filter class="solr.StopFilterFactory" ignoreCase="true" words="stopwords.txt" />
<filter class="solr.LowerCaseFilterFactory"/>
</analyzer>
<analyzer type="query">
<tokenizer class="solr.StandardTokenizerFactory"/>
<filter class="solr.StopFilterFactory" ignoreCase="true" words="stopwords.txt" />
<filter class="solr.SynonymFilterFactory" synonyms="synonyms.txt" ignoreCase="true" expand="true"/>
<filter class="solr.LowerCaseFilterFactory"/>
</analyzer>
</fieldType>
<fieldType name="text_lc" class="solr.TextField" positionIncrementGap="100">
<analyzer>
<tokenizer class="solr.KeywordTokenizerFactory"/>
<filter class="solr.LowerCaseFilterFactory" />
</analyzer>
</fieldType>
</types>
<!-- Similarity is the scoring routine for each document vs. a query.
A custom Similarity or SimilarityFactory may be specified here, but
the default is fine for most applications.
For more info: http://wiki.apache.org/solr/SchemaXml#Similarity
-->
<!--
<similarity class="com.example.solr.CustomSimilarityFactory">
<str name="paramkey">param value</str>
</similarity>
-->
</schema>

View File

@ -0,0 +1,65 @@
<?xml version="1.0" ?>
<!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor
license agreements. See the NOTICE file distributed with this work for additional
information regarding copyright ownership. The ASF licenses this file to
You under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License. You may obtain a copy of
the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required
by applicable law or agreed to in writing, software distributed under the
License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS
OF ANY KIND, either express or implied. See the License for the specific
language governing permissions and limitations under the License. -->
<config>
<luceneMatchVersion>6.0.0</luceneMatchVersion>
<dataDir>${solr.data.dir:}</dataDir>
<directoryFactory name="DirectoryFactory"
class="${solr.directoryFactory:solr.RAMDirectoryFactory}" />
<schemaFactory class="ClassicIndexSchemaFactory" />
<!-- Query parser used to rerank top docs with a provided model -->
<queryParser name="ltr"
class="org.apache.solr.search.LTRQParserPlugin" />
<query>
<filterCache class="solr.FastLRUCache" size="4096"
initialSize="2048" autowarmCount="0" />
<cache name="QUERY_DOC_FV" class="solr.search.LRUCache" size="4096"
initialSize="2048" autowarmCount="4096" regenerator="solr.search.NoOpRegenerator" />
</query>
<!-- add a transformer that will encode the document features in the response.
For each document the transformer will add the features as an extra field
in the response. The name of the field we will be the the name of the transformer
enclosed between brackets (in this case [fv]). In order to get the feature
vector you will have to specify that you want the field (e.g., fl="*,[fv]) -->
<transformer name="fv"
class="org.apache.solr.response.transform.LTRFeatureLoggerTransformerFactory" />
<updateHandler class="solr.DirectUpdateHandler2">
<autoCommit>
<maxTime>15000</maxTime>
<openSearcher>false</openSearcher>
</autoCommit>
<autoSoftCommit>
<maxTime>1000</maxTime>
</autoSoftCommit>
<updateLog>
<str name="dir">${solr.data.dir:}</str>
</updateLog>
</updateHandler>
<requestHandler name="/update" class="solr.UpdateRequestHandler" />
<!-- Query request handler managing models and features -->
<requestHandler name="/query" class="solr.SearchHandler">
<lst name="defaults">
<str name="echoParams">explicit</str>
<str name="wt">json</str>
<str name="indent">true</str>
<str name="df">id</str>
</lst>
</requestHandler>
</config>

View File

@ -0,0 +1,69 @@
<?xml version="1.0" ?>
<!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor
license agreements. See the NOTICE file distributed with this work for additional
information regarding copyright ownership. The ASF licenses this file to
You under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License. You may obtain a copy of
the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required
by applicable law or agreed to in writing, software distributed under the
License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS
OF ANY KIND, either express or implied. See the License for the specific
language governing permissions and limitations under the License. -->
<config>
<luceneMatchVersion>6.0.0</luceneMatchVersion>
<dataDir>${solr.data.dir:}</dataDir>
<directoryFactory name="DirectoryFactory"
class="${solr.directoryFactory:solr.RAMDirectoryFactory}" />
<schemaFactory class="ClassicIndexSchemaFactory" />
<!-- Query parser used to rerank top docs with a provided model -->
<queryParser name="ltr" class="org.apache.solr.search.LTRQParserPlugin" >
<int name="threadModule.totalPoolThreads">10</int> <!-- Maximum threads to use for all queries -->
<int name="threadModule.numThreadsPerRequest">10</int> <!-- Maximum threads to use for a single query-->
</queryParser>
<query>
<filterCache class="solr.FastLRUCache" size="4096"
initialSize="2048" autowarmCount="0" />
<cache name="QUERY_DOC_FV" class="solr.search.LRUCache" size="4096"
initialSize="2048" autowarmCount="4096" regenerator="solr.search.NoOpRegenerator" />
</query>
<!-- add a transformer that will encode the document features in the response.
For each document the transformer will add the features as an extra field
in the response. The name of the field we will be the the name of the transformer
enclosed between brackets (in this case [fv]). In order to get the feature
vector you will have to specify that you want the field (e.g., fl="*,[fv]) -->
<transformer name="fv"
class="org.apache.solr.response.transform.LTRFeatureLoggerTransformerFactory" />
<updateHandler class="solr.DirectUpdateHandler2">
<autoCommit>
<maxTime>15000</maxTime>
<openSearcher>false</openSearcher>
</autoCommit>
<autoSoftCommit>
<maxTime>1000</maxTime>
</autoSoftCommit>
<updateLog>
<str name="dir">${solr.data.dir:}</str>
</updateLog>
</updateHandler>
<requestHandler name="/update" class="solr.UpdateRequestHandler" />
<!-- Query request handler managing models and features -->
<requestHandler name="/query" class="solr.SearchHandler">
<lst name="defaults">
<str name="echoParams">explicit</str>
<str name="wt">json</str>
<str name="indent">true</str>
<str name="df">id</str>
</lst>
</requestHandler>
</config>

View File

@ -0,0 +1,62 @@
<?xml version="1.0" ?>
<!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor
license agreements. See the NOTICE file distributed with this work for additional
information regarding copyright ownership. The ASF licenses this file to
You under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License. You may obtain a copy of
the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required
by applicable law or agreed to in writing, software distributed under the
License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS
OF ANY KIND, either express or implied. See the License for the specific
language governing permissions and limitations under the License. -->
<config>
<luceneMatchVersion>6.0.0</luceneMatchVersion>
<dataDir>${solr.data.dir:}</dataDir>
<directoryFactory name="DirectoryFactory"
class="${solr.directoryFactory:solr.RAMDirectoryFactory}" />
<schemaFactory class="ClassicIndexSchemaFactory" />
<!-- Query parser used to rerank top docs with a provided model -->
<queryParser name="ltr" class="org.apache.solr.search.LTRQParserPlugin" />
<maxBufferedDocs>1</maxBufferedDocs>
<mergePolicyFactory class="org.apache.solr.index.TieredMergePolicyFactory">
<int name="maxMergeAtOnce">10</int>
<int name="segmentsPerTier">1000</int>
</mergePolicyFactory>
<!-- add a transformer that will encode the document features in the response.
For each document the transformer will add the features as an extra field
in the response. The name of the field we will be the the name of the transformer
enclosed between brackets (in this case [fv]). In order to get the feature
vector you will have to specify that you want the field (e.g., fl="*,[fv]) -->
<transformer name="features"
class="org.apache.solr.response.transform.LTRFeatureLoggerTransformerFactory" />
<updateHandler class="solr.DirectUpdateHandler2">
<autoCommit>
<maxTime>15000</maxTime>
<openSearcher>false</openSearcher>
</autoCommit>
<autoSoftCommit>
<maxTime>1000</maxTime>
</autoSoftCommit>
<updateLog>
<str name="dir">${solr.data.dir:}</str>
</updateLog>
</updateHandler>
<requestHandler name="/update" class="solr.UpdateRequestHandler" />
<!-- Query request handler managing models and features -->
<requestHandler name="/query" class="solr.SearchHandler">
<lst name="defaults">
<str name="echoParams">explicit</str>
<str name="wt">json</str>
<str name="indent">true</str>
<str name="df">id</str>
</lst>
</requestHandler>
</config>

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
a

View File

@ -0,0 +1,28 @@
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#-----------------------------------------------------------------------
#some test synonym mappings unlikely to appear in real input text
aaafoo => aaabar
bbbfoo => bbbfoo bbbbar
cccfoo => cccbar cccbaz
fooaaa,baraaa,bazaaa
# Some synonym groups specific to this example
GB,gib,gigabyte,gigabytes
MB,mib,megabyte,megabytes
Television, Televisions, TV, TVs
#notice we use "gib" instead of "GiB" so any WordDelimiterFilter coming
#after us won't split it into two words.
# Synonym mappings can be used for spelling correction too
pixima => pixma

View File

@ -0,0 +1,42 @@
<?xml version="1.0" encoding="UTF-8" ?>
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<solr>
<str name="shareSchema">${shareSchema:false}</str>
<str name="configSetBaseDir">${configSetBaseDir:configsets}</str>
<str name="coreRootDirectory">${coreRootDirectory:.}</str>
<shardHandlerFactory name="shardHandlerFactory" class="HttpShardHandlerFactory">
<str name="urlScheme">${urlScheme:}</str>
<int name="socketTimeout">${socketTimeout:90000}</int>
<int name="connTimeout">${connTimeout:15000}</int>
</shardHandlerFactory>
<solrcloud>
<str name="host">127.0.0.1</str>
<int name="hostPort">${hostPort:8983}</int>
<str name="hostContext">${hostContext:solr}</str>
<int name="zkClientTimeout">${solr.zkclienttimeout:30000}</int>
<bool name="genericCoreNodeNames">${genericCoreNodeNames:true}</bool>
<int name="leaderVoteWait">${leaderVoteWait:10000}</int>
<int name="distribUpdateConnTimeout">${distribUpdateConnTimeout:45000}</int>
<int name="distribUpdateSoTimeout">${distribUpdateSoTimeout:340000}</int>
</solrcloud>
</solr>

View File

@ -0,0 +1,211 @@
/* * Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr;
import java.io.File;
import java.util.SortedMap;
import org.apache.commons.io.FileUtils;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.client.solrj.embedded.JettyConfig;
import org.apache.solr.client.solrj.request.CollectionAdminRequest;
import org.apache.solr.client.solrj.response.CollectionAdminResponse;
import org.apache.solr.client.solrj.response.QueryResponse;
import org.apache.solr.cloud.AbstractDistribZkTestBase;
import org.apache.solr.cloud.MiniSolrCloudCluster;
import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.common.cloud.ZkStateReader;
import org.apache.solr.ltr.feature.SolrFeature;
import org.apache.solr.ltr.feature.ValueFeature;
import org.apache.solr.ltr.model.LinearModel;
import org.eclipse.jetty.servlet.ServletHolder;
import org.junit.AfterClass;
import org.junit.Test;
public class TestLTROnSolrCloud extends TestRerankBase {
private MiniSolrCloudCluster solrCluster;
String solrconfig = "solrconfig-ltr.xml";
String schema = "schema.xml";
SortedMap<ServletHolder,String> extraServlets = null;
@Override
public void setUp() throws Exception {
super.setUp();
extraServlets = setupTestInit(solrconfig, schema, true);
System.setProperty("enable.update.log", "true");
int numberOfShards = random().nextInt(4)+1;
int numberOfReplicas = random().nextInt(2)+1;
int maxShardsPerNode = numberOfShards+random().nextInt(4)+1;
int numberOfNodes = numberOfShards * maxShardsPerNode;
setupSolrCluster(numberOfShards, numberOfReplicas, numberOfNodes, maxShardsPerNode);
}
@Override
public void tearDown() throws Exception {
restTestHarness.close();
restTestHarness = null;
jetty.stop();
jetty = null;
solrCluster.shutdown();
super.tearDown();
}
@Test
public void testSimpleQuery() throws Exception {
// will randomly pick a configuration with [1..5] shards and [1..3] replicas
// Test regular query, it will sort the documents by inverse
// popularity (the less popular, docid == 1, will be in the first
// position
SolrQuery query = new SolrQuery("{!func}sub(8,field(popularity))");
query.setRequestHandler("/query");
query.setFields("*,score");
query.setParam("rows", "8");
QueryResponse queryResponse =
solrCluster.getSolrClient().query(COLLECTION,query);
assertEquals(8, queryResponse.getResults().getNumFound());
assertEquals("1", queryResponse.getResults().get(0).get("id").toString());
assertEquals("2", queryResponse.getResults().get(1).get("id").toString());
assertEquals("3", queryResponse.getResults().get(2).get("id").toString());
assertEquals("4", queryResponse.getResults().get(3).get("id").toString());
// Test re-rank and feature vectors returned
query.setFields("*,score,features:[fv]");
query.add("rq", "{!ltr model=powpularityS-model reRankDocs=8}");
queryResponse =
solrCluster.getSolrClient().query(COLLECTION,query);
assertEquals(8, queryResponse.getResults().getNumFound());
assertEquals("8", queryResponse.getResults().get(0).get("id").toString());
assertEquals("powpularityS:64.0;c3:2.0",
queryResponse.getResults().get(0).get("features").toString());
assertEquals("7", queryResponse.getResults().get(1).get("id").toString());
assertEquals("powpularityS:49.0;c3:2.0",
queryResponse.getResults().get(1).get("features").toString());
assertEquals("6", queryResponse.getResults().get(2).get("id").toString());
assertEquals("powpularityS:36.0;c3:2.0",
queryResponse.getResults().get(2).get("features").toString());
assertEquals("5", queryResponse.getResults().get(3).get("id").toString());
assertEquals("powpularityS:25.0;c3:2.0",
queryResponse.getResults().get(3).get("features").toString());
}
private void setupSolrCluster(int numShards, int numReplicas, int numServers, int maxShardsPerNode) throws Exception {
JettyConfig jc = buildJettyConfig("/solr");
jc = JettyConfig.builder(jc).withServlets(extraServlets).build();
solrCluster = new MiniSolrCloudCluster(numServers, tmpSolrHome.toPath(), jc);
File configDir = tmpSolrHome.toPath().resolve("collection1/conf").toFile();
solrCluster.uploadConfigSet(configDir.toPath(), "conf1");
solrCluster.getSolrClient().setDefaultCollection(COLLECTION);
createCollection(COLLECTION, "conf1", numShards, numReplicas, maxShardsPerNode);
indexDocuments(COLLECTION);
createJettyAndHarness(tmpSolrHome.getAbsolutePath(), solrconfig, schema,
"/solr", true, extraServlets);
loadModelsAndFeatures();
}
private void createCollection(String name, String config, int numShards, int numReplicas, int maxShardsPerNode)
throws Exception {
CollectionAdminResponse response;
CollectionAdminRequest.Create create =
CollectionAdminRequest.createCollection(name, config, numShards, numReplicas);
create.setMaxShardsPerNode(maxShardsPerNode);
response = create.process(solrCluster.getSolrClient());
if (response.getStatus() != 0 || response.getErrorMessages() != null) {
fail("Could not create collection. Response" + response.toString());
}
ZkStateReader zkStateReader = solrCluster.getSolrClient().getZkStateReader();
AbstractDistribZkTestBase.waitForRecoveriesToFinish(name, zkStateReader, false, true, 100);
}
void indexDocument(String collection, String id, String title, String description, int popularity)
throws Exception{
SolrInputDocument doc = new SolrInputDocument();
doc.setField("id", id);
doc.setField("title", title);
doc.setField("description", description);
doc.setField("popularity", popularity);
solrCluster.getSolrClient().add(collection, doc);
}
private void indexDocuments(final String collection)
throws Exception {
final int collectionSize = 8;
for (int docId = 1; docId <= collectionSize; docId++) {
final int popularity = docId;
indexDocument(collection, String.valueOf(docId), "a1", "bloom", popularity);
}
solrCluster.getSolrClient().commit(collection);
}
private void loadModelsAndFeatures() throws Exception {
final String featureStore = "test";
final String[] featureNames = new String[] {"powpularityS","c3"};
final String jsonModelParams = "{\"weights\":{\"powpularityS\":1.0,\"c3\":1.0}}";
loadFeature(
featureNames[0],
SolrFeature.class.getCanonicalName(),
featureStore,
"{\"q\":\"{!func}pow(popularity,2)\"}"
);
loadFeature(
featureNames[1],
ValueFeature.class.getCanonicalName(),
featureStore,
"{\"value\":2}"
);
loadModel(
"powpularityS-model",
LinearModel.class.getCanonicalName(),
featureNames,
featureStore,
jsonModelParams
);
reloadCollection(COLLECTION);
}
private void reloadCollection(String collection) throws Exception {
CollectionAdminRequest.Reload reloadRequest = CollectionAdminRequest.reloadCollection(collection);
CollectionAdminResponse response = reloadRequest.process(solrCluster.getSolrClient());
assertEquals(0, response.getStatus());
assertTrue(response.isSuccess());
}
@AfterClass
public static void after() throws Exception {
FileUtils.deleteDirectory(tmpSolrHome);
System.clearProperty("managed.schema.mutable");
}
}

View File

@ -0,0 +1,152 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.ltr.model.LinearModel;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
public class TestLTRQParserExplain extends TestRerankBase {
@BeforeClass
public static void setup() throws Exception {
setuptest();
loadFeatures("features-store-test-model.json");
}
@AfterClass
public static void after() throws Exception {
aftertest();
}
@Test
public void testRerankedExplain() throws Exception {
loadModel("linear2", LinearModel.class.getCanonicalName(), new String[] {
"constant1", "constant2", "pop"},
"{\"weights\":{\"pop\":1.0,\"constant1\":1.5,\"constant2\":3.5}}");
final SolrQuery query = new SolrQuery();
query.setQuery("title:bloomberg");
query.setParam("debugQuery", "on");
query.add("rows", "2");
query.add("rq", "{!ltr reRankDocs=2 model=linear2}");
query.add("fl", "*,score");
assertJQ(
"/query" + query.toQueryString(),
"/debug/explain/9=='\n13.5 = LinearModel(name=linear2,featureWeights=[constant1=1.5,constant2=3.5,pop=1.0]) model applied to features, sum of:\n 1.5 = prod of:\n 1.5 = weight on feature\n 1.0 = ValueFeature [name=constant1, params={value=1}]\n 7.0 = prod of:\n 3.5 = weight on feature\n 2.0 = ValueFeature [name=constant2, params={value=2}]\n 5.0 = prod of:\n 1.0 = weight on feature\n 5.0 = FieldValueFeature [name=pop, params={field=popularity}]\n'");
}
@Test
public void testRerankedExplainSameBetweenDifferentDocsWithSameFeatures() throws Exception {
loadFeatures("features-linear.json");
loadModels("linear-model.json");
final SolrQuery query = new SolrQuery();
query.setQuery("title:bloomberg");
query.setParam("debugQuery", "on");
query.add("rows", "4");
query.add("rq", "{!ltr reRankDocs=4 model=6029760550880411648}");
query.add("fl", "*,score");
query.add("wt", "json");
final String expectedExplainNormalizer = "normalized using MinMaxNormalizer(min=0.0,max=10.0)";
final String expectedExplain = "\n3.5116758 = LinearModel(name=6029760550880411648,featureWeights=["
+ "title=0.0,"
+ "description=0.1,"
+ "keywords=0.2,"
+ "popularity=0.3,"
+ "text=0.4,"
+ "queryIntentPerson=0.1231231,"
+ "queryIntentCompany=0.12121211"
+ "]) model applied to features, sum of:\n 0.0 = prod of:\n 0.0 = weight on feature\n 1.0 = ValueFeature [name=title, params={value=1}]\n 0.2 = prod of:\n 0.1 = weight on feature\n 2.0 = ValueFeature [name=description, params={value=2}]\n 0.4 = prod of:\n 0.2 = weight on feature\n 2.0 = ValueFeature [name=keywords, params={value=2}]\n 0.09 = prod of:\n 0.3 = weight on feature\n 0.3 = "+expectedExplainNormalizer+"\n 3.0 = ValueFeature [name=popularity, params={value=3}]\n 1.6 = prod of:\n 0.4 = weight on feature\n 4.0 = ValueFeature [name=text, params={value=4}]\n 0.6156155 = prod of:\n 0.1231231 = weight on feature\n 5.0 = ValueFeature [name=queryIntentPerson, params={value=5}]\n 0.60606056 = prod of:\n 0.12121211 = weight on feature\n 5.0 = ValueFeature [name=queryIntentCompany, params={value=5}]\n";
assertJQ(
"/query" + query.toQueryString(),
"/debug/explain/7=='"+expectedExplain+"'}");
assertJQ(
"/query" + query.toQueryString(),
"/debug/explain/9=='"+expectedExplain+"'}");
}
@Test
public void LinearScoreExplainMissingEfiFeatureShouldReturnDefaultScore() throws Exception {
loadFeatures("features-linear-efi.json");
loadModels("linear-model-efi.json");
SolrQuery query = new SolrQuery();
query.setQuery("title:bloomberg");
query.setParam("debugQuery", "on");
query.add("rows", "4");
query.add("rq", "{!ltr reRankDocs=4 model=linear-efi}");
query.add("fl", "*,score");
query.add("wt", "xml");
final String linearModelEfiString = "LinearModel(name=linear-efi,featureWeights=["
+ "sampleConstant=1.0,"
+ "search_number_of_nights=2.0])";
query.remove("wt");
query.add("wt", "json");
assertJQ(
"/query" + query.toQueryString(),
"/debug/explain/7=='\n5.0 = "+linearModelEfiString+" model applied to features, sum of:\n 5.0 = prod of:\n 1.0 = weight on feature\n 5.0 = ValueFeature [name=sampleConstant, params={value=5}]\n" +
" 0.0 = prod of:\n" +
" 2.0 = weight on feature\n" +
" 0.0 = The feature has no value\n'}");
assertJQ(
"/query" + query.toQueryString(),
"/debug/explain/9=='\n5.0 = "+linearModelEfiString+" model applied to features, sum of:\n 5.0 = prod of:\n 1.0 = weight on feature\n 5.0 = ValueFeature [name=sampleConstant, params={value=5}]\n" +
" 0.0 = prod of:\n" +
" 2.0 = weight on feature\n" +
" 0.0 = The feature has no value\n'}");
}
@Test
public void multipleAdditiveTreesScoreExplainMissingEfiFeatureShouldReturnDefaultScore() throws Exception {
loadFeatures("external_features_for_sparse_processing.json");
loadModels("multipleadditivetreesmodel_external_binary_features.json");
SolrQuery query = new SolrQuery();
query.setQuery("title:bloomberg");
query.setParam("debugQuery", "on");
query.add("rows", "4");
query.add("rq", "{!ltr reRankDocs=4 model=external_model_binary_feature efi.user_device_tablet=1}");
query.add("fl", "*,score");
final String tree1 = "(weight=1.0,root=(feature=user_device_smartphone,threshold=0.5,left=0.0,right=50.0))";
final String tree2 = "(weight=1.0,root=(feature=user_device_tablet,threshold=0.5,left=0.0,right=65.0))";
final String trees = "["+tree1+","+tree2+"]";
query.add("wt", "json");
assertJQ(
"/query" + query.toQueryString(),
"/debug/explain/7=='\n" +
"65.0 = MultipleAdditiveTreesModel(name=external_model_binary_feature,trees="+trees+") model applied to features, sum of:\n" +
" 0.0 = tree 0 | \\'user_device_smartphone\\':0.0 <= 0.500001, Go Left | val: 0.0\n" +
" 65.0 = tree 1 | \\'user_device_tablet\\':1.0 > 0.500001, Go Right | val: 65.0\n'}");
assertJQ(
"/query" + query.toQueryString(),
"/debug/explain/9=='\n" +
"65.0 = MultipleAdditiveTreesModel(name=external_model_binary_feature,trees="+trees+") model applied to features, sum of:\n" +
" 0.0 = tree 0 | \\'user_device_smartphone\\':0.0 <= 0.500001, Go Left | val: 0.0\n" +
" 65.0 = tree 1 | \\'user_device_tablet\\':1.0 > 0.500001, Go Right | val: 65.0\n'}");
}
}

View File

@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr;
import org.apache.solr.client.solrj.SolrQuery;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
public class TestLTRQParserPlugin extends TestRerankBase {
@BeforeClass
public static void before() throws Exception {
setuptest("solrconfig-ltr.xml", "schema.xml");
// store = getModelStore();
bulkIndex();
loadFeatures("features-linear.json");
loadModels("linear-model.json");
}
@AfterClass
public static void after() throws Exception {
aftertest();
// store.clear();
}
@Test
public void ltrModelIdMissingTest() throws Exception {
final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.add("fl", "*, score");
query.add("rows", "4");
query.add("fv", "true");
query.add("rq", "{!ltr reRankDocs=100}");
final String res = restTestHarness.query("/query" + query.toQueryString());
assert (res.contains("Must provide model in the request"));
}
@Test
public void ltrModelIdDoesNotExistTest() throws Exception {
final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.add("fl", "*, score");
query.add("rows", "4");
query.add("fv", "true");
query.add("rq", "{!ltr model=-1 reRankDocs=100}");
final String res = restTestHarness.query("/query" + query.toQueryString());
assert (res.contains("cannot find model"));
}
@Test
public void ltrMoreResultsThanReRankedTest() throws Exception {
final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.add("fl", "*, score");
query.add("rows", "4");
query.add("fv", "true");
String nonRerankedScore = "0.09271725";
// Normal solr order
assertJQ("/query" + query.toQueryString(),
"/response/docs/[0]/id=='9'",
"/response/docs/[1]/id=='8'",
"/response/docs/[2]/id=='7'",
"/response/docs/[3]/id=='6'",
"/response/docs/[3]/score=="+nonRerankedScore
);
query.add("rq", "{!ltr model=6029760550880411648 reRankDocs=3}");
// Different order for top 3 reranked, but last one is the same top nonreranked doc
assertJQ("/query" + query.toQueryString(),
"/response/docs/[0]/id=='7'",
"/response/docs/[1]/id=='8'",
"/response/docs/[2]/id=='9'",
"/response/docs/[3]/id=='6'",
"/response/docs/[3]/score=="+nonRerankedScore
);
}
@Test
public void ltrNoResultsTest() throws Exception {
final SolrQuery query = new SolrQuery();
query.setQuery("title:bloomberg23");
query.add("fl", "*,[fv]");
query.add("rows", "3");
query.add("debugQuery", "on");
query.add("rq", "{!ltr reRankDocs=3 model=6029760550880411648}");
assertJQ("/query" + query.toQueryString(), "/response/numFound/==0");
}
}

View File

@ -0,0 +1,300 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr;
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FloatDocValuesField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.FieldValueFeature;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.model.TestLinearModel;
import org.apache.solr.ltr.norm.IdentityNormalizer;
import org.apache.solr.ltr.norm.Normalizer;
import org.junit.Ignore;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class TestLTRReRankingPipeline extends LuceneTestCase {
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
private static final SolrResourceLoader solrResourceLoader = new SolrResourceLoader();
private IndexSearcher getSearcher(IndexReader r) {
final IndexSearcher searcher = newSearcher(r);
return searcher;
}
private static List<Feature> makeFieldValueFeatures(int[] featureIds,
String field) {
final List<Feature> features = new ArrayList<>();
for (final int i : featureIds) {
final Map<String,Object> params = new HashMap<String,Object>();
params.put("field", field);
final Feature f = Feature.getInstance(solrResourceLoader,
FieldValueFeature.class.getCanonicalName(),
"f" + i, params);
f.setIndex(i);
features.add(f);
}
return features;
}
private class MockModel extends LTRScoringModel {
public MockModel(String name, List<Feature> features,
List<Normalizer> norms,
String featureStoreName, List<Feature> allFeatures,
Map<String,Object> params) {
super(name, features, norms, featureStoreName, allFeatures, params);
}
@Override
public float score(float[] modelFeatureValuesNormalized) {
return modelFeatureValuesNormalized[2];
}
@Override
public Explanation explain(LeafReaderContext context, int doc,
float finalScore, List<Explanation> featureExplanations) {
return null;
}
}
@Ignore
@Test
public void testRescorer() throws IOException {
final Directory dir = newDirectory();
final RandomIndexWriter w = new RandomIndexWriter(random(), dir);
Document doc = new Document();
doc.add(newStringField("id", "0", Field.Store.YES));
doc.add(newTextField("field", "wizard the the the the the oz",
Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 1.0f));
w.addDocument(doc);
doc = new Document();
doc.add(newStringField("id", "1", Field.Store.YES));
// 1 extra token, but wizard and oz are close;
doc.add(newTextField("field", "wizard oz the the the the the the",
Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 2.0f));
w.addDocument(doc);
final IndexReader r = w.getReader();
w.close();
// Do ordinary BooleanQuery:
final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
final IndexSearcher searcher = getSearcher(r);
// first run the standard query
TopDocs hits = searcher.search(bqBuilder.build(), 10);
assertEquals(2, hits.totalHits);
assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
final List<Feature> features = makeFieldValueFeatures(new int[] {0, 1, 2},
"final-score");
final List<Normalizer> norms =
new ArrayList<Normalizer>(
Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
final List<Feature> allFeatures = makeFieldValueFeatures(new int[] {0, 1,
2, 3, 4, 5, 6, 7, 8, 9}, "final-score");
final LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test",
features, norms, "test", allFeatures, null);
final LTRRescorer rescorer = new LTRRescorer(new LTRScoringQuery(ltrScoringModel));
hits = rescorer.rescore(searcher, hits, 2);
// rerank using the field final-score
assertEquals("1", searcher.doc(hits.scoreDocs[0].doc).get("id"));
assertEquals("0", searcher.doc(hits.scoreDocs[1].doc).get("id"));
r.close();
dir.close();
}
@Ignore
@Test
public void testDifferentTopN() throws IOException {
final Directory dir = newDirectory();
final RandomIndexWriter w = new RandomIndexWriter(random(), dir);
Document doc = new Document();
doc.add(newStringField("id", "0", Field.Store.YES));
doc.add(newTextField("field", "wizard oz oz oz oz oz", Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 1.0f));
w.addDocument(doc);
doc = new Document();
doc.add(newStringField("id", "1", Field.Store.YES));
doc.add(newTextField("field", "wizard oz oz oz oz the", Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 2.0f));
w.addDocument(doc);
doc = new Document();
doc.add(newStringField("id", "2", Field.Store.YES));
doc.add(newTextField("field", "wizard oz oz oz the the ", Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 3.0f));
w.addDocument(doc);
doc = new Document();
doc.add(newStringField("id", "3", Field.Store.YES));
doc.add(newTextField("field", "wizard oz oz the the the the ",
Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 4.0f));
w.addDocument(doc);
doc = new Document();
doc.add(newStringField("id", "4", Field.Store.YES));
doc.add(newTextField("field", "wizard oz the the the the the the",
Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 5.0f));
w.addDocument(doc);
final IndexReader r = w.getReader();
w.close();
// Do ordinary BooleanQuery:
final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
final IndexSearcher searcher = getSearcher(r);
// first run the standard query
TopDocs hits = searcher.search(bqBuilder.build(), 10);
assertEquals(5, hits.totalHits);
assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id"));
assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id"));
assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id"));
final List<Feature> features = makeFieldValueFeatures(new int[] {0, 1, 2},
"final-score");
final List<Normalizer> norms =
new ArrayList<Normalizer>(
Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
final List<Feature> allFeatures = makeFieldValueFeatures(new int[] {0, 1,
2, 3, 4, 5, 6, 7, 8, 9}, "final-score");
final LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test",
features, norms, "test", allFeatures, null);
final LTRRescorer rescorer = new LTRRescorer(new LTRScoringQuery(ltrScoringModel));
// rerank @ 0 should not change the order
hits = rescorer.rescore(searcher, hits, 0);
assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id"));
assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id"));
assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id"));
// test rerank with different topN cuts
for (int topN = 1; topN <= 5; topN++) {
log.info("rerank {} documents ", topN);
hits = searcher.search(bqBuilder.build(), 10);
final ScoreDoc[] slice = new ScoreDoc[topN];
System.arraycopy(hits.scoreDocs, 0, slice, 0, topN);
hits = new TopDocs(hits.totalHits, slice, hits.getMaxScore());
hits = rescorer.rescore(searcher, hits, topN);
for (int i = topN - 1, j = 0; i >= 0; i--, j++) {
log.info("doc {} in pos {}", searcher.doc(hits.scoreDocs[j].doc)
.get("id"), j);
assertEquals(i,
Integer.parseInt(searcher.doc(hits.scoreDocs[j].doc).get("id")));
assertEquals(i + 1, hits.scoreDocs[j].score, 0.00001);
}
}
r.close();
dir.close();
}
@Test
public void testDocParam() throws Exception {
final Map<String,Object> test = new HashMap<String,Object>();
test.put("fake", 2);
List<Feature> features = makeFieldValueFeatures(new int[] {0},
"final-score");
List<Normalizer> norms =
new ArrayList<Normalizer>(
Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
List<Feature> allFeatures = makeFieldValueFeatures(new int[] {0},
"final-score");
MockModel ltrScoringModel = new MockModel("test",
features, norms, "test", allFeatures, null);
LTRScoringQuery query = new LTRScoringQuery(ltrScoringModel);
LTRScoringQuery.ModelWeight wgt = query.createWeight(null, true, 1f);
LTRScoringQuery.ModelWeight.ModelScorer modelScr = wgt.scorer(null);
modelScr.getDocInfo().setOriginalDocScore(new Float(1f));
for (final Scorer.ChildScorer feat : modelScr.getChildren()) {
assertNotNull(((Feature.FeatureWeight.FeatureScorer) feat.child).getDocInfo().getOriginalDocScore());
}
features = makeFieldValueFeatures(new int[] {0, 1, 2}, "final-score");
norms =
new ArrayList<Normalizer>(
Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
allFeatures = makeFieldValueFeatures(new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8,
9}, "final-score");
ltrScoringModel = new MockModel("test", features, norms,
"test", allFeatures, null);
query = new LTRScoringQuery(ltrScoringModel);
wgt = query.createWeight(null, true, 1f);
modelScr = wgt.scorer(null);
modelScr.getDocInfo().setOriginalDocScore(new Float(1f));
for (final Scorer.ChildScorer feat : modelScr.getChildren()) {
assertNotNull(((Feature.FeatureWeight.FeatureScorer) feat.child).getDocInfo().getOriginalDocScore());
}
}
}

View File

@ -0,0 +1,319 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FloatDocValuesField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.ValueFeature;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.model.ModelException;
import org.apache.solr.ltr.model.TestLinearModel;
import org.apache.solr.ltr.norm.IdentityNormalizer;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.ltr.norm.NormalizerException;
import org.junit.Test;
public class TestLTRScoringQuery extends LuceneTestCase {
public final static SolrResourceLoader solrResourceLoader = new SolrResourceLoader();
private IndexSearcher getSearcher(IndexReader r) {
final IndexSearcher searcher = newSearcher(r, false, false);
return searcher;
}
private static List<Feature> makeFeatures(int[] featureIds) {
final List<Feature> features = new ArrayList<>();
for (final int i : featureIds) {
Map<String,Object> params = new HashMap<String,Object>();
params.put("value", i);
final Feature f = Feature.getInstance(solrResourceLoader,
ValueFeature.class.getCanonicalName(),
"f" + i, params);
f.setIndex(i);
features.add(f);
}
return features;
}
private static List<Feature> makeFilterFeatures(int[] featureIds) {
final List<Feature> features = new ArrayList<>();
for (final int i : featureIds) {
Map<String,Object> params = new HashMap<String,Object>();
params.put("value", i);
final Feature f = Feature.getInstance(solrResourceLoader,
ValueFeature.class.getCanonicalName(),
"f" + i, params);
f.setIndex(i);
features.add(f);
}
return features;
}
private static Map<String,Object> makeFeatureWeights(List<Feature> features) {
final Map<String,Object> nameParams = new HashMap<String,Object>();
final HashMap<String,Double> modelWeights = new HashMap<String,Double>();
for (final Feature feat : features) {
modelWeights.put(feat.getName(), 0.1);
}
nameParams.put("weights", modelWeights);
return nameParams;
}
private LTRScoringQuery.ModelWeight performQuery(TopDocs hits,
IndexSearcher searcher, int docid, LTRScoringQuery model) throws IOException,
ModelException {
final List<LeafReaderContext> leafContexts = searcher.getTopReaderContext()
.leaves();
final int n = ReaderUtil.subIndex(hits.scoreDocs[0].doc, leafContexts);
final LeafReaderContext context = leafContexts.get(n);
final int deBasedDoc = hits.scoreDocs[0].doc - context.docBase;
final Weight weight = searcher.createNormalizedWeight(model, true);
final Scorer scorer = weight.scorer(context);
// rerank using the field final-score
scorer.iterator().advance(deBasedDoc);
scorer.score();
// assertEquals(42.0f, score, 0.0001);
// assertTrue(weight instanceof AssertingWeight);
// (AssertingIndexSearcher)
assertTrue(weight instanceof LTRScoringQuery.ModelWeight);
final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) weight;
return modelWeight;
}
@Test
public void testLTRScoringQueryEquality() throws ModelException {
final List<Feature> features = makeFeatures(new int[] {0, 1, 2});
final List<Normalizer> norms =
new ArrayList<Normalizer>(
Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
final List<Feature> allFeatures = makeFeatures(
new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
final Map<String,Object> modelParams = makeFeatureWeights(features);
final LTRScoringModel algorithm1 = TestLinearModel.createLinearModel(
"testModelName",
features, norms, "testStoreName", allFeatures, modelParams);
final LTRScoringQuery m0 = new LTRScoringQuery(algorithm1);
final HashMap<String,String[]> externalFeatureInfo = new HashMap<>();
externalFeatureInfo.put("queryIntent", new String[] {"company"});
externalFeatureInfo.put("user_query", new String[] {"abc"});
final LTRScoringQuery m1 = new LTRScoringQuery(algorithm1, externalFeatureInfo, false, null);
final HashMap<String,String[]> externalFeatureInfo2 = new HashMap<>();
externalFeatureInfo2.put("user_query", new String[] {"abc"});
externalFeatureInfo2.put("queryIntent", new String[] {"company"});
int totalPoolThreads = 10, numThreadsPerRequest = 10;
LTRThreadModule threadManager = new LTRThreadModule(totalPoolThreads, numThreadsPerRequest);
final LTRScoringQuery m2 = new LTRScoringQuery(algorithm1, externalFeatureInfo2, false, threadManager);
// Models with same algorithm and efis, just in different order should be the same
assertEquals(m1, m2);
assertEquals(m1.hashCode(), m2.hashCode());
// Models with same algorithm, but different efi content should not match
assertFalse(m1.equals(m0));
assertFalse(m1.hashCode() == m0.hashCode());
final LTRScoringModel algorithm2 = TestLinearModel.createLinearModel(
"testModelName2",
features, norms, "testStoreName", allFeatures, modelParams);
final LTRScoringQuery m3 = new LTRScoringQuery(algorithm2);
assertFalse(m1.equals(m3));
assertFalse(m1.hashCode() == m3.hashCode());
final LTRScoringModel algorithm3 = TestLinearModel.createLinearModel(
"testModelName",
features, norms, "testStoreName3", allFeatures, modelParams);
final LTRScoringQuery m4 = new LTRScoringQuery(algorithm3);
assertFalse(m1.equals(m4));
assertFalse(m1.hashCode() == m4.hashCode());
}
@Test
public void testLTRScoringQuery() throws IOException, ModelException {
final Directory dir = newDirectory();
final RandomIndexWriter w = new RandomIndexWriter(random(), dir);
Document doc = new Document();
doc.add(newStringField("id", "0", Field.Store.YES));
doc.add(newTextField("field", "wizard the the the the the oz",
Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 1.0f));
w.addDocument(doc);
doc = new Document();
doc.add(newStringField("id", "1", Field.Store.YES));
// 1 extra token, but wizard and oz are close;
doc.add(newTextField("field", "wizard oz the the the the the the",
Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 2.0f));
w.addDocument(doc);
final IndexReader r = w.getReader();
w.close();
// Do ordinary BooleanQuery:
final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
final IndexSearcher searcher = getSearcher(r);
// first run the standard query
final TopDocs hits = searcher.search(bqBuilder.build(), 10);
assertEquals(2, hits.totalHits);
assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
List<Feature> features = makeFeatures(new int[] {0, 1, 2});
final List<Feature> allFeatures = makeFeatures(new int[] {0, 1, 2, 3, 4, 5,
6, 7, 8, 9});
List<Normalizer> norms =
new ArrayList<Normalizer>(
Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test",
features, norms, "test", allFeatures,
makeFeatureWeights(features));
LTRScoringQuery.ModelWeight modelWeight = performQuery(hits, searcher,
hits.scoreDocs[0].doc, new LTRScoringQuery(ltrScoringModel));
assertEquals(3, modelWeight.getModelFeatureValuesNormalized().length);
for (int i = 0; i < 3; i++) {
assertEquals(i, modelWeight.getModelFeatureValuesNormalized()[i], 0.0001);
}
int[] posVals = new int[] {0, 1, 2};
int pos = 0;
for (LTRScoringQuery.FeatureInfo fInfo:modelWeight.getFeaturesInfo()) {
if (fInfo == null){
continue;
}
assertEquals(posVals[pos], fInfo.getValue(), 0.0001);
assertEquals("f"+posVals[pos], fInfo.getName());
pos++;
}
final int[] mixPositions = new int[] {8, 2, 4, 9, 0};
features = makeFeatures(mixPositions);
norms =
new ArrayList<Normalizer>(
Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
ltrScoringModel = TestLinearModel.createLinearModel("test",
features, norms, "test", allFeatures, makeFeatureWeights(features));
modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc,
new LTRScoringQuery(ltrScoringModel));
assertEquals(mixPositions.length,
modelWeight.getModelFeatureWeights().length);
for (int i = 0; i < mixPositions.length; i++) {
assertEquals(mixPositions[i],
modelWeight.getModelFeatureValuesNormalized()[i], 0.0001);
}
final ModelException expectedModelException = new ModelException("no features declared for model test");
final int[] noPositions = new int[] {};
features = makeFeatures(noPositions);
norms =
new ArrayList<Normalizer>(
Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
try {
ltrScoringModel = TestLinearModel.createLinearModel("test",
features, norms, "test", allFeatures, makeFeatureWeights(features));
fail("unexpectedly got here instead of catching "+expectedModelException);
modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc,
new LTRScoringQuery(ltrScoringModel));
assertEquals(0, modelWeight.getModelFeatureWeights().length);
} catch (ModelException actualModelException) {
assertEquals(expectedModelException.toString(), actualModelException.toString());
}
// test normalizers
features = makeFilterFeatures(mixPositions);
final Normalizer norm = new Normalizer() {
@Override
public float normalize(float value) {
return 42.42f;
}
@Override
public LinkedHashMap<String,Object> paramsToMap() {
return null;
}
@Override
protected void validate() throws NormalizerException {
}
};
norms =
new ArrayList<Normalizer>(
Collections.nCopies(features.size(),norm));
final LTRScoringModel normMeta = TestLinearModel.createLinearModel("test",
features, norms, "test", allFeatures,
makeFeatureWeights(features));
modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc,
new LTRScoringQuery(normMeta));
normMeta.normalizeFeaturesInPlace(modelWeight.getModelFeatureValuesNormalized());
assertEquals(mixPositions.length,
modelWeight.getModelFeatureWeights().length);
for (int i = 0; i < mixPositions.length; i++) {
assertEquals(42.42f, modelWeight.getModelFeatureValuesNormalized()[i], 0.0001);
}
r.close();
dir.close();
}
}

View File

@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.ltr.feature.SolrFeature;
import org.apache.solr.ltr.model.LinearModel;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
public class TestLTRWithFacet extends TestRerankBase {
@BeforeClass
public static void before() throws Exception {
setuptest("solrconfig-ltr.xml", "schema.xml");
assertU(adoc("id", "1", "title", "a1", "description", "E", "popularity",
"1"));
assertU(adoc("id", "2", "title", "a1 b1", "description",
"B", "popularity", "2"));
assertU(adoc("id", "3", "title", "a1 b1 c1", "description", "B", "popularity",
"3"));
assertU(adoc("id", "4", "title", "a1 b1 c1 d1", "description", "B", "popularity",
"4"));
assertU(adoc("id", "5", "title", "a1 b1 c1 d1 e1", "description", "E", "popularity",
"5"));
assertU(adoc("id", "6", "title", "a1 b1 c1 d1 e1 f1", "description", "B",
"popularity", "6"));
assertU(adoc("id", "7", "title", "a1 b1 c1 d1 e1 f1 g1", "description",
"C", "popularity", "7"));
assertU(adoc("id", "8", "title", "a1 b1 c1 d1 e1 f1 g1 h1", "description",
"D", "popularity", "8"));
assertU(commit());
}
@Test
public void testRankingSolrFacet() throws Exception {
// before();
loadFeature("powpularityS", SolrFeature.class.getCanonicalName(),
"{\"q\":\"{!func}pow(popularity,2)\"}");
loadModel("powpularityS-model", LinearModel.class.getCanonicalName(),
new String[] {"powpularityS"}, "{\"weights\":{\"powpularityS\":1.0}}");
final SolrQuery query = new SolrQuery();
query.setQuery("title:a1");
query.add("fl", "*, score");
query.add("rows", "4");
query.add("facet", "true");
query.add("facet.field", "description");
assertJQ("/query" + query.toQueryString(), "/response/numFound/==8");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='2'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='3'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='4'");
// Normal term match
assertJQ("/query" + query.toQueryString(), ""
+ "/facet_counts/facet_fields/description=="
+ "['b', 4, 'e', 2, 'c', 1, 'd', 1]");
query.add("rq", "{!ltr model=powpularityS-model reRankDocs=4}");
query.set("debugQuery", "on");
assertJQ("/query" + query.toQueryString(), "/response/numFound/==8");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='4'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==16.0");
assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='3'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==9.0");
assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='2'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==4.0");
assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='1'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/score==1.0");
assertJQ("/query" + query.toQueryString(), ""
+ "/facet_counts/facet_fields/description=="
+ "['b', 4, 'e', 2, 'c', 1, 'd', 1]");
// aftertest();
}
@AfterClass
public static void after() throws Exception {
aftertest();
}
}

View File

@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.ltr.feature.SolrFeature;
import org.apache.solr.ltr.model.LinearModel;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
public class TestLTRWithSort extends TestRerankBase {
@BeforeClass
public static void before() throws Exception {
setuptest("solrconfig-ltr.xml", "schema.xml");
assertU(adoc("id", "1", "title", "a1", "description", "E", "popularity",
"1"));
assertU(adoc("id", "2", "title", "a1 b1", "description",
"B", "popularity", "2"));
assertU(adoc("id", "3", "title", "a1 b1 c1", "description", "B", "popularity",
"3"));
assertU(adoc("id", "4", "title", "a1 b1 c1 d1", "description", "B", "popularity",
"4"));
assertU(adoc("id", "5", "title", "a1 b1 c1 d1 e1", "description", "E", "popularity",
"5"));
assertU(adoc("id", "6", "title", "a1 b1 c1 d1 e1 f1", "description", "B",
"popularity", "6"));
assertU(adoc("id", "7", "title", "a1 b1 c1 d1 e1 f1 g1", "description",
"C", "popularity", "7"));
assertU(adoc("id", "8", "title", "a1 b1 c1 d1 e1 f1 g1 h1", "description",
"D", "popularity", "8"));
assertU(commit());
}
@Test
public void testRankingSolrSort() throws Exception {
// before();
loadFeature("powpularityS", SolrFeature.class.getCanonicalName(),
"{\"q\":\"{!func}pow(popularity,2)\"}");
loadModel("powpularityS-model", LinearModel.class.getCanonicalName(),
new String[] {"powpularityS"}, "{\"weights\":{\"powpularityS\":1.0}}");
final SolrQuery query = new SolrQuery();
query.setQuery("title:a1");
query.add("fl", "*, score");
query.add("rows", "4");
// Normal term match
assertJQ("/query" + query.toQueryString(), "/response/numFound/==8");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='2'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='3'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='4'");
//Add sort
query.add("sort", "description desc");
assertJQ("/query" + query.toQueryString(), "/response/numFound/==8");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='5'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='8'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='7'");
query.add("rq", "{!ltr model=powpularityS-model reRankDocs=4}");
query.set("debugQuery", "on");
assertJQ("/query" + query.toQueryString(), "/response/numFound/==8");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='8'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==64.0");
assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='7'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==49.0");
assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='5'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==25.0");
assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='1'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/score==1.0");
// aftertest();
}
@AfterClass
public static void after() throws Exception {
aftertest();
}
}

View File

@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr;
import org.apache.solr.client.solrj.SolrQuery;
import org.junit.Test;
public class TestParallelWeightCreation extends TestRerankBase{
@Test
public void testLTRScoringQueryParallelWeightCreationResultOrder() throws Exception {
setuptest("solrconfig-ltr_Th10_10.xml", "schema.xml");
assertU(adoc("id", "1", "title", "w1 w3", "description", "w1", "popularity",
"1"));
assertU(adoc("id", "2", "title", "w2", "description", "w2", "popularity",
"2"));
assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity",
"3"));
assertU(adoc("id", "4", "title", "w4 w3", "description", "w4", "popularity",
"4"));
assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity",
"5"));
assertU(commit());
loadFeatures("external_features.json");
loadModels("external_model.json");
loadModels("external_model_store.json");
// check to make sure that the order of results will be the same when using parallel weight creation
final SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.add("fl", "*,score");
query.add("rows", "4");
query.add("rq", "{!ltr reRankDocs=4 model=externalmodel efi.user_query=w3}");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='3'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='4'");
aftertest();
}
@Test
public void testLTRQParserThreadInitialization() throws Exception {
// setting the value of number of threads to -ve should throw an exception
String msg1 = null;
try{
new LTRThreadModule(1,-1);
}catch(IllegalArgumentException iae){
msg1 = iae.getMessage();;
}
assertTrue(msg1.equals("numThreadsPerRequest cannot be less than 1"));
// set totalPoolThreads to 1 and numThreadsPerRequest to 2 and verify that an exception is thrown
String msg2 = null;
try{
new LTRThreadModule(1,2);
}catch(IllegalArgumentException iae){
msg2 = iae.getMessage();
}
assertTrue(msg2.equals("numThreadsPerRequest cannot be greater than totalPoolThreads"));
}
}

View File

@ -0,0 +1,429 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.net.URL;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
import java.util.SortedMap;
import java.util.TreeMap;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.solr.common.params.CommonParams;
import org.apache.solr.common.util.ContentStream;
import org.apache.solr.common.util.ContentStreamBase;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.FeatureException;
import org.apache.solr.ltr.feature.ValueFeature;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.model.LinearModel;
import org.apache.solr.ltr.model.ModelException;
import org.apache.solr.ltr.store.rest.ManagedFeatureStore;
import org.apache.solr.ltr.store.rest.ManagedModelStore;
import org.apache.solr.request.SolrQueryRequestBase;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.rest.ManagedResourceStorage;
import org.apache.solr.rest.SolrSchemaRestApi;
import org.apache.solr.util.RestTestBase;
import org.eclipse.jetty.servlet.ServletHolder;
import org.noggit.ObjectBuilder;
import org.restlet.ext.servlet.ServerServlet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class TestRerankBase extends RestTestBase {
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
protected static final SolrResourceLoader solrResourceLoader = new SolrResourceLoader();
protected static File tmpSolrHome;
protected static File tmpConfDir;
public static final String FEATURE_FILE_NAME = "_schema_feature-store.json";
public static final String MODEL_FILE_NAME = "_schema_model-store.json";
public static final String PARENT_ENDPOINT = "/schema/*";
protected static final String COLLECTION = "collection1";
protected static final String CONF_DIR = COLLECTION + "/conf";
protected static File fstorefile = null;
protected static File mstorefile = null;
public static void setuptest() throws Exception {
setuptest("solrconfig-ltr.xml", "schema.xml");
bulkIndex();
}
public static void setupPersistenttest() throws Exception {
setupPersistentTest("solrconfig-ltr.xml", "schema.xml");
bulkIndex();
}
public static ManagedFeatureStore getManagedFeatureStore() {
return ManagedFeatureStore.getManagedFeatureStore(h.getCore());
}
public static ManagedModelStore getManagedModelStore() {
return ManagedModelStore.getManagedModelStore(h.getCore());
}
protected static SortedMap<ServletHolder,String> setupTestInit(
String solrconfig, String schema,
boolean isPersistent) throws Exception {
tmpSolrHome = createTempDir().toFile();
tmpConfDir = new File(tmpSolrHome, CONF_DIR);
tmpConfDir.deleteOnExit();
FileUtils.copyDirectory(new File(TEST_HOME()),
tmpSolrHome.getAbsoluteFile());
final File fstore = new File(tmpConfDir, FEATURE_FILE_NAME);
final File mstore = new File(tmpConfDir, MODEL_FILE_NAME);
if (isPersistent) {
fstorefile = fstore;
mstorefile = mstore;
}
if (fstore.exists()) {
log.info("remove feature store config file in {}",
fstore.getAbsolutePath());
Files.delete(fstore.toPath());
}
if (mstore.exists()) {
log.info("remove model store config file in {}",
mstore.getAbsolutePath());
Files.delete(mstore.toPath());
}
if (!solrconfig.equals("solrconfig.xml")) {
FileUtils.copyFile(new File(tmpSolrHome.getAbsolutePath()
+ "/collection1/conf/" + solrconfig),
new File(tmpSolrHome.getAbsolutePath()
+ "/collection1/conf/solrconfig.xml"));
}
if (!schema.equals("schema.xml")) {
FileUtils.copyFile(new File(tmpSolrHome.getAbsolutePath()
+ "/collection1/conf/" + schema),
new File(tmpSolrHome.getAbsolutePath()
+ "/collection1/conf/schema.xml"));
}
final SortedMap<ServletHolder,String> extraServlets = new TreeMap<>();
final ServletHolder solrRestApi = new ServletHolder("SolrSchemaRestApi",
ServerServlet.class);
solrRestApi.setInitParameter("org.restlet.application",
SolrSchemaRestApi.class.getCanonicalName());
solrRestApi.setInitParameter("storageIO",
ManagedResourceStorage.InMemoryStorageIO.class.getCanonicalName());
extraServlets.put(solrRestApi, PARENT_ENDPOINT);
System.setProperty("managed.schema.mutable", "true");
return extraServlets;
}
public static void setuptest(String solrconfig, String schema)
throws Exception {
initCore(solrconfig, schema);
SortedMap<ServletHolder,String> extraServlets =
setupTestInit(solrconfig,schema,false);
System.setProperty("enable.update.log", "false");
createJettyAndHarness(tmpSolrHome.getAbsolutePath(), solrconfig, schema,
"/solr", true, extraServlets);
}
public static void setupPersistentTest(String solrconfig, String schema)
throws Exception {
initCore(solrconfig, schema);
SortedMap<ServletHolder,String> extraServlets =
setupTestInit(solrconfig,schema,true);
createJettyAndHarness(tmpSolrHome.getAbsolutePath(), solrconfig, schema,
"/solr", true, extraServlets);
}
protected static void aftertest() throws Exception {
restTestHarness.close();
restTestHarness = null;
jetty.stop();
jetty = null;
FileUtils.deleteDirectory(tmpSolrHome);
System.clearProperty("managed.schema.mutable");
// System.clearProperty("enable.update.log");
}
public static void makeRestTestHarnessNull() {
restTestHarness = null;
}
/** produces a model encoded in json **/
public static String getModelInJson(String name, String type,
String[] features, String fstore, String params) {
final StringBuilder sb = new StringBuilder();
sb.append("{\n");
sb.append("\"name\":").append('"').append(name).append('"').append(",\n");
sb.append("\"store\":").append('"').append(fstore).append('"')
.append(",\n");
sb.append("\"class\":").append('"').append(type).append('"').append(",\n");
sb.append("\"features\":").append('[');
for (final String feature : features) {
sb.append("\n\t{ ");
sb.append("\"name\":").append('"').append(feature).append('"')
.append("},");
}
sb.deleteCharAt(sb.length() - 1);
sb.append("\n]\n");
if (params != null) {
sb.append(",\n");
sb.append("\"params\":").append(params);
}
sb.append("\n}\n");
return sb.toString();
}
/** produces a model encoded in json **/
public static String getFeatureInJson(String name, String type,
String fstore, String params) {
final StringBuilder sb = new StringBuilder();
sb.append("{\n");
sb.append("\"name\":").append('"').append(name).append('"').append(",\n");
sb.append("\"store\":").append('"').append(fstore).append('"')
.append(",\n");
sb.append("\"class\":").append('"').append(type).append('"');
if (params != null) {
sb.append(",\n");
sb.append("\"params\":").append(params);
}
sb.append("\n}\n");
return sb.toString();
}
protected static void loadFeature(String name, String type, String params)
throws Exception {
final String feature = getFeatureInJson(name, type, "test", params);
log.info("loading feauture \n{} ", feature);
assertJPut(ManagedFeatureStore.REST_END_POINT, feature,
"/responseHeader/status==0");
}
protected static void loadFeature(String name, String type, String fstore,
String params) throws Exception {
final String feature = getFeatureInJson(name, type, fstore, params);
log.info("loading feauture \n{} ", feature);
assertJPut(ManagedFeatureStore.REST_END_POINT, feature,
"/responseHeader/status==0");
}
protected static void loadModel(String name, String type, String[] features,
String params) throws Exception {
loadModel(name, type, features, "test", params);
}
protected static void loadModel(String name, String type, String[] features,
String fstore, String params) throws Exception {
final String model = getModelInJson(name, type, features, fstore, params);
log.info("loading model \n{} ", model);
assertJPut(ManagedModelStore.REST_END_POINT, model,
"/responseHeader/status==0");
}
public static void loadModels(String fileName) throws Exception {
final URL url = TestRerankBase.class.getResource("/modelExamples/"
+ fileName);
final String multipleModels = FileUtils.readFileToString(
new File(url.toURI()), "UTF-8");
assertJPut(ManagedModelStore.REST_END_POINT, multipleModels,
"/responseHeader/status==0");
}
public static LTRScoringModel createModelFromFiles(String modelFileName,
String featureFileName) throws ModelException, Exception {
URL url = TestRerankBase.class.getResource("/modelExamples/"
+ modelFileName);
final String modelJson = FileUtils.readFileToString(new File(url.toURI()),
"UTF-8");
final ManagedModelStore ms = getManagedModelStore();
url = TestRerankBase.class.getResource("/featureExamples/"
+ featureFileName);
final String featureJson = FileUtils.readFileToString(
new File(url.toURI()), "UTF-8");
Object parsedFeatureJson = null;
try {
parsedFeatureJson = ObjectBuilder.fromJSON(featureJson);
} catch (final IOException ioExc) {
throw new ModelException("ObjectBuilder failed parsing json", ioExc);
}
final ManagedFeatureStore fs = getManagedFeatureStore();
// fs.getFeatureStore(null).clear();
fs.doDeleteChild(null, "*"); // is this safe??
// based on my need to call this I dont think that
// "getNewManagedFeatureStore()"
// is actually returning a new feature store each time
fs.applyUpdatesToManagedData(parsedFeatureJson);
ms.setManagedFeatureStore(fs); // can we skip this and just use fs directly below?
final LTRScoringModel ltrScoringModel = ManagedModelStore.fromLTRScoringModelMap(
solrResourceLoader, mapFromJson(modelJson), ms.getManagedFeatureStore());
ms.addModel(ltrScoringModel);
return ltrScoringModel;
}
@SuppressWarnings("unchecked")
static private Map<String,Object> mapFromJson(String json) throws ModelException {
Object parsedJson = null;
try {
parsedJson = ObjectBuilder.fromJSON(json);
} catch (final IOException ioExc) {
throw new ModelException("ObjectBuilder failed parsing json", ioExc);
}
return (Map<String,Object>) parsedJson;
}
public static void loadFeatures(String fileName) throws Exception {
final URL url = TestRerankBase.class.getResource("/featureExamples/"
+ fileName);
final String multipleFeatures = FileUtils.readFileToString(
new File(url.toURI()), "UTF-8");
log.info("send \n{}", multipleFeatures);
assertJPut(ManagedFeatureStore.REST_END_POINT, multipleFeatures,
"/responseHeader/status==0");
}
protected List<Feature> getFeatures(List<String> names)
throws FeatureException {
final List<Feature> features = new ArrayList<>();
int pos = 0;
for (final String name : names) {
final Map<String,Object> params = new HashMap<String,Object>();
params.put("value", 10);
final Feature f = Feature.getInstance(solrResourceLoader,
ValueFeature.class.getCanonicalName(),
name, params);
f.setIndex(pos);
features.add(f);
++pos;
}
return features;
}
protected List<Feature> getFeatures(String[] names) throws FeatureException {
return getFeatures(Arrays.asList(names));
}
protected static void loadModelAndFeatures(String name, int allFeatureCount,
int modelFeatureCount) throws Exception {
final String[] features = new String[modelFeatureCount];
final String[] weights = new String[modelFeatureCount];
for (int i = 0; i < allFeatureCount; i++) {
final String featureName = "c" + i;
if (i < modelFeatureCount) {
features[i] = featureName;
weights[i] = "\"" + featureName + "\":1.0";
}
loadFeature(featureName, ValueFeature.ValueFeatureWeight.class.getCanonicalName(),
"{\"value\":" + i + "}");
}
loadModel(name, LinearModel.class.getCanonicalName(), features,
"{\"weights\":{" + StringUtils.join(weights, ",") + "}}");
}
protected static void bulkIndex() throws Exception {
assertU(adoc("title", "bloomberg different bla", "description",
"bloomberg", "id", "6", "popularity", "1"));
assertU(adoc("title", "bloomberg bloomberg ", "description", "bloomberg",
"id", "7", "popularity", "2"));
assertU(adoc("title", "bloomberg bloomberg bloomberg", "description",
"bloomberg", "id", "8", "popularity", "3"));
assertU(adoc("title", "bloomberg bloomberg bloomberg bloomberg",
"description", "bloomberg", "id", "9", "popularity", "5"));
assertU(commit());
}
protected static void bulkIndex(String filePath) throws Exception {
final SolrQueryRequestBase req = lrf.makeRequest(
CommonParams.STREAM_CONTENTTYPE, "application/xml");
final List<ContentStream> streams = new ArrayList<ContentStream>();
final File file = new File(filePath);
streams.add(new ContentStreamBase.FileStream(file));
req.setContentStreams(streams);
try {
final SolrQueryResponse res = new SolrQueryResponse();
h.updater.handleRequest(req, res);
} catch (final Throwable ex) {
// Ignore. Just log the exception and go to the next file
log.error(ex.getMessage(), ex);
}
assertU(commit());
}
protected static void buildIndexUsingAdoc(String filepath)
throws FileNotFoundException {
final Scanner scn = new Scanner(new File(filepath), "UTF-8");
StringBuffer buff = new StringBuffer();
scn.nextLine();
scn.nextLine();
scn.nextLine(); // Skip the first 3 lines then add everything else
final ArrayList<String> docsToAdd = new ArrayList<String>();
while (scn.hasNext()) {
String curLine = scn.nextLine();
if (curLine.contains("</doc>")) {
buff.append(curLine + "\n");
docsToAdd.add(buff.toString().replace("</add>", "")
.replace("<doc>", "<add>\n<doc>")
.replace("</doc>", "</doc>\n</add>"));
if (!scn.hasNext()) {
break;
} else {
curLine = scn.nextLine();
}
buff = new StringBuffer();
}
buff.append(curLine + "\n");
}
for (final String doc : docsToAdd) {
assertU(doc.trim());
}
assertU(commit());
scn.close();
}
}

View File

@ -0,0 +1,251 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FloatDocValuesField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.ValueFeature;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.model.ModelException;
import org.apache.solr.ltr.model.TestLinearModel;
import org.apache.solr.ltr.norm.IdentityNormalizer;
import org.apache.solr.ltr.norm.Normalizer;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
public class TestSelectiveWeightCreation extends TestRerankBase {
private IndexSearcher getSearcher(IndexReader r) {
final IndexSearcher searcher = newSearcher(r, false, false);
return searcher;
}
private static List<Feature> makeFeatures(int[] featureIds) {
final List<Feature> features = new ArrayList<>();
for (final int i : featureIds) {
Map<String,Object> params = new HashMap<String,Object>();
params.put("value", i);
final Feature f = Feature.getInstance(solrResourceLoader,
ValueFeature.class.getCanonicalName(),
"f" + i, params);
f.setIndex(i);
features.add(f);
}
return features;
}
private static Map<String,Object> makeFeatureWeights(List<Feature> features) {
final Map<String,Object> nameParams = new HashMap<String,Object>();
final HashMap<String,Double> modelWeights = new HashMap<String,Double>();
for (final Feature feat : features) {
modelWeights.put(feat.getName(), 0.1);
}
nameParams.put("weights", modelWeights);
return nameParams;
}
private LTRScoringQuery.ModelWeight performQuery(TopDocs hits,
IndexSearcher searcher, int docid, LTRScoringQuery model) throws IOException,
ModelException {
final List<LeafReaderContext> leafContexts = searcher.getTopReaderContext()
.leaves();
final int n = ReaderUtil.subIndex(hits.scoreDocs[0].doc, leafContexts);
final LeafReaderContext context = leafContexts.get(n);
final int deBasedDoc = hits.scoreDocs[0].doc - context.docBase;
final Weight weight = searcher.createNormalizedWeight(model, true);
final Scorer scorer = weight.scorer(context);
// rerank using the field final-score
scorer.iterator().advance(deBasedDoc);
scorer.score();
assertTrue(weight instanceof LTRScoringQuery.ModelWeight);
final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) weight;
return modelWeight;
}
@BeforeClass
public static void before() throws Exception {
setuptest("solrconfig-ltr.xml", "schema.xml");
assertU(adoc("id", "1", "title", "w1 w3", "description", "w1", "popularity",
"1"));
assertU(adoc("id", "2", "title", "w2", "description", "w2", "popularity",
"2"));
assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity",
"3"));
assertU(adoc("id", "4", "title", "w4 w3", "description", "w4", "popularity",
"4"));
assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity",
"5"));
assertU(commit());
loadFeatures("external_features.json");
loadModels("external_model.json");
loadModels("external_model_store.json");
}
@AfterClass
public static void after() throws Exception {
aftertest();
}
@Test
public void testScoringQueryWeightCreation() throws IOException, ModelException {
final Directory dir = newDirectory();
final RandomIndexWriter w = new RandomIndexWriter(random(), dir);
Document doc = new Document();
doc.add(newStringField("id", "0", Field.Store.YES));
doc.add(newTextField("field", "wizard the the the the the oz",
Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 1.0f));
w.addDocument(doc);
doc = new Document();
doc.add(newStringField("id", "1", Field.Store.YES));
// 1 extra token, but wizard and oz are close;
doc.add(newTextField("field", "wizard oz the the the the the the",
Field.Store.NO));
doc.add(new FloatDocValuesField("final-score", 2.0f));
w.addDocument(doc);
final IndexReader r = w.getReader();
w.close();
// Do ordinary BooleanQuery:
final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
final IndexSearcher searcher = getSearcher(r);
// first run the standard query
final TopDocs hits = searcher.search(bqBuilder.build(), 10);
assertEquals(2, hits.totalHits);
assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
List<Feature> features = makeFeatures(new int[] {0, 1, 2});
final List<Feature> allFeatures = makeFeatures(new int[] {0, 1, 2, 3, 4, 5,
6, 7, 8, 9});
final List<Normalizer> norms = new ArrayList<>();
for (int k=0; k < features.size(); ++k){
norms.add(IdentityNormalizer.INSTANCE);
}
// when features are NOT requested in the response, only the modelFeature weights should be created
final LTRScoringModel ltrScoringModel1 = TestLinearModel.createLinearModel("test",
features, norms, "test", allFeatures,
makeFeatureWeights(features));
LTRScoringQuery.ModelWeight modelWeight = performQuery(hits, searcher,
hits.scoreDocs[0].doc, new LTRScoringQuery(ltrScoringModel1, false)); // features not requested in response
LTRScoringQuery.FeatureInfo[] featuresInfo = modelWeight.getFeaturesInfo();
assertEquals(features.size(), modelWeight.getModelFeatureValuesNormalized().length);
int validFeatures = 0;
for (int i=0; i < featuresInfo.length; ++i){
if (featuresInfo[i] != null && featuresInfo[i].isUsed()){
validFeatures += 1;
}
}
assertEquals(validFeatures, features.size());
// when features are requested in the response, weights should be created for all features
final LTRScoringModel ltrScoringModel2 = TestLinearModel.createLinearModel("test",
features, norms, "test", allFeatures,
makeFeatureWeights(features));
modelWeight = performQuery(hits, searcher,
hits.scoreDocs[0].doc, new LTRScoringQuery(ltrScoringModel2, true)); // features requested in response
featuresInfo = modelWeight.getFeaturesInfo();
assertEquals(features.size(), modelWeight.getModelFeatureValuesNormalized().length);
assertEquals(allFeatures.size(), modelWeight.getExtractedFeatureWeights().length);
validFeatures = 0;
for (int i=0; i < featuresInfo.length; ++i){
if (featuresInfo[i] != null && featuresInfo[i].isUsed()){
validFeatures += 1;
}
}
assertEquals(validFeatures, allFeatures.size());
assertU(delI("0"));assertU(delI("1"));
r.close();
dir.close();
}
@Test
public void testSelectiveWeightsRequestFeaturesFromDifferentStore() throws Exception {
final SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.add("fl", "*,score");
query.add("rows", "4");
query.add("rq", "{!ltr reRankDocs=4 model=externalmodel efi.user_query=w3}");
query.add("fl", "fv:[fv]");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='3'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='4'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fv=='matchedTitle:1.0;titlePhraseMatch:0.40254828'"); // extract all features in default store
query.remove("fl");
query.remove("rq");
query.add("fl", "*,score");
query.add("rq", "{!ltr reRankDocs=4 model=externalmodel efi.user_query=w3}");
query.add("fl", "fv:[fv store=fstore4 efi.myPop=3]");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==0.999");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fv=='popularity:3.0;originalScore:1.0'"); // extract all features from fstore4
query.remove("fl");
query.remove("rq");
query.add("fl", "*,score");
query.add("rq", "{!ltr reRankDocs=4 model=externalmodelstore efi.user_query=w3 efi.myconf=0.8}");
query.add("fl", "fv:[fv store=fstore4 efi.myPop=3]");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); // score using fstore2 used by externalmodelstore
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==0.7992");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fv=='popularity:3.0;originalScore:1.0'"); // extract all features from fstore4
}
}

View File

@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.feature;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.ltr.TestRerankBase;
import org.apache.solr.ltr.model.LinearModel;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
public class TestEdisMaxSolrFeature extends TestRerankBase {
@BeforeClass
public static void before() throws Exception {
setuptest("solrconfig-ltr.xml", "schema.xml");
assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity",
"1"));
assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description",
"w2 2asd asdd didid", "popularity", "2"));
assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity",
"3"));
assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity",
"4"));
assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity",
"5"));
assertU(adoc("id", "6", "title", "w1 w2", "description", "w1 w2",
"popularity", "6"));
assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description",
"w1 w2 w3 w4 w5 w8", "popularity", "7"));
assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description",
"w1 w1 w1 w2 w2", "popularity", "8"));
assertU(commit());
}
@AfterClass
public static void after() throws Exception {
aftertest();
}
@Test
public void testEdisMaxSolrFeature() throws Exception {
loadFeature(
"SomeEdisMax",
SolrFeature.class.getCanonicalName(),
"{\"q\":\"{!edismax qf='title description' pf='description' mm=100% boost='pow(popularity, 0.1)' v='w1' tie=0.1}\"}");
loadModel("EdisMax-model", LinearModel.class.getCanonicalName(),
new String[] {"SomeEdisMax"}, "{\"weights\":{\"SomeEdisMax\":1.0}}");
final SolrQuery query = new SolrQuery();
query.setQuery("title:w1");
query.add("fl", "*, score");
query.add("rows", "4");
query.add("rq", "{!ltr model=EdisMax-model reRankDocs=4}");
query.set("debugQuery", "on");
restTestHarness.query("/query" + query.toQueryString());
assertJQ("/query" + query.toQueryString(), "/response/numFound/==4");
}
}

View File

@ -0,0 +1,157 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.feature;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.ltr.TestRerankBase;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
public class TestExternalFeatures extends TestRerankBase {
@BeforeClass
public static void before() throws Exception {
setuptest("solrconfig-ltr.xml", "schema.xml");
assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity",
"1"));
assertU(adoc("id", "2", "title", "w2", "description", "w2", "popularity",
"2"));
assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity",
"3"));
assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity",
"4"));
assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity",
"5"));
assertU(commit());
loadFeatures("external_features.json");
loadModels("external_model.json");
}
@AfterClass
public static void after() throws Exception {
aftertest();
}
@Test
public void testEfiInTransformerShouldNotChangeOrderOfRerankedResults() throws Exception {
final SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.add("fl", "*,score");
query.add("rows", "3");
// Regular scores
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==1.0");
assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='2'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==1.0");
assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='3'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==1.0");
query.add("fl", "[fv]");
query.add("rq", "{!ltr reRankDocs=3 model=externalmodel efi.user_query=w3}");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='3'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==0.999");
assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='1'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==0.0");
assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='2'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==0.0");
// Adding an efi in the transformer should not affect the rq ranking with a
// different value for efi of the same parameter
query.remove("fl");
query.add("fl", "id,[fv efi.user_query=w2]");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='3'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='1'");
assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='2'");
}
@Test
public void testFeaturesUseStopwordQueryReturnEmptyFeatureVector() throws Exception {
final SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.add("fl", "*,score,fv:[fv]");
query.add("rows", "1");
// Stopword only query passed in
query.add("rq", "{!ltr reRankDocs=3 model=externalmodel efi.user_query='a'}");
// Features are query title matches, which remove stopwords, leaving blank query, so no matches
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fv==''");
}
@Test
public void testEfiFeatureExtraction() throws Exception {
final SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.add("rows", "1");
// Features we're extracting depend on external feature info not passed in
query.add("fl", "[fv]");
assertJQ("/query" + query.toQueryString(), "/error/msg=='Exception from createWeight for SolrFeature [name=matchedTitle, params={q={!terms f=title}${user_query}}] SolrFeatureWeight requires efi parameter that was not passed in request.'");
// Adding efi in features section should make it work
query.remove("fl");
query.add("fl", "score,fvalias:[fv store=fstore2 efi.myconf=2.3]");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fvalias=='confidence:2.3;originalScore:1.0'");
// Adding efi in transformer + rq should still use the transformer's params for feature extraction
query.remove("fl");
query.add("fl", "score,fvalias:[fv store=fstore2 efi.myconf=2.3]");
query.add("rq", "{!ltr reRankDocs=3 model=externalmodel efi.user_query=w3}");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fvalias=='confidence:2.3;originalScore:1.0'");
}
@Test
public void featureExtraction_valueFeatureImplicitlyNotRequired_shouldNotScoreFeature() throws Exception {
final SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.add("rows", "1");
// Efi is explicitly not required, so we do not score the feature
query.remove("fl");
query.add("fl", "fvalias:[fv store=fstore2]");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fvalias=='originalScore:0.0'");
}
@Test
public void featureExtraction_valueFeatureExplicitlyNotRequired_shouldNotScoreFeature() throws Exception {
final SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.add("rows", "1");
// Efi is explicitly not required, so we do not score the feature
query.remove("fl");
query.add("fl", "fvalias:[fv store=fstore3]");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fvalias=='originalScore:0.0'");
}
@Test
public void featureExtraction_valueFeatureRequired_shouldThrowException() throws Exception {
final SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.add("rows", "1");
// Using nondefault store should still result in error with no efi when it is required (myPop)
query.remove("fl");
query.add("fl", "fvalias:[fv store=fstore4]");
assertJQ("/query" + query.toQueryString(), "/error/msg=='Exception from createWeight for ValueFeature [name=popularity, params={value=${myPop}, required=true}] ValueFeatureWeight requires efi parameter that was not passed in request.'");
}
}

View File

@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.feature;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.ltr.TestRerankBase;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
public class TestExternalValueFeatures extends TestRerankBase {
@BeforeClass
public static void before() throws Exception {
setuptest("solrconfig-ltr.xml", "schema.xml");
assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity",
"1"));
assertU(adoc("id", "2", "title", "w2", "description", "w2", "popularity",
"2"));
assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity",
"3"));
assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity",
"4"));
assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity",
"5"));
assertU(commit());
loadFeatures("external_features_for_sparse_processing.json");
loadModels("multipleadditivetreesmodel_external_binary_features.json");
}
@AfterClass
public static void after() throws Exception {
aftertest();
}
@Test
public void efiFeatureProcessing_oneEfiMissing_shouldNotCalculateMissingFeature() throws Exception {
SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.add("fl", "*,score,features:[fv]");
query.add("rows", "3");
query.add("fl", "[fv]");
query.add("rq", "{!ltr reRankDocs=3 model=external_model_binary_feature efi.user_device_tablet=1}");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'");
assertJQ("/query" + query.toQueryString(),
"/response/docs/[0]/features=='user_device_tablet:1.0'");
assertJQ("/query" + query.toQueryString(),
"/response/docs/[0]/score==65.0");
}
@Test
public void efiFeatureProcessing_allEfisMissing_shouldReturnZeroScore() throws Exception {
SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.add("fl", "*,score,features:[fv]");
query.add("rows", "3");
query.add("fl", "[fv]");
query
.add("rq", "{!ltr reRankDocs=3 model=external_model_binary_feature}");
assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'");
assertJQ("/query" + query.toQueryString(),
"/response/docs/[0]/features==''");
assertJQ("/query" + query.toQueryString(),
"/response/docs/[0]/score==0.0");
}
}

View File

@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.feature;
import java.util.List;
import java.util.Map;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.ltr.TestRerankBase;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.noggit.ObjectBuilder;
public class TestFeatureExtractionFromMultipleSegments extends TestRerankBase {
static final String AB = "abcdefghijklmnopqrstuvwxyz";
static String randomString( int len ){
StringBuilder sb = new StringBuilder( len );
for( int i = 0; i < len; i++ ) {
sb.append( AB.charAt( random().nextInt(AB.length()) ) );
}
return sb.toString();
}
@BeforeClass
public static void before() throws Exception {
// solrconfig-multiseg.xml contains the merge policy to restrict merging
setuptest("solrconfig-multiseg.xml", "schema.xml");
// index 400 documents
for(int i = 0; i<400;i=i+20) {
assertU(adoc("id", new Integer(i).toString(), "popularity", "201", "description", "apple is a company " + randomString(i%6+3), "normHits", "0.1"));
assertU(adoc("id", new Integer(i+1).toString(), "popularity", "201", "description", "d " + randomString(i%6+3), "normHits", "0.11"));
assertU(adoc("id", new Integer(i+2).toString(), "popularity", "201", "description", "apple is a company too " + randomString(i%6+3), "normHits", "0.1"));
assertU(adoc("id", new Integer(i+3).toString(), "popularity", "201", "description", "new york city is big apple " + randomString(i%6+3), "normHits", "0.11"));
assertU(adoc("id", new Integer(i+6).toString(), "popularity", "301", "description", "function name " + randomString(i%6+3), "normHits", "0.1"));
assertU(adoc("id", new Integer(i+7).toString(), "popularity", "301", "description", "function " + randomString(i%6+3), "normHits", "0.1"));
assertU(adoc("id", new Integer(i+8).toString(), "popularity", "301", "description", "This is a sample function for testing " + randomString(i%6+3), "normHits", "0.1"));
assertU(adoc("id", new Integer(i+9).toString(), "popularity", "301", "description", "Function to check out stock prices "+randomString(i%6+3), "normHits", "0.1"));
assertU(adoc("id", new Integer(i+10).toString(),"popularity", "301", "description", "Some descriptions "+randomString(i%6+3), "normHits", "0.1"));
assertU(adoc("id", new Integer(i+11).toString(), "popularity", "201", "description", "apple apple is a company " + randomString(i%6+3), "normHits", "0.1"));
assertU(adoc("id", new Integer(i+12).toString(), "popularity", "201", "description", "Big Apple is New York.", "normHits", "0.01"));
assertU(adoc("id", new Integer(i+13).toString(), "popularity", "201", "description", "New some York is Big. "+ randomString(i%6+3), "normHits", "0.1"));
assertU(adoc("id", new Integer(i+14).toString(), "popularity", "201", "description", "apple apple is a company " + randomString(i%6+3), "normHits", "0.1"));
assertU(adoc("id", new Integer(i+15).toString(), "popularity", "201", "description", "Big Apple is New York.", "normHits", "0.01"));
assertU(adoc("id", new Integer(i+16).toString(), "popularity", "401", "description", "barack h", "normHits", "0.0"));
assertU(adoc("id", new Integer(i+17).toString(), "popularity", "201", "description", "red delicious apple " + randomString(i%6+3), "normHits", "0.1"));
assertU(adoc("id", new Integer(i+18).toString(), "popularity", "201", "description", "nyc " + randomString(i%6+3), "normHits", "0.11"));
}
assertU(commit());
loadFeatures("comp_features.json");
}
@AfterClass
public static void after() throws Exception {
aftertest();
}
@Test
public void testFeatureExtractionFromMultipleSegments() throws Exception {
final SolrQuery query = new SolrQuery();
query.setQuery("{!edismax qf='description^1' boost='sum(product(pow(normHits, 0.7), 1600), .1)' v='apple'}");
// request 100 rows, if any rows are fetched from the second or subsequent segments the tests should succeed if LTRRescorer::extractFeaturesInfo() advances the doc iterator properly
int numRows = 100;
query.add("rows", (new Integer(numRows)).toString());
query.add("wt", "json");
query.add("fq", "popularity:201");
query.add("fl", "*, score,id,normHits,description,fv:[features store='feature-store-6' format='dense' efi.user_text='apple']");
String res = restTestHarness.query("/query" + query.toQueryString());
Map<String,Object> resultJson = (Map<String,Object>) ObjectBuilder.fromJSON(res);
List<Map<String,Object>> docs = (List<Map<String,Object>>)((Map<String,Object>)resultJson.get("response")).get("docs");
int passCount = 0;
for (final Map<String,Object> doc : docs) {
String features = (String)doc.get("fv");
assert(features.length() > 0);
++passCount;
}
assert(passCount == numRows);
}
}

View File

@ -0,0 +1,254 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.feature;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.ltr.TestRerankBase;
import org.apache.solr.ltr.model.LinearModel;
import org.apache.solr.ltr.store.FeatureStore;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
public class TestFeatureLogging extends TestRerankBase {
@BeforeClass
public static void setup() throws Exception {
setuptest();
}
@AfterClass
public static void after() throws Exception {
aftertest();
}
@Test
public void testGeneratedFeatures() throws Exception {
loadFeature("c1", ValueFeature.class.getCanonicalName(), "test1",
"{\"value\":1.0}");
loadFeature("c2", ValueFeature.class.getCanonicalName(), "test1",
"{\"value\":2.0}");
loadFeature("c3", ValueFeature.class.getCanonicalName(), "test1",
"{\"value\":3.0}");
loadFeature("pop", FieldValueFeature.class.getCanonicalName(), "test1",
"{\"field\":\"popularity\"}");
loadFeature("nomatch", SolrFeature.class.getCanonicalName(), "test1",
"{\"q\":\"{!terms f=title}foobarbat\"}");
loadFeature("yesmatch", SolrFeature.class.getCanonicalName(), "test1",
"{\"q\":\"{!terms f=popularity}2\"}");
loadModel("sum1", LinearModel.class.getCanonicalName(), new String[] {
"c1", "c2", "c3"}, "test1",
"{\"weights\":{\"c1\":1.0,\"c2\":1.0,\"c3\":1.0}}");
final SolrQuery query = new SolrQuery();
query.setQuery("title:bloomberg");
query.add("fl", "title,description,id,popularity,[fv]");
query.add("rows", "3");
query.add("debugQuery", "on");
query.add("rq", "{!ltr reRankDocs=3 model=sum1}");
restTestHarness.query("/query" + query.toQueryString());
assertJQ(
"/query" + query.toQueryString(),
"/response/docs/[0]/=={'title':'bloomberg bloomberg ', 'description':'bloomberg','id':'7', 'popularity':2, '[fv]':'c1:1.0;c2:2.0;c3:3.0;pop:2.0;yesmatch:1.0'}");
query.remove("fl");
query.add("fl", "[fv]");
query.add("rows", "3");
query.add("rq", "{!ltr reRankDocs=3 model=sum1}");
restTestHarness.query("/query" + query.toQueryString());
assertJQ("/query" + query.toQueryString(),
"/response/docs/[0]/=={'[fv]':'c1:1.0;c2:2.0;c3:3.0;pop:2.0;yesmatch:1.0'}");
query.remove("rq");
// set logging at false but still asking for feature, and it should work anyway
query.add("rq", "{!ltr reRankDocs=3 model=sum1}");
assertJQ("/query" + query.toQueryString(),
"/response/docs/[0]/=={'[fv]':'c1:1.0;c2:2.0;c3:3.0;pop:2.0;yesmatch:1.0'}");
}
@Test
public void testDefaultStoreFeatureExtraction() throws Exception {
loadFeature("defaultf1", ValueFeature.class.getCanonicalName(),
FeatureStore.DEFAULT_FEATURE_STORE_NAME,
"{\"value\":1.0}");
loadFeature("store8f1", ValueFeature.class.getCanonicalName(),
"store8",
"{\"value\":2.0}");
loadFeature("store9f1", ValueFeature.class.getCanonicalName(),
"store9",
"{\"value\":3.0}");
loadModel("store9m1", LinearModel.class.getCanonicalName(),
new String[] {"store9f1"},
"store9",
"{\"weights\":{\"store9f1\":1.0}}");
final SolrQuery query = new SolrQuery();
query.setQuery("id:7");
query.add("rows", "1");
// No store specified, use default store for extraction
query.add("fl", "fv:[fv]");
assertJQ("/query" + query.toQueryString(),
"/response/docs/[0]/=={'fv':'defaultf1:1.0'}");
// Store specified, use store for extraction
query.remove("fl");
query.add("fl", "fv:[fv store=store8]");
assertJQ("/query" + query.toQueryString(),
"/response/docs/[0]/=={'fv':'store8f1:2.0'}");
// Store specified + model specified, use store for extraction
query.add("rq", "{!ltr reRankDocs=3 model=store9m1}");
assertJQ("/query" + query.toQueryString(),
"/response/docs/[0]/=={'fv':'store8f1:2.0'}");
// No store specified + model specified, use model store for extraction
query.remove("fl");
query.add("fl", "fv:[fv]");
assertJQ("/query" + query.toQueryString(),
"/response/docs/[0]/=={'fv':'store9f1:3.0'}");
}
@Test
public void testGeneratedGroup() throws Exception {
loadFeature("c1", ValueFeature.class.getCanonicalName(), "testgroup",
"{\"value\":1.0}");
loadFeature("c2", ValueFeature.class.getCanonicalName(), "testgroup",
"{\"value\":2.0}");
loadFeature("c3", ValueFeature.class.getCanonicalName(), "testgroup",
"{\"value\":3.0}");
loadFeature("pop", FieldValueFeature.class.getCanonicalName(), "testgroup",
"{\"field\":\"popularity\"}");
loadModel("sumgroup", LinearModel.class.getCanonicalName(), new String[] {
"c1", "c2", "c3"}, "testgroup",
"{\"weights\":{\"c1\":1.0,\"c2\":1.0,\"c3\":1.0}}");
final SolrQuery query = new SolrQuery();
query.setQuery("title:bloomberg");
query.add("fl", "*,[fv]");
query.add("debugQuery", "on");
query.remove("fl");
query.add("fl", "fv:[fv]");
query.add("rows", "3");
query.add("group", "true");
query.add("group.field", "title");
query.add("rq", "{!ltr reRankDocs=3 model=sumgroup}");
restTestHarness.query("/query" + query.toQueryString());
assertJQ(
"/query" + query.toQueryString(),
"/grouped/title/groups/[0]/doclist/docs/[0]/=={'fv':'c1:1.0;c2:2.0;c3:3.0;pop:5.0'}");
query.remove("fl");
query.add("fl", "fv:[fv fvwt=json]");
restTestHarness.query("/query" + query.toQueryString());
assertJQ(
"/query" + query.toQueryString(),
"/grouped/title/groups/[0]/doclist/docs/[0]/fv/=={'c1':1.0,'c2':2.0,'c3':3.0,'pop':5.0}");
query.remove("fl");
query.add("fl", "fv:[fv fvwt=json]");
assertJQ(
"/query" + query.toQueryString(),
"/grouped/title/groups/[0]/doclist/docs/[0]/fv/=={'c1':1.0,'c2':2.0,'c3':3.0,'pop':5.0}");
}
@Test
public void testSparseDenseFeatures() throws Exception {
loadFeature("match", SolrFeature.class.getCanonicalName(), "test4",
"{\"q\":\"{!terms f=title}different\"}");
loadFeature("c4", ValueFeature.class.getCanonicalName(), "test4",
"{\"value\":1.0}");
loadModel("sum4", LinearModel.class.getCanonicalName(), new String[] {
"match"}, "test4",
"{\"weights\":{\"match\":1.0}}");
//json - no feature format check (default to sparse)
final SolrQuery query = new SolrQuery();
query.setQuery("title:bloomberg");
query.add("rows", "10");
query.add("fl", "*,score,fv:[fv store=test4 fvwt=json]");
query.add("rq", "{!ltr reRankDocs=10 model=sum4}");
assertJQ(
"/query" + query.toQueryString(),
"/response/docs/[0]/fv/=={'match':1.0,'c4':1.0}");
assertJQ(
"/query" + query.toQueryString(),
"/response/docs/[1]/fv/=={'c4':1.0}");
//json - sparse feature format check
query.remove("fl");
query.add("fl", "*,score,fv:[fv store=test4 format=sparse fvwt=json]");
assertJQ(
"/query" + query.toQueryString(),
"/response/docs/[0]/fv/=={'match':1.0,'c4':1.0}");
assertJQ(
"/query" + query.toQueryString(),
"/response/docs/[1]/fv/=={'c4':1.0}");
//json - dense feature format check
query.remove("fl");
query.add("fl", "*,score,fv:[fv store=test4 format=dense fvwt=json]");
assertJQ(
"/query" + query.toQueryString(),
"/response/docs/[0]/fv/=={'match':1.0,'c4':1.0}");
assertJQ(
"/query" + query.toQueryString(),
"/response/docs/[1]/fv/=={'match':0.0,'c4':1.0}");
//csv - no feature format check (default to sparse)
query.remove("fl");
query.add("fl", "*,score,fv:[fv store=test4 fvwt=csv]");
assertJQ(
"/query" + query.toQueryString(),
"/response/docs/[0]/fv/=='match:1.0;c4:1.0'");
assertJQ(
"/query" + query.toQueryString(),
"/response/docs/[1]/fv/=='c4:1.0'");
//csv - sparse feature format check
query.remove("fl");
query.add("fl", "*,score,fv:[fv store=test4 format=sparse fvwt=csv]");
assertJQ(
"/query" + query.toQueryString(),
"/response/docs/[0]/fv/=='match:1.0;c4:1.0'");
assertJQ(
"/query" + query.toQueryString(),
"/response/docs/[1]/fv/=='c4:1.0'");
//csv - dense feature format check
query.remove("fl");
query.add("fl", "*,score,fv:[fv store=test4 format=dense fvwt=csv]");
assertJQ(
"/query" + query.toQueryString(),
"/response/docs/[0]/fv/=='match:1.0;c4:1.0'");
assertJQ(
"/query" + query.toQueryString(),
"/response/docs/[1]/fv/=='match:0.0;c4:1.0'");
}
}

View File

@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.feature;
import org.apache.solr.ltr.TestRerankBase;
import org.apache.solr.ltr.store.rest.ManagedFeatureStore;
import org.apache.solr.ltr.store.rest.TestManagedFeatureStore;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
public class TestFeatureLtrScoringModel extends TestRerankBase {
static ManagedFeatureStore store = null;
@BeforeClass
public static void setup() throws Exception {
setuptest();
store = getManagedFeatureStore();
}
@AfterClass
public static void after() throws Exception {
aftertest();
}
@Test
public void getInstanceTest() throws FeatureException
{
store.addFeature(TestManagedFeatureStore.createMap("test",
OriginalScoreFeature.class.getCanonicalName(), null),
"testFstore");
final Feature feature = store.getFeatureStore("testFstore").get("test");
assertNotNull(feature);
assertEquals("test", feature.getName());
assertEquals(OriginalScoreFeature.class.getCanonicalName(), feature
.getClass().getCanonicalName());
}
@Test
public void getInvalidInstanceTest()
{
final String nonExistingClassName = "org.apache.solr.ltr.feature.LOLFeature";
final ClassNotFoundException expectedException =
new ClassNotFoundException(nonExistingClassName);
try {
store.addFeature(TestManagedFeatureStore.createMap("test",
nonExistingClassName, null),
"testFstore2");
fail("getInvalidInstanceTest failed to throw exception: "+expectedException);
} catch (Exception actualException) {
Throwable rootError = getRootCause(actualException);
assertEquals(expectedException.toString(), rootError.toString());
}
}
}

View File

@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.feature;
import java.util.HashMap;
import java.util.Map;
import org.apache.solr.ltr.TestRerankBase;
import org.apache.solr.ltr.store.FeatureStore;
import org.apache.solr.ltr.store.rest.ManagedFeatureStore;
import org.apache.solr.ltr.store.rest.TestManagedFeatureStore;
import org.junit.BeforeClass;
import org.junit.Test;
public class TestFeatureStore extends TestRerankBase {
static ManagedFeatureStore fstore = null;
@BeforeClass
public static void setup() throws Exception {
setuptest();
fstore = getManagedFeatureStore();
}
@Test
public void testDefaultFeatureStoreName()
{
assertEquals("_DEFAULT_", FeatureStore.DEFAULT_FEATURE_STORE_NAME);
final FeatureStore expectedFeatureStore = fstore.getFeatureStore(FeatureStore.DEFAULT_FEATURE_STORE_NAME);
final FeatureStore actualFeatureStore = fstore.getFeatureStore(null);
assertEquals("getFeatureStore(null) should return the default feature store", expectedFeatureStore, actualFeatureStore);
}
@Test
public void testFeatureStoreAdd() throws FeatureException
{
final FeatureStore fs = fstore.getFeatureStore("fstore-testFeature");
for (int i = 0; i < 5; i++) {
final String name = "c" + i;
fstore.addFeature(TestManagedFeatureStore.createMap(name,
OriginalScoreFeature.class.getCanonicalName(), null),
"fstore-testFeature");
final Feature f = fs.get(name);
assertNotNull(f);
}
assertEquals(5, fs.getFeatures().size());
}
@Test
public void testFeatureStoreGet() throws FeatureException
{
final FeatureStore fs = fstore.getFeatureStore("fstore-testFeature2");
for (int i = 0; i < 5; i++) {
Map<String,Object> params = new HashMap<String,Object>();
params.put("value", i);
final String name = "c" + i;
fstore.addFeature(TestManagedFeatureStore.createMap(name,
ValueFeature.class.getCanonicalName(), params),
"fstore-testFeature2");
}
for (int i = 0; i < 5; i++) {
final Feature f = fs.get("c" + i);
assertEquals("c" + i, f.getName());
assertTrue(f instanceof ValueFeature);
final ValueFeature vf = (ValueFeature)f;
assertEquals(i, vf.getValue());
}
}
@Test
public void testMissingFeatureReturnsNull() {
final FeatureStore fs = fstore.getFeatureStore("fstore-testFeature3");
for (int i = 0; i < 5; i++) {
Map<String,Object> params = new HashMap<String,Object>();
params.put("value", i);
final String name = "testc" + (float) i;
fstore.addFeature(TestManagedFeatureStore.createMap(name,
ValueFeature.class.getCanonicalName(), params),
"fstore-testFeature3");
}
assertNull(fs.get("missing_feature_name"));
}
}

Some files were not shown because too many files have changed in this diff Show More