From 5a66b3bc089e4b3e73b1c41c4cdcd89b183b85e7 Mon Sep 17 00:00:00 2001 From: Christine Poerschke Date: Tue, 1 Nov 2016 17:50:14 +0000 Subject: [PATCH] SOLR-8542: Adds Solr Learning to Rank (LTR) plugin for reranking results with machine learning models. (Michael Nilsson, Diego Ceccarelli, Joshua Pantony, Jon Dorando, Naveen Santhapuri, Alessandro Benedetti, David Grohmann, Christine Poerschke) --- dev-tools/idea/.idea/modules.xml | 1 + dev-tools/idea/solr/contrib/ltr/ltr.iml | 37 + solr/CHANGES.txt | 3 + solr/contrib/ltr/README.md | 406 ++++ solr/contrib/ltr/README.txt | 1 + solr/contrib/ltr/build.xml | 30 + solr/contrib/ltr/example/config.json | 14 + solr/contrib/ltr/example/libsvm_formatter.py | 124 ++ solr/contrib/ltr/example/solrconfig.xml | 1722 +++++++++++++++++ .../ltr/example/techproducts-features.json | 26 + .../ltr/example/techproducts-model.json | 18 + .../example/train_and_upload_demo_model.py | 163 ++ solr/contrib/ltr/example/user_queries.txt | 8 + solr/contrib/ltr/ivy.xml | 32 + .../src/java/org/apache/solr/ltr/DocInfo.java | 42 + .../org/apache/solr/ltr/FeatureLogger.java | 193 ++ .../java/org/apache/solr/ltr/LTRRescorer.java | 249 +++ .../org/apache/solr/ltr/LTRScoringQuery.java | 738 +++++++ .../org/apache/solr/ltr/LTRThreadModule.java | 163 ++ .../ltr/SolrQueryRequestContextUtils.java | 83 + .../org/apache/solr/ltr/feature/Feature.java | 335 ++++ .../solr/ltr/feature/FeatureException.java | 31 + .../solr/ltr/feature/FieldLengthFeature.java | 152 ++ .../solr/ltr/feature/FieldValueFeature.java | 141 ++ .../ltr/feature/OriginalScoreFeature.java | 118 ++ .../apache/solr/ltr/feature/SolrFeature.java | 320 +++ .../apache/solr/ltr/feature/ValueFeature.java | 148 ++ .../apache/solr/ltr/feature/package-info.java | 21 + .../solr/ltr/model/LTRScoringModel.java | 298 +++ .../apache/solr/ltr/model/LinearModel.java | 147 ++ .../apache/solr/ltr/model/ModelException.java | 31 + .../ltr/model/MultipleAdditiveTreesModel.java | 377 ++++ .../apache/solr/ltr/model/package-info.java | 21 + .../solr/ltr/norm/IdentityNormalizer.java | 53 + .../solr/ltr/norm/MinMaxNormalizer.java | 107 + .../org/apache/solr/ltr/norm/Normalizer.java | 64 + .../solr/ltr/norm/NormalizerException.java | 31 + .../solr/ltr/norm/StandardNormalizer.java | 99 + .../apache/solr/ltr/norm/package-info.java | 23 + .../org/apache/solr/ltr/package-info.java | 45 + .../apache/solr/ltr/store/FeatureStore.java | 67 + .../org/apache/solr/ltr/store/ModelStore.java | 74 + .../apache/solr/ltr/store/package-info.java | 21 + .../ltr/store/rest/ManagedFeatureStore.java | 215 ++ .../ltr/store/rest/ManagedModelStore.java | 319 +++ .../solr/ltr/store/rest/package-info.java | 22 + .../LTRFeatureLoggerTransformerFactory.java | 254 +++ .../solr/response/transform/package-info.java | 23 + .../apache/solr/search/LTRQParserPlugin.java | 233 +++ .../org/apache/solr/search/package-info.java | 23 + solr/contrib/ltr/src/java/overview.html | 91 + .../featureExamples/comp_features.json | 37 + .../featureExamples/external_features.json | 51 + ...ternal_features_for_sparse_processing.json | 18 + .../featureExamples/features-linear-efi.json | 17 + .../featureExamples/features-linear.json | 51 + .../features-store-test-model.json | 51 + .../featureExamples/fq_features.json | 16 + .../multipleadditivetreesmodel_features.json | 16 + .../ltr/src/test-files/log4j.properties | 32 + .../modelExamples/external_model.json | 12 + .../modelExamples/external_model_store.json | 13 + .../test-files/modelExamples/fq-model.json | 20 + .../modelExamples/linear-model-efi.json | 14 + .../modelExamples/linear-model.json | 30 + .../multipleadditivetreesmodel.json | 38 + ...vetreesmodel_external_binary_features.json | 38 + ...multipleadditivetreesmodel_no_feature.json | 24 + ...ultipleadditivetreesmodel_no_features.json | 14 + .../multipleadditivetreesmodel_no_left.json | 22 + .../multipleadditivetreesmodel_no_params.json | 8 + .../multipleadditivetreesmodel_no_right.json | 22 + ...ltipleadditivetreesmodel_no_threshold.json | 24 + .../multipleadditivetreesmodel_no_tree.json | 15 + .../multipleadditivetreesmodel_no_trees.json | 10 + .../multipleadditivetreesmodel_no_weight.json | 24 + .../solr/collection1/conf/schema.xml | 88 + .../solr/collection1/conf/solrconfig-ltr.xml | 65 + .../conf/solrconfig-ltr_Th10_10.xml | 69 + .../collection1/conf/solrconfig-multiseg.xml | 62 + .../solr/collection1/conf/stopwords.txt | 16 + .../solr/collection1/conf/synonyms.txt | 28 + solr/contrib/ltr/src/test-files/solr/solr.xml | 42 + .../apache/solr/ltr/TestLTROnSolrCloud.java | 211 ++ .../solr/ltr/TestLTRQParserExplain.java | 152 ++ .../apache/solr/ltr/TestLTRQParserPlugin.java | 114 ++ .../solr/ltr/TestLTRReRankingPipeline.java | 300 +++ .../apache/solr/ltr/TestLTRScoringQuery.java | 319 +++ .../org/apache/solr/ltr/TestLTRWithFacet.java | 103 + .../org/apache/solr/ltr/TestLTRWithSort.java | 102 + .../solr/ltr/TestParallelWeightCreation.java | 77 + .../org/apache/solr/ltr/TestRerankBase.java | 429 ++++ .../solr/ltr/TestSelectiveWeightCreation.java | 251 +++ .../ltr/feature/TestEdisMaxSolrFeature.java | 76 + .../ltr/feature/TestExternalFeatures.java | 157 ++ .../feature/TestExternalValueFeatures.java | 86 + ...FeatureExtractionFromMultipleSegments.java | 105 + .../solr/ltr/feature/TestFeatureLogging.java | 254 +++ .../feature/TestFeatureLtrScoringModel.java | 71 + .../solr/ltr/feature/TestFeatureStore.java | 106 + .../ltr/feature/TestFieldLengthFeature.java | 156 ++ .../ltr/feature/TestFieldValueFeature.java | 173 ++ .../ltr/feature/TestFilterSolrFeature.java | 105 + .../ltr/feature/TestNoMatchSolrFeature.java | 192 ++ .../ltr/feature/TestOriginalScoreFeature.java | 148 ++ .../solr/ltr/feature/TestRankingFeature.java | 123 ++ .../ltr/feature/TestUserTermScoreWithQ.java | 74 + .../ltr/feature/TestUserTermScorerQuery.java | 74 + .../ltr/feature/TestUserTermScorereQDF.java | 75 + .../solr/ltr/feature/TestValueFeature.java | 165 ++ .../solr/ltr/model/TestLinearModel.java | 207 ++ .../model/TestMultipleAdditiveTreesModel.java | 246 +++ .../solr/ltr/norm/TestMinMaxNormalizer.java | 120 ++ .../solr/ltr/norm/TestStandardNormalizer.java | 132 ++ .../store/rest/TestManagedFeatureStore.java | 36 + .../solr/ltr/store/rest/TestModelManager.java | 163 ++ .../rest/TestModelManagerPersistence.java | 121 ++ 117 files changed, 14167 insertions(+) create mode 100644 dev-tools/idea/solr/contrib/ltr/ltr.iml create mode 100644 solr/contrib/ltr/README.md create mode 120000 solr/contrib/ltr/README.txt create mode 100644 solr/contrib/ltr/build.xml create mode 100644 solr/contrib/ltr/example/config.json create mode 100644 solr/contrib/ltr/example/libsvm_formatter.py create mode 100644 solr/contrib/ltr/example/solrconfig.xml create mode 100644 solr/contrib/ltr/example/techproducts-features.json create mode 100644 solr/contrib/ltr/example/techproducts-model.json create mode 100755 solr/contrib/ltr/example/train_and_upload_demo_model.py create mode 100644 solr/contrib/ltr/example/user_queries.txt create mode 100644 solr/contrib/ltr/ivy.xml create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/DocInfo.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRThreadModule.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/SolrQueryRequestContextUtils.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/Feature.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FeatureException.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FieldLengthFeature.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/OriginalScoreFeature.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/SolrFeature.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/ValueFeature.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/package-info.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/model/LTRScoringModel.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/model/LinearModel.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/model/ModelException.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/model/MultipleAdditiveTreesModel.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/model/package-info.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/IdentityNormalizer.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/MinMaxNormalizer.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/Normalizer.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/NormalizerException.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/StandardNormalizer.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/package-info.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/package-info.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/store/FeatureStore.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/store/ModelStore.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/store/package-info.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/store/rest/ManagedFeatureStore.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/store/rest/ManagedModelStore.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/ltr/store/rest/package-info.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/response/transform/LTRFeatureLoggerTransformerFactory.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/response/transform/package-info.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/search/LTRQParserPlugin.java create mode 100644 solr/contrib/ltr/src/java/org/apache/solr/search/package-info.java create mode 100644 solr/contrib/ltr/src/java/overview.html create mode 100644 solr/contrib/ltr/src/test-files/featureExamples/comp_features.json create mode 100644 solr/contrib/ltr/src/test-files/featureExamples/external_features.json create mode 100644 solr/contrib/ltr/src/test-files/featureExamples/external_features_for_sparse_processing.json create mode 100644 solr/contrib/ltr/src/test-files/featureExamples/features-linear-efi.json create mode 100644 solr/contrib/ltr/src/test-files/featureExamples/features-linear.json create mode 100644 solr/contrib/ltr/src/test-files/featureExamples/features-store-test-model.json create mode 100644 solr/contrib/ltr/src/test-files/featureExamples/fq_features.json create mode 100644 solr/contrib/ltr/src/test-files/featureExamples/multipleadditivetreesmodel_features.json create mode 100644 solr/contrib/ltr/src/test-files/log4j.properties create mode 100644 solr/contrib/ltr/src/test-files/modelExamples/external_model.json create mode 100644 solr/contrib/ltr/src/test-files/modelExamples/external_model_store.json create mode 100644 solr/contrib/ltr/src/test-files/modelExamples/fq-model.json create mode 100644 solr/contrib/ltr/src/test-files/modelExamples/linear-model-efi.json create mode 100644 solr/contrib/ltr/src/test-files/modelExamples/linear-model.json create mode 100644 solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel.json create mode 100644 solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_external_binary_features.json create mode 100644 solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_feature.json create mode 100644 solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_features.json create mode 100644 solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_left.json create mode 100644 solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_params.json create mode 100644 solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_right.json create mode 100644 solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_threshold.json create mode 100644 solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_tree.json create mode 100644 solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_trees.json create mode 100644 solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_weight.json create mode 100644 solr/contrib/ltr/src/test-files/solr/collection1/conf/schema.xml create mode 100644 solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml create mode 100644 solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr_Th10_10.xml create mode 100644 solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-multiseg.xml create mode 100644 solr/contrib/ltr/src/test-files/solr/collection1/conf/stopwords.txt create mode 100644 solr/contrib/ltr/src/test-files/solr/collection1/conf/synonyms.txt create mode 100644 solr/contrib/ltr/src/test-files/solr/solr.xml create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTROnSolrCloud.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserExplain.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserPlugin.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRScoringQuery.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRWithFacet.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRWithSort.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/TestParallelWeightCreation.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestEdisMaxSolrFeature.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestExternalFeatures.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestExternalValueFeatures.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureLogging.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureLtrScoringModel.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureStore.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFieldLengthFeature.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFieldValueFeature.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFilterSolrFeature.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestNoMatchSolrFeature.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestOriginalScoreFeature.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestRankingFeature.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestUserTermScoreWithQ.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestUserTermScorerQuery.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestUserTermScorereQDF.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestValueFeature.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestLinearModel.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestMultipleAdditiveTreesModel.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/norm/TestMinMaxNormalizer.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/norm/TestStandardNormalizer.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/store/rest/TestManagedFeatureStore.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/store/rest/TestModelManager.java create mode 100644 solr/contrib/ltr/src/test/org/apache/solr/ltr/store/rest/TestModelManagerPersistence.java diff --git a/dev-tools/idea/.idea/modules.xml b/dev-tools/idea/.idea/modules.xml index 6fbe496772f..5d2d106a449 100644 --- a/dev-tools/idea/.idea/modules.xml +++ b/dev-tools/idea/.idea/modules.xml @@ -60,6 +60,7 @@ + diff --git a/dev-tools/idea/solr/contrib/ltr/ltr.iml b/dev-tools/idea/solr/contrib/ltr/ltr.iml new file mode 100644 index 00000000000..efc505d8d6f --- /dev/null +++ b/dev-tools/idea/solr/contrib/ltr/ltr.iml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/solr/CHANGES.txt b/solr/CHANGES.txt index fd4d2af1526..16cae8cab59 100644 --- a/solr/CHANGES.txt +++ b/solr/CHANGES.txt @@ -93,6 +93,9 @@ New Features SOLR_HOME on every node. Editing config through API is supported but affects only that one node. (janhoy) +* SOLR-8542: Adds Solr Learning to Rank (LTR) plugin for reranking results with machine learning models. + (Michael Nilsson, Diego Ceccarelli, Joshua Pantony, Jon Dorando, Naveen Santhapuri, Alessandro Benedetti, David Grohmann, Christine Poerschke) + Optimizations ---------------------- * SOLR-9704: Facet Module / JSON Facet API: Optimize blockChildren facets that have diff --git a/solr/contrib/ltr/README.md b/solr/contrib/ltr/README.md new file mode 100644 index 00000000000..5fe0087ba86 --- /dev/null +++ b/solr/contrib/ltr/README.md @@ -0,0 +1,406 @@ +Apache Solr Learning to Rank +======== + +This is the main [learning to rank integrated into solr](http://www.slideshare.net/lucidworks/learning-to-rank-in-solr-presented-by-michael-nilsson-diego-ceccarelli-bloomberg-lp) +repository. +[Read up on learning to rank](https://en.wikipedia.org/wiki/Learning_to_rank) + +Apache Solr Learning to Rank (LTR) provides a way for you to extract features +directly inside Solr for use in training a machine learned model. You can then +deploy that model to Solr and use it to rerank your top X search results. + +# Test the plugin with solr/example/techproducts in a few easy steps! + +Solr provides some simple example of indices. In order to test the plugin with +the techproducts example please follow these steps. + +1. Compile solr and the examples + + `cd solr` + `ant dist` + `ant server` + +2. Run the example to setup the index + + `./bin/solr -e techproducts` + +3. Stop solr and install the plugin: + 1. Stop solr + + `./bin/solr stop` + 2. Create the lib folder + + `mkdir example/techproducts/solr/techproducts/lib` + 3. Install the plugin in the lib folder + + `cp build/contrib/ltr/solr-ltr-7.0.0-SNAPSHOT.jar example/techproducts/solr/techproducts/lib/` + 4. Replace the original solrconfig with one importing all the ltr components + + `cp contrib/ltr/example/solrconfig.xml example/techproducts/solr/techproducts/conf/` + +4. Run the example again + + `./bin/solr -e techproducts` + + Note you could also have just restarted your collection using the admin page. + You can find more detailed instructions [here](https://wiki.apache.org/solr/SolrPlugins). + +5. Deploy features and a model + + `curl -XPUT 'http://localhost:8983/solr/techproducts/schema/feature-store' --data-binary "@./contrib/ltr/example/techproducts-features.json" -H 'Content-type:application/json'` + + `curl -XPUT 'http://localhost:8983/solr/techproducts/schema/model-store' --data-binary "@./contrib/ltr/example/techproducts-model.json" -H 'Content-type:application/json'` + +6. Have fun ! + + * Access to the default feature store + + http://localhost:8983/solr/techproducts/schema/feature-store/\_DEFAULT\_ + * Access to the model store + + http://localhost:8983/solr/techproducts/schema/model-store + * Perform a reranking query using the model, and retrieve the features + + http://localhost:8983/solr/techproducts/query?indent=on&q=test&wt=json&rq={!ltr%20model=linear%20reRankDocs=25%20efi.user_query=%27test%27}&fl=[features],price,score,name + + +BONUS: Train an actual machine learning model + +1. Download and install [liblinear](https://www.csie.ntu.edu.tw/~cjlin/liblinear/) + +2. Change `contrib/ltr/example/config.json` "trainingLibraryLocation" to point to the train directory where you installed liblinear. + +3. Extract features, train a reranking model, and deploy it to Solr. + + `cd contrib/ltr/example` + + `python train_and_upload_demo_model.py -c config.json` + + This script deploys your features from `config.json` "featuresFile" to Solr. Then it takes the relevance judged query + document pairs of "userQueriesFile" and merges it with the features extracted from Solr into a training + file. That file is used to train a linear model, which is then deployed to Solr for you to rerank results. + +4. Search and rerank the results using the trained model + + http://localhost:8983/solr/techproducts/query?indent=on&q=test&wt=json&rq={!ltr%20model=ExampleModel%20reRankDocs=25%20efi.user_query=%27test%27}&fl=price,score,name + +# Changes to solrconfig.xml +```xml + + ... + + + + + + + + + + ... + + + + + + + +``` + +# Defining Features +In the learning to rank plugin, you can define features in a feature space +using standard Solr queries. As an example: + +###### features.json +```json +[ +{ "name": "isBook", + "class": "org.apache.solr.ltr.feature.SolrFeature", + "params":{ "fq": ["{!terms f=category}book"] } +}, +{ + "name": "documentRecency", + "class": "org.apache.solr.ltr.feature.SolrFeature", + "params": { + "q": "{!func}recip( ms(NOW,publish_date), 3.16e-11, 1, 1)" + } +}, +{ + "name":"originalScore", + "class":"org.apache.solr.ltr.feature.OriginalScoreFeature", + "params":{} +}, +{ + "name" : "userTextTitleMatch", + "class" : "org.apache.solr.ltr.feature.SolrFeature", + "params" : { "q" : "{!field f=title}${user_text}" } +}, + { + "name" : "userFromMobile", + "class" : "org.apache.solr.ltr.feature.ValueFeature", + "params" : { "value" : "${userFromMobile}", "required":true } + } +] +``` + +Defines five features. Anything that is a valid Solr query can be used to define +a feature. + +### Filter Query Features +The first feature isBook fires if the term 'book' matches the category field +for the given examined document. Since in this feature q was not specified, +either the score 1 (in case of a match) or the score 0 (in case of no match) +will be returned. + +### Query Features +In the second feature (documentRecency) q was specified using a function query. +In this case the score for the feature on a given document is whatever the query +returns (1 for docs dated now, 1/2 for docs dated 1 year ago, 1/3 for docs dated +2 years ago, etc..) . If both an fq and q is used, documents that don't match +the fq will receive a score of 0 for the documentRecency feature, all other +documents will receive the score specified by the query for this feature. + +### Original Score Feature +The third feature (originalScore) has no parameters, and uses the +OriginalScoreFeature class instead of the SolrFeature class. Its purpose is +to simply return the score for the original search request against the current +matching document. + +### External Features +Users can specify external information that can to be passed in as +part of the query to the ltr ranking framework. In this case, the +fourth feature (userTextPhraseMatch) will be looking for an external field +called 'user_text' passed in through the request, and will fire if there is +a term match for the document field 'title' from the value of the external +field 'user_text'. You can provide default values for external features as +well by specifying ${myField:myDefault}, similar to how you would in a Solr config. +In this case, the fifth feature (userFromMobile) will be looking for an external parameter +called 'userFromMobile' passed in through the request, if the ValueFeature is : +required=true, it will throw an exception if the external feature is not passed +required=false, it will silently ignore the feature and avoid the scoring ( at Document scoring time, the model will consider 0 as feature value) +The advantage in defining a feature as not required, where possible, is to avoid wasting caching space and time in calculating the featureScore. +See the [Run a Rerank Query](#run-a-rerank-query) section for how to pass in external information. + +### Custom Features +Custom features can be created by extending from +org.apache.solr.ltr.feature.Feature, however this is generally not recommended. +The majority of features should be possible to create using the methods described +above. + +# Defining Models +Currently the Learning to Rank plugin supports 2 generalized forms of +models: 1. Linear Model i.e. [RankSVM](http://www.cs.cornell.edu/people/tj/publications/joachims_02c.pdf), [Pranking](https://papers.nips.cc/paper/2023-pranking-with-ranking.pdf) +and 2. Multiple Additive Trees i.e. [LambdaMART](http://research.microsoft.com/pubs/132652/MSR-TR-2010-82.pdf), [Gradient Boosted Regression Trees (GBRT)](https://papers.nips.cc/paper/3305-a-general-boosting-method-and-its-application-to-learning-ranking-functions-for-web-search.pdf) + +### Linear +If you'd like to introduce a bias set a constant feature +to the bias value you'd like and make a weight of 1.0 for that feature. + +###### model.json +```json +{ + "class":"org.apache.solr.ltr.model.LinearModel", + "name":"myModelName", + "features":[ + { "name": "userTextTitleMatch"}, + { "name": "originalScore"}, + { "name": "isBook"} + ], + "params":{ + "weights": { + "userTextTitleMatch": 1.0, + "originalScore": 0.5, + "isBook": 0.1 + } + + } +} +``` + +This is an example of a toy Linear model. Class specifies the class to be +using to interpret the model. Name is the model identifier you will use +when making request to the ltr framework. Features specifies the feature +space that you want extracted when using this model. All features that +appear in the model params will be used for scoring and must appear in +the features list. You can add extra features to the features list that +will be computed but not used in the model for scoring, which can be useful +for logging. Params are the Linear parameters. + +Good library for training SVM, an example of a Linear model, is +(https://www.csie.ntu.edu.tw/~cjlin/liblinear/ , https://www.csie.ntu.edu.tw/~cjlin/libsvm/) . +You will need to convert the libSVM model format to the format specified above. + +### Multiple Additive Trees + +###### model2.json +```json +{ + "class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel", + "name":"multipleadditivetreesmodel", + "features":[ + { "name": "userTextTitleMatch"}, + { "name": "originalScore"} + ], + "params":{ + "trees": [ + { + "weight" : 1, + "root": { + "feature": "userTextTitleMatch", + "threshold": 0.5, + "left" : { + "value" : -100 + }, + "right": { + "feature" : "originalScore", + "threshold": 10.0, + "left" : { + "value" : 50 + }, + "right" : { + "value" : 75 + } + } + } + }, + { + "weight" : 2, + "root": { + "value" : -10 + } + } + ] + } +} +``` +This is an example of a toy Multiple Additive Trees. Class specifies the class to be using to +interpret the model. Name is the +model identifier you will use when making request to the ltr framework. +Features specifies the feature space that you want extracted when using this +model. All features that appear in the model params will be used for scoring and +must appear in the features list. You can add extra features to the features +list that will be computed but not used in the model for scoring, which can +be useful for logging. Params are the Multiple Additive Trees specific parameters. In this +case we have 2 trees, one with 3 leaf nodes and one with 1 leaf node. + +A good library for training LambdaMART, an example of Multiple Additive Trees, is ( http://sourceforge.net/p/lemur/wiki/RankLib/ ). +You will need to convert the RankLib model format to the format specified above. + +# Deploy Models and Features +To send features run + +`curl -XPUT 'http://localhost:8983/solr/collection1/schema/feature-store' --data-binary @/path/features.json -H 'Content-type:application/json'` + +To send models run + +`curl -XPUT 'http://localhost:8983/solr/collection1/schema/model-store' --data-binary @/path/model.json -H 'Content-type:application/json'` + + +# View Models and Features +`curl -XGET 'http://localhost:8983/solr/collection1/schema/feature-store'` + +`curl -XGET 'http://localhost:8983/solr/collection1/schema/model-store'` + +# Run a Rerank Query +Add to your original solr query +`rq={!ltr model=myModelName reRankDocs=25}` + +The model name is the name of the model you sent to solr earlier. +The number of documents you want reranked, which can be larger than the +number you display, is reRankDocs. + +### Pass in external information for external features +Add to your original solr query +`rq={!ltr reRankDocs=3 model=externalmodel efi.field1='text1' efi.field2='text2'}` + +Where "field1" specifies the name of the customized field to be used by one +or more of your features, and text1 is the information to be pass in. As an +example that matches the earlier shown userTextTitleMatch feature one could do: + +`rq={!ltr reRankDocs=3 model=externalmodel efi.user_text='Casablanca' efi.user_intent='movie'}` + +# Extract features +To extract features you need to use the feature vector transformer `features` + +`fl=*,score,[features]&rq={!ltr model=yourModel reRankDocs=25}` + +If you use `[features]` together with your reranking model, it will return +the array of features used by your model. Otherwise you can just ask solr to +produce the features without doing the reranking: + +`fl=*,score,[features store=yourFeatureStore format=[dense|sparse] ]` + +This will return the values of the features in the given store. The format of the +extracted features will be based on the format parameter. The default is sparse. + +# Assemble training data +In order to train a learning to rank model you need training data. Training data is +what "teaches" the model what the appropriate weight for each feature is. In general +training data is a collection of queries with associated documents and what their ranking/score +should be. As an example: +``` +secretary of state|John Kerry|0.66|CROWDSOURCE +secretary of state|Cesar A. Perales|0.33|CROWDSOURCE +secretary of state|New York State|0.0|CROWDSOURCE +secretary of state|Colorado State University Secretary|0.0|CROWDSOURCE + +microsoft ceo|Satya Nadella|1.0|CLICK_LOG +microsoft ceo|Microsoft|0.0|CLICK_LOG +microsoft ceo|State|0.0|CLICK_LOG +microsoft ceo|Secretary|0.0|CLICK_LOG +``` +In this example the first column indicates the query, the second column indicates a unique id for that doc, +the third column indicates the relative importance or relevance of that doc, and the fourth column indicates the source. +There are 2 primary ways you might collect data for use with your machine learning algorithim. The first +is to collect the clicks of your users given a specific query. There are many ways of preparing this data +to train a model (http://www.cs.cornell.edu/people/tj/publications/joachims_etal_05a.pdf). The general idea +is that if a user sees multiple documents and clicks the one lower down, that document should be scored higher +than the one above it. The second way is explicitly through a crowdsourcing platform like Mechanical Turk or +CrowdFlower. These platforms allow you to show human workers documents associated with a query and have them +tell you what the correct ranking should be. + +At this point you'll need to collect feature vectors for each query document pair. You can use the information +from the Extract features section above to do this. An example script has been included in example/train_and_upload_demo_model.py. + +# Explanation of the core reranking logic +An LTR model is plugged into the ranking through the [LTRQParserPlugin](/solr/contrib/ltr/src/java/org/apache/solr/search/LTRQParserPlugin.java). The plugin will +read from the request the model, an instance of [LTRScoringModel](/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/LTRScoringModel.java), +plus other parameters. The plugin will generate an LTRQuery, a particular [ReRankQuery](/solr/core/src/java/org/apache/solr/search/AbstractReRankQuery.java). +It wraps the original solr query for the first pass ranking, and uses the provided model in an +[LTRScoringQuery](/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java) to +rescore and rerank the top documents. The LTRScoringQuery will take care of computing the values of all the +[features](/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/Feature.java) and then will delegate the final score +generation to the LTRScoringModel. + +# Speeding up the weight creation with threads +About half the time for ranking is spent in the creation of weights for each feature used in ranking. If the number of features is significantly high (say, 500 or more), this increases the ranking overhead proportionally. To alleviate this problem, parallel weight creation is provided as a configurable option. In order to use this feature, the following lines need to be added to the solrconfig.xml +```xml + + + + + 10 + 5 + + + + + 10 + 5 + + + +``` + +The threadModule.totalPoolThreads option limits the total number of threads to be used across all query instances at any given time. threadModule.numThreadsPerRequest limits the number of threads used to process a single query. In the above example, 10 threads will be used to services all queries and a maximum of 5 threads to service a single query. If the solr instances is expected to receive no more than one query at a time, it is best to set both these numbers to the same value. If multiple queries need to serviced simultaneously, the numbers can be adjusted based on the expected response times. If the value of threadModule.numThreadsPerRequest is higher, the reponse time for a single query will be improved upto a point. If multiple queries are serviced simultaneously, the threadModule.totalPoolThreads imposes a contention between the queries if (threadModule.numThreadsPerRequest*total parallel queries > threadModule.totalPoolThreads). + diff --git a/solr/contrib/ltr/README.txt b/solr/contrib/ltr/README.txt new file mode 120000 index 00000000000..42061c01a1c --- /dev/null +++ b/solr/contrib/ltr/README.txt @@ -0,0 +1 @@ +README.md \ No newline at end of file diff --git a/solr/contrib/ltr/build.xml b/solr/contrib/ltr/build.xml new file mode 100644 index 00000000000..bbd5cf3d9b1 --- /dev/null +++ b/solr/contrib/ltr/build.xml @@ -0,0 +1,30 @@ + + + + + + + + Learning to Rank Package + + + + + + + diff --git a/solr/contrib/ltr/example/config.json b/solr/contrib/ltr/example/config.json new file mode 100644 index 00000000000..483fe690e16 --- /dev/null +++ b/solr/contrib/ltr/example/config.json @@ -0,0 +1,14 @@ +{ + "host": "localhost", + "port": 8983, + "collection": "techproducts", + "requestHandler": "query", + "q": "*:*", + "otherParams": "fl=id,score,[features efi.user_query='$USERQUERY']", + "userQueriesFile": "user_queries.txt", + "trainingFile": "ClickData", + "featuresFile": "techproducts-features.json", + "trainingLibraryLocation": "liblinear/train", + "solrModelFile": "solrModel.json", + "solrModelName": "ExampleModel" +} diff --git a/solr/contrib/ltr/example/libsvm_formatter.py b/solr/contrib/ltr/example/libsvm_formatter.py new file mode 100644 index 00000000000..25cf10ba05d --- /dev/null +++ b/solr/contrib/ltr/example/libsvm_formatter.py @@ -0,0 +1,124 @@ +from subprocess import call +import os + +PAIRWISE_THRESHOLD = 1.e-1 +FEATURE_DIFF_THRESHOLD = 1.e-6 + +class LibSvmFormatter: + def processQueryDocFeatureVector(self,docClickInfo,trainingFile): + '''Expects as input a sorted by queries list or generator that provides the context + for each query in a tuple composed of: (query , docId , relevance , source , featureVector). + The list of documents that are part of the same query will generate comparisons + against each other for training. ''' + curQueryAndSource = ""; + with open(trainingFile,"w") as output: + self.featureNameToId = {} + self.featureIdToName = {} + self.curFeatIndex = 1; + curListOfFv = [] + for query,docId,relevance,source,featureVector in docClickInfo: + if curQueryAndSource != query + source: + #Time to flush out all the pairs + _writeRankSVMPairs(curListOfFv,output); + curListOfFv = [] + curQueryAndSource = query + source + curListOfFv.append((relevance,self._makeFeaturesMap(featureVector))) + _writeRankSVMPairs(curListOfFv,output); #This catches the last list of comparisons + + def _makeFeaturesMap(self,featureVector): + '''expects a list of strings with "feature name":"feature value" pairs. Outputs a map of map[key] = value. + Where key is now an integer. libSVM requires the key to be an integer but not all libraries have + this requirement.''' + features = {} + for keyValuePairStr in featureVector: + featName,featValue = keyValuePairStr.split(":"); + features[self._getFeatureId(featName)] = float(featValue); + return features + + def _getFeatureId(self,key): + if key not in self.featureNameToId: + self.featureNameToId[key] = self.curFeatIndex; + self.featureIdToName[self.curFeatIndex] = key; + self.curFeatIndex += 1; + return self.featureNameToId[key]; + + def convertLibSvmModelToLtrModel(self,libSvmModelLocation, outputFile, modelName): + with open(libSvmModelLocation, 'r') as inFile: + with open(outputFile,'w') as convertedOutFile: + convertedOutFile.write('{\n\t"class":"org.apache.solr.ltr.model.LinearModel",\n') + convertedOutFile.write('\t"name": "' + str(modelName) + '",\n') + convertedOutFile.write('\t"features": [\n') + isFirst = True; + for featKey in self.featureNameToId.keys(): + convertedOutFile.write('\t\t{ "name":"' + featKey + '"}' if isFirst else ',\n\t\t{ "name":"' + featKey + '"}' ); + isFirst = False; + convertedOutFile.write("\n\t],\n"); + convertedOutFile.write('\t"params": {\n\t\t"weights": {\n'); + + startReading = False + isFirst = True + counter = 1 + for line in inFile: + if startReading: + newParamVal = float(line.strip()) + if not isFirst: + convertedOutFile.write(',\n\t\t\t"' + self.featureIdToName[counter] + '":' + str(newParamVal)) + else: + convertedOutFile.write('\t\t\t"' + self.featureIdToName[counter] + '":' + str(newParamVal)) + isFirst = False + counter += 1 + elif line.strip() == 'w': + startReading = True + convertedOutFile.write('\n\t\t}\n\t}\n}') + +def _writeRankSVMPairs(listOfFeatures,output): + '''Given a list of (relevance, {Features Map}) where the list represents + a set of documents to be compared, this calculates all pairs and + writes the Feature Vectors in a format compatible with libSVM. + Ex: listOfFeatures = [ + #(relevance, {feature1:value, featureN:value}) + (4, {1:0.9, 2:0.9, 3:0.1}) + (3, {1:0.7, 2:0.9, 3:0.2}) + (1, {1:0.1, 2:0.9, 6:0.1}) + ] + ''' + for d1 in range(0,len(listOfFeatures)): + for d2 in range(d1+1,len(listOfFeatures)): + doc1,doc2 = listOfFeatures[d1], listOfFeatures[d2] + fv1,fv2 = doc1[1],doc2[1] + d1Relevance, d2Relevance = float(doc1[0]),float(doc2[0]) + if d1Relevance - d2Relevance > PAIRWISE_THRESHOLD:#d1Relevance > d2Relevance + outputLibSvmLine("+1",subtractFvMap(fv1,fv2),output); + outputLibSvmLine("-1",subtractFvMap(fv2,fv1),output); + elif d1Relevance - d2Relevance < -PAIRWISE_THRESHOLD: #d1Relevance < d2Relevance: + outputLibSvmLine("+1",subtractFvMap(fv2,fv1),output); + outputLibSvmLine("-1",subtractFvMap(fv1,fv2),output); + else: #Must be approximately equal relevance, in which case this is a useless signal and we should skip + continue; + +def subtractFvMap(fv1,fv2): + '''returns the fv from fv1 - fv2''' + retFv = fv1.copy(); + for featInd in fv2.keys(): + subVal = 0.0; + if featInd in fv1: + subVal = fv1[featInd] - fv2[featInd] + else: + subVal = -fv2[featInd] + if abs(subVal) > FEATURE_DIFF_THRESHOLD: #This ensures everything is in sparse format, and removes useless signals + retFv[featInd] = subVal; + else: + retFv.pop(featInd, None) + return retFv; + +def outputLibSvmLine(sign,fvMap,outputFile): + outputFile.write(sign) + for feat in fvMap.keys(): + outputFile.write(" " + str(feat) + ":" + str(fvMap[feat])); + outputFile.write("\n") + +def trainLibSvm(libraryLocation,trainingFileName): + if os.path.isfile(libraryLocation): + call([libraryLocation, trainingFileName]) + else: + raise Exception("NO LIBRARY FOUND: " + libraryLocation); diff --git a/solr/contrib/ltr/example/solrconfig.xml b/solr/contrib/ltr/example/solrconfig.xml new file mode 100644 index 00000000000..18d6cb83b66 --- /dev/null +++ b/solr/contrib/ltr/example/solrconfig.xml @@ -0,0 +1,1722 @@ + + + + + + + + + 6.0.0 + + + + + + + + + + + + + + + + + + + + + + + ${solr.data.dir:} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ${solr.lock.type:native} + + + + + + + + + + + + + true + + + + + + + + + + + + + + + + ${solr.ulog.dir:} + ${solr.ulog.numVersionBuckets:65536} + + + + + ${solr.autoCommit.maxTime:15000} + false + + + + + + ${solr.autoSoftCommit.maxTime:-1} + + + + + + + + + + + + + + + + 1024 + + + + -1 + + + + + + + + + + + + + + + + + + + + + + + + + + + true + + + + + + 20 + + + 200 + + + + + + + + + + + + static firstSearcher warming in solrconfig.xml + + + + + + false + + + 2 + + + + + + + + + + + + + + + + + + + + + + + explicit + 10 + + false + + + + + + + + + + + + + 10 + 10 + + + + + + + + + + explicit + json + true + text + + + + + + + + explicit + + + velocity + browse + layout + Solritas + + + edismax + + text^0.5 features^1.0 name^1.2 sku^1.5 id^10.0 manu^1.1 cat^1.4 + title^10.0 description^5.0 keywords^5.0 author^2.0 resourcename^1.0 + + 100% + *:* + 10 + *,score + + + text^0.5 features^1.0 name^1.2 sku^1.5 id^10.0 manu^1.1 cat^1.4 + title^10.0 description^5.0 keywords^5.0 author^2.0 resourcename^1.0 + + text,features,name,sku,id,manu,cat,title,description,keywords,author,resourcename + 3 + + + on + true + cat + manu_exact + content_type + author_s + ipod + GB + 1 + cat,inStock + after + price + 0 + 600 + 50 + popularity + 0 + 10 + 3 + manufacturedate_dt + NOW/YEAR-10YEARS + NOW + +1YEAR + before + after + + + on + content features title name + true + html + <b> + </b> + 0 + title + 0 + name + 3 + 200 + content + 750 + + + on + false + 5 + 2 + 5 + true + true + 5 + 3 + + + + + spellcheck + + + + + + + text + + + + + + + _src_ + + true + + + + + + + + + + true + ignored_ + + + true + links + ignored_ + + + + + + + + + + + + + + + explicit + true + + + + + + + + + text_general + + + + + + default + text + solr.DirectSolrSpellChecker + + internal + + 0.5 + + 2 + + 1 + + 5 + + 4 + + 0.01 + + + + + + wordbreak + solr.WordBreakSolrSpellChecker + name + true + true + 10 + + + + + + + + + + + + + + + + + default + wordbreak + on + true + 10 + 5 + 5 + true + true + 10 + 5 + + + spellcheck + + + + + + + mySuggester + FuzzyLookupFactory + DocumentDictionaryFactory + cat + price + string + false + + + + + + true + 10 + + + suggest + + + + + + + + + + + true + + + tvComponent + + + + + + + + + lingo3g + true + com.carrotsearch.lingo3g.Lingo3GClusteringAlgorithm + clustering/carrot2 + + + + lingo + org.carrot2.clustering.lingo.LingoClusteringAlgorithm + clustering/carrot2 + + + + stc + org.carrot2.clustering.stc.STCClusteringAlgorithm + clustering/carrot2 + + + + kmeans + org.carrot2.clustering.kmeans.BisectingKMeansClusteringAlgorithm + clustering/carrot2 + + + + + + + true + true + + name + + id + + features + + true + + + + false + + + edismax + + text^0.5 features^1.0 name^1.2 sku^1.5 id^10.0 manu^1.1 cat^1.4 + + *:* + 100 + *,score + + + clustering + + + + + + + + + + true + false + + + terms + + + + + + + + string + elevate.xml + + + + + + explicit + + + elevator + + + + + + + + + + + 100 + + + + + + + + 70 + + 0.5 + + [-\w ,/\n\"']{20,200} + + + + + + + ]]> + ]]> + + + + + + + + + + + + + + + + + + + + + + + + ,, + ,, + ,, + ,, + ,]]> + ]]> + + + + + + 10 + .,!? + + + + + + + WORD + + + en + US + + + + + + + + + + + + + + + + + + + + + + text/plain; charset=UTF-8 + + + + + ${velocity.template.base.dir:} + + + + + + 5 + + + + + + + + + + + + + + + + + + *:* + + + diff --git a/solr/contrib/ltr/example/techproducts-features.json b/solr/contrib/ltr/example/techproducts-features.json new file mode 100644 index 00000000000..f358f8bc968 --- /dev/null +++ b/solr/contrib/ltr/example/techproducts-features.json @@ -0,0 +1,26 @@ +[ +{ + "name": "isInStock", + "class": "org.apache.solr.ltr.feature.FieldValueFeature", + "params": { + "field": "inStock" + } +}, +{ + "name": "price", + "class": "org.apache.solr.ltr.feature.FieldValueFeature", + "params": { + "field": "price" + } +}, +{ + "name":"originalScore", + "class":"org.apache.solr.ltr.feature.OriginalScoreFeature", + "params":{} +}, +{ + "name" : "productNameMatchQuery", + "class" : "org.apache.solr.ltr.feature.SolrFeature", + "params" : { "q" : "{!field f=name}${user_query}" } +} +] diff --git a/solr/contrib/ltr/example/techproducts-model.json b/solr/contrib/ltr/example/techproducts-model.json new file mode 100644 index 00000000000..0efded7bc38 --- /dev/null +++ b/solr/contrib/ltr/example/techproducts-model.json @@ -0,0 +1,18 @@ +{ + "class":"org.apache.solr.ltr.model.LinearModel", + "name":"linear", + "features":[ + {"name":"isInStock"}, + {"name":"price"}, + {"name":"originalScore"}, + {"name":"productNameMatchQuery"} + ], + "params":{ + "weights":{ + "isInStock":15.0, + "price":1.0, + "originalScore":5.0, + "productNameMatchQuery":1.0 + } + } +} diff --git a/solr/contrib/ltr/example/train_and_upload_demo_model.py b/solr/contrib/ltr/example/train_and_upload_demo_model.py new file mode 100755 index 00000000000..c3762de667f --- /dev/null +++ b/solr/contrib/ltr/example/train_and_upload_demo_model.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python + +import sys +import json +import httplib +import urllib +import libsvm_formatter + +from optparse import OptionParser + +solrQueryUrl = "" + +def generateQueries(config): + with open(config["userQueriesFile"]) as input: + solrQueryUrls = [] #A list of tuples with solrQueryUrl,solrQuery,docId,scoreForPQ,source + + for line in input: + line = line.strip(); + searchText,docId,score,source = line.split("|"); + solrQuery = generateHttpRequest(config,searchText,docId) + solrQueryUrls.append((solrQuery,searchText,docId,score,source)) + + return solrQueryUrls; + +def generateHttpRequest(config,searchText,docId): + global solrQueryUrl + if len(solrQueryUrl) < 1: + solrQueryUrl = "/solr/%(collection)s/%(requestHandler)s?%(otherParams)s&q=" % config + solrQueryUrl = solrQueryUrl.replace(" ","+") + solrQueryUrl += urllib.quote_plus("id:") + + + userQuery = urllib.quote_plus(searchText.strip().replace("'","\\'").replace("/","\\\\/")) + solrQuery = solrQueryUrl + '"' + urllib.quote_plus(docId) + '"' #+ solrQueryUrlEnd + solrQuery = solrQuery.replace("%24USERQUERY", userQuery).replace('$USERQUERY', urllib.quote_plus("\\'" + userQuery + "\\'")) + + return solrQuery + +def generateTrainingData(solrQueries, config): + '''Given a list of solr queries, yields a tuple of query , docId , score , source , feature vector for each query. + Feature Vector is a list of strings of form "key:value"''' + conn = httplib.HTTPConnection(config["host"], config["port"]) + headers = {"Connection":" keep-alive"} + + try: + for queryUrl,query,docId,score,source in solrQueries: + conn.request("GET", queryUrl, headers=headers) + r = conn.getresponse() + msg = r.read() + msgDict = json.loads(msg) + fv = "" + docs = msgDict['response']['docs'] + if len(docs) > 0 and "[features]" in docs[0]: + if not msgDict['response']['docs'][0]["[features]"] == None: + fv = msgDict['response']['docs'][0]["[features]"]; + else: + print "ERROR NULL FV FOR: " + docId; + print msg + continue; + else: + print "ERROR FOR: " + docId; + print msg + continue; + + if r.status == httplib.OK: + #print "http connection was ok for: " + queryUrl + yield(query,docId,score,source,fv.split(";")); + else: + raise Exception("Status: {0} {1}\nResponse: {2}".format(r.status, r.reason, msg)) + except Exception as e: + print msg + print e + + conn.close() + +def setupSolr(config): + '''Sets up solr with the proper features for the test''' + + conn = httplib.HTTPConnection(config["host"], config["port"]) + + baseUrl = "/solr/" + config["collection"] + featureUrl = baseUrl + "/schema/feature-store" + + # CAUTION! This will delete all feature stores. This is just for demo purposes + conn.request("DELETE", featureUrl+"/*") + r = conn.getresponse() + msg = r.read() + if (r.status != httplib.OK and + r.status != httplib.CREATED and + r.status != httplib.ACCEPTED and + r.status != httplib.NOT_FOUND): + raise Exception("Status: {0} {1}\nResponse: {2}".format(r.status, r.reason, msg)) + + + # Add features + headers = {'Content-type': 'application/json'} + featuresBody = open(config["featuresFile"]) + + conn.request("POST", featureUrl, featuresBody, headers) + r = conn.getresponse() + msg = r.read() + if (r.status != httplib.OK and + r.status != httplib.ACCEPTED): + print r.status + print "" + print r.reason; + raise Exception("Status: {0} {1}\nResponse: {2}".format(r.status, r.reason, msg)) + + conn.close() + + +def main(argv=None): + if argv is None: + argv = sys.argv + + parser = OptionParser(usage="usage: %prog [options] ", version="%prog 1.0") + parser.add_option('-c', '--config', + dest='configFile', + help='File of configuration for the test') + (options, args) = parser.parse_args() + + if options.configFile == None: + parser.print_help() + return 1 + + with open(options.configFile) as configFile: + config = json.load(configFile) + + print "Uploading feature space to Solr" + setupSolr(config) + + print "Generating feature extraction Solr queries" + reRankQueries = generateQueries(config) + + print "Extracting features" + fvGenerator = generateTrainingData(reRankQueries, config); + formatter = libsvm_formatter.LibSvmFormatter(); + formatter.processQueryDocFeatureVector(fvGenerator,config["trainingFile"]); + + print "Training ranksvm model" + libsvm_formatter.trainLibSvm(config["trainingLibraryLocation"],config["trainingFile"]) + + print "Converting ranksvm model to solr model" + formatter.convertLibSvmModelToLtrModel(config["trainingFile"] + ".model", config["solrModelFile"], config["solrModelName"]) + + print "Uploading model to solr" + uploadModel(config["collection"], config["host"], config["port"], config["solrModelFile"]) + +def uploadModel(collection, host, port, modelFile): + modelUrl = "/solr/" + collection + "/schema/model-store" + headers = {'Content-type': 'application/json'} + with open(modelFile) as modelBody: + conn = httplib.HTTPConnection(host, port) + conn.request("POST", modelUrl, modelBody, headers) + r = conn.getresponse() + msg = r.read() + if (r.status != httplib.OK and + r.status != httplib.CREATED and + r.status != httplib.ACCEPTED): + raise Exception("Status: {0} {1}\nResponse: {2}".format(r.status, r.reason, msg)) + +if __name__ == '__main__': + sys.exit(main()) diff --git a/solr/contrib/ltr/example/user_queries.txt b/solr/contrib/ltr/example/user_queries.txt new file mode 100644 index 00000000000..a3a345527f0 --- /dev/null +++ b/solr/contrib/ltr/example/user_queries.txt @@ -0,0 +1,8 @@ +hard drive|SP2514N|0.6666666|CLICK_LOGS +hard drive|6H500F0|0.330082034|CLICK_LOGS +hard drive|F8V7067-APL-KIT|0.0|CLICK_LOGS +hard drive|IW-02|0.0|CLICK_LOGS +ipod|MA147LL/A|1.0|EXPLICIT +ipod|F8V7067-APL-KIT|0.25|EXPLICIT +ipod|IW-02|0.25|EXPLICIT +ipod|6H500F0|0.0|EXPLICIT diff --git a/solr/contrib/ltr/ivy.xml b/solr/contrib/ltr/ivy.xml new file mode 100644 index 00000000000..68e9797bb09 --- /dev/null +++ b/solr/contrib/ltr/ivy.xml @@ -0,0 +1,32 @@ + + + + + + + + + + + + + + + diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/DocInfo.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/DocInfo.java new file mode 100644 index 00000000000..b3dfb9eee4b --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/DocInfo.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr; + +import java.util.HashMap; + +public class DocInfo extends HashMap { + + // Name of key used to store the original score of a doc + private static final String ORIGINAL_DOC_SCORE = "ORIGINAL_DOC_SCORE"; + + public DocInfo() { + super(); + } + + public void setOriginalDocScore(Float score) { + put(ORIGINAL_DOC_SCORE, score); + } + + public Float getOriginalDocScore() { + return (Float)get(ORIGINAL_DOC_SCORE); + } + + public boolean hasOriginalDocScore() { + return containsKey(ORIGINAL_DOC_SCORE); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java new file mode 100644 index 00000000000..a5afd05952c --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr; + +import java.lang.invoke.MethodHandles; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.apache.solr.search.SolrIndexSearcher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * FeatureLogger can be registered in a model and provide a strategy for logging + * the feature values. + */ +public abstract class FeatureLogger { + + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + /** the name of the cache using for storing the feature value **/ + private static final String QUERY_FV_CACHE_NAME = "QUERY_DOC_FV"; + + protected enum FeatureFormat {DENSE, SPARSE}; + protected final FeatureFormat featureFormat; + + protected FeatureLogger(FeatureFormat f) { + this.featureFormat = f; + } + + /** + * Log will be called every time that the model generates the feature values + * for a document and a query. + * + * @param docid + * Solr document id whose features we are saving + * @param featuresInfo + * List of all the {@link LTRScoringQuery.FeatureInfo} objects which contain name and value + * for all the features triggered by the result set + * @return true if the logger successfully logged the features, false + * otherwise. + */ + + public boolean log(int docid, LTRScoringQuery scoringQuery, + SolrIndexSearcher searcher, LTRScoringQuery.FeatureInfo[] featuresInfo) { + final FV_TYPE featureVector = makeFeatureVector(featuresInfo); + if (featureVector == null) { + return false; + } + + return searcher.cacheInsert(QUERY_FV_CACHE_NAME, + fvCacheKey(scoringQuery, docid), featureVector) != null; + } + + /** + * returns a FeatureLogger that logs the features in output, using the format + * specified in the 'stringFormat' param: 'csv' will log the features as a unique + * string in csv format 'json' will log the features in a map in a Map of + * featureName keys to featureValue values if format is null or empty, csv + * format will be selected. + * 'featureFormat' param: 'dense' will write features in dense format, + * 'sparse' will write the features in sparse format, null or empty will + * default to 'sparse' + * + * + * @return a feature logger for the format specified. + */ + public static FeatureLogger createFeatureLogger(String stringFormat, String featureFormat) { + final FeatureFormat f; + if (featureFormat == null || featureFormat.isEmpty() || + featureFormat.equals("sparse")) { + f = FeatureFormat.SPARSE; + } + else if (featureFormat.equals("dense")) { + f = FeatureFormat.DENSE; + } + else { + f = FeatureFormat.SPARSE; + log.warn("unknown feature logger feature format {} | {}", stringFormat, featureFormat); + } + if ((stringFormat == null) || stringFormat.isEmpty()) { + return new CSVFeatureLogger(f); + } + if (stringFormat.equals("csv")) { + return new CSVFeatureLogger(f); + } + if (stringFormat.equals("json")) { + return new MapFeatureLogger(f); + } + log.warn("unknown feature logger string format {} | {}", stringFormat, featureFormat); + return null; + + } + + public abstract FV_TYPE makeFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo); + + private static int fvCacheKey(LTRScoringQuery scoringQuery, int docid) { + return scoringQuery.hashCode() + (31 * docid); + } + + /** + * populate the document with its feature vector + * + * @param docid + * Solr document id + * @return String representation of the list of features calculated for docid + */ + + public FV_TYPE getFeatureVector(int docid, LTRScoringQuery scoringQuery, + SolrIndexSearcher searcher) { + return (FV_TYPE) searcher.cacheLookup(QUERY_FV_CACHE_NAME, fvCacheKey(scoringQuery, docid)); + } + + + public static class MapFeatureLogger extends FeatureLogger> { + + public MapFeatureLogger(FeatureFormat f) { + super(f); + } + + @Override + public Map makeFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo) { + boolean isDense = featureFormat.equals(FeatureFormat.DENSE); + Map hashmap = Collections.emptyMap(); + if (featuresInfo.length > 0) { + hashmap = new HashMap(featuresInfo.length); + for (LTRScoringQuery.FeatureInfo featInfo:featuresInfo){ + if (featInfo.isUsed() || isDense){ + hashmap.put(featInfo.getName(), featInfo.getValue()); + } + } + } + return hashmap; + } + + } + + public static class CSVFeatureLogger extends FeatureLogger { + StringBuilder sb = new StringBuilder(500); + char keyValueSep = ':'; + char featureSep = ';'; + + public CSVFeatureLogger(FeatureFormat f) { + super(f); + } + + public CSVFeatureLogger setKeyValueSep(char keyValueSep) { + this.keyValueSep = keyValueSep; + return this; + } + + public CSVFeatureLogger setFeatureSep(char featureSep) { + this.featureSep = featureSep; + return this; + } + + @Override + public String makeFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo) { + boolean isDense = featureFormat.equals(FeatureFormat.DENSE); + for (LTRScoringQuery.FeatureInfo featInfo:featuresInfo) { + if (featInfo.isUsed() || isDense){ + sb.append(featInfo.getName()) + .append(keyValueSep) + .append(featInfo.getValue()) + .append(featureSep); + } + } + + final String features = (sb.length() > 0 ? sb.substring(0, + sb.length() - 1) : ""); + sb.setLength(0); + + return features; + } + + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java new file mode 100644 index 00000000000..27223b770f9 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.ReaderUtil; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Rescorer; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Weight; +import org.apache.solr.search.SolrIndexSearcher; + + +/** + * Implements the rescoring logic. The top documents returned by solr with their + * original scores, will be processed by a {@link LTRScoringQuery} that will assign a + * new score to each document. The top documents will be resorted based on the + * new score. + * */ +public class LTRRescorer extends Rescorer { + + LTRScoringQuery scoringQuery; + public LTRRescorer(LTRScoringQuery scoringQuery) { + this.scoringQuery = scoringQuery; + } + + private void heapAdjust(ScoreDoc[] hits, int size, int root) { + final ScoreDoc doc = hits[root]; + final float score = doc.score; + int i = root; + while (i <= ((size >> 1) - 1)) { + final int lchild = (i << 1) + 1; + final ScoreDoc ldoc = hits[lchild]; + final float lscore = ldoc.score; + float rscore = Float.MAX_VALUE; + final int rchild = (i << 1) + 2; + ScoreDoc rdoc = null; + if (rchild < size) { + rdoc = hits[rchild]; + rscore = rdoc.score; + } + if (lscore < score) { + if (rscore < lscore) { + hits[i] = rdoc; + hits[rchild] = doc; + i = rchild; + } else { + hits[i] = ldoc; + hits[lchild] = doc; + i = lchild; + } + } else if (rscore < score) { + hits[i] = rdoc; + hits[rchild] = doc; + i = rchild; + } else { + return; + } + } + } + + private void heapify(ScoreDoc[] hits, int size) { + for (int i = (size >> 1) - 1; i >= 0; i--) { + heapAdjust(hits, size, i); + } + } + + /** + * rescores the documents: + * + * @param searcher + * current IndexSearcher + * @param firstPassTopDocs + * documents to rerank; + * @param topN + * documents to return; + */ + @Override + public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, + int topN) throws IOException { + if ((topN == 0) || (firstPassTopDocs.totalHits == 0)) { + return firstPassTopDocs; + } + final ScoreDoc[] hits = firstPassTopDocs.scoreDocs; + Arrays.sort(hits, new Comparator() { + @Override + public int compare(ScoreDoc a, ScoreDoc b) { + return a.doc - b.doc; + } + }); + + topN = Math.min(topN, firstPassTopDocs.totalHits); + final ScoreDoc[] reranked = new ScoreDoc[topN]; + final List leaves = searcher.getIndexReader().leaves(); + final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) searcher + .createNormalizedWeight(scoringQuery, true); + + final SolrIndexSearcher solrIndexSearch = (SolrIndexSearcher) searcher; + scoreFeatures(solrIndexSearch, firstPassTopDocs,topN, modelWeight, hits, leaves, reranked); + // Must sort all documents that we reranked, and then select the top + Arrays.sort(reranked, new Comparator() { + @Override + public int compare(ScoreDoc a, ScoreDoc b) { + // Sort by score descending, then docID ascending: + if (a.score > b.score) { + return -1; + } else if (a.score < b.score) { + return 1; + } else { + // This subtraction can't overflow int + // because docIDs are >= 0: + return a.doc - b.doc; + } + } + }); + + return new TopDocs(firstPassTopDocs.totalHits, reranked, reranked[0].score); + } + + public void scoreFeatures(SolrIndexSearcher solrIndexSearch, TopDocs firstPassTopDocs, + int topN, LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List leaves, + ScoreDoc[] reranked) throws IOException { + + int readerUpto = -1; + int endDoc = 0; + int docBase = 0; + + LTRScoringQuery.ModelWeight.ModelScorer scorer = null; + int hitUpto = 0; + final FeatureLogger featureLogger = scoringQuery.getFeatureLogger(); + + while (hitUpto < hits.length) { + final ScoreDoc hit = hits[hitUpto]; + final int docID = hit.doc; + LeafReaderContext readerContext = null; + while (docID >= endDoc) { + readerUpto++; + readerContext = leaves.get(readerUpto); + endDoc = readerContext.docBase + readerContext.reader().maxDoc(); + } + // We advanced to another segment + if (readerContext != null) { + docBase = readerContext.docBase; + scorer = modelWeight.scorer(readerContext); + } + // Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to + // call score + // even if no feature scorers match, since a model might use that info to + // return a + // non-zero score. Same applies for the case of advancing a LTRScoringQuery.ModelWeight.ModelScorer + // past the target + // doc since the model algorithm still needs to compute a potentially + // non-zero score from blank features. + assert (scorer != null); + final int targetDoc = docID - docBase; + scorer.docID(); + scorer.iterator().advance(targetDoc); + + scorer.getDocInfo().setOriginalDocScore(new Float(hit.score)); + hit.score = scorer.score(); + if (hitUpto < topN) { + reranked[hitUpto] = hit; + // if the heap is not full, maybe I want to log the features for this + // document + if (featureLogger != null) { + featureLogger.log(hit.doc, scoringQuery, solrIndexSearch, + modelWeight.getFeaturesInfo()); + } + } else if (hitUpto == topN) { + // collected topN document, I create the heap + heapify(reranked, topN); + } + if (hitUpto >= topN) { + // once that heap is ready, if the score of this document is lower that + // the minimum + // i don't want to log the feature. Otherwise I replace it with the + // minimum and fix the + // heap. + if (hit.score > reranked[0].score) { + reranked[0] = hit; + heapAdjust(reranked, topN, 0); + if (featureLogger != null) { + featureLogger.log(hit.doc, scoringQuery, solrIndexSearch, + modelWeight.getFeaturesInfo()); + } + } + } + hitUpto++; + } + } + + @Override + public Explanation explain(IndexSearcher searcher, + Explanation firstPassExplanation, int docID) throws IOException { + + final List leafContexts = searcher.getTopReaderContext() + .leaves(); + final int n = ReaderUtil.subIndex(docID, leafContexts); + final LeafReaderContext context = leafContexts.get(n); + final int deBasedDoc = docID - context.docBase; + final Weight modelWeight = searcher.createNormalizedWeight(scoringQuery, + true); + return modelWeight.explain(context, deBasedDoc); + } + + public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo(LTRScoringQuery.ModelWeight modelWeight, + int docid, + Float originalDocScore, + List leafContexts) + throws IOException { + final int n = ReaderUtil.subIndex(docid, leafContexts); + final LeafReaderContext atomicContext = leafContexts.get(n); + final int deBasedDoc = docid - atomicContext.docBase; + final LTRScoringQuery.ModelWeight.ModelScorer r = modelWeight.scorer(atomicContext); + if ( (r == null) || (r.iterator().advance(deBasedDoc) != deBasedDoc) ) { + return new LTRScoringQuery.FeatureInfo[0]; + } else { + if (originalDocScore != null) { + // If results have not been reranked, the score passed in is the original query's + // score, which some features can use instead of recalculating it + r.getDocInfo().setOriginalDocScore(originalDocScore); + } + r.score(); + return modelWeight.getFeaturesInfo(); + } + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java new file mode 100644 index 00000000000..991c1edf58f --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -0,0 +1,738 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr; + +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.Future; +import java.util.concurrent.FutureTask; +import java.util.concurrent.RunnableFuture; +import java.util.concurrent.Semaphore; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.DisiPriorityQueue; +import org.apache.lucene.search.DisiWrapper; +import org.apache.lucene.search.DisjunctionDISIApproximation; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.request.SolrQueryRequest; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * The ranking query that is run, reranking results using the + * LTRScoringModel algorithm + */ +public class LTRScoringQuery extends Query { + + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + // contains a description of the model + final private LTRScoringModel ltrScoringModel; + final private boolean extractAllFeatures; + final private LTRThreadModule ltrThreadMgr; + final private Semaphore querySemaphore; // limits the number of threads per query, so that multiple requests can be serviced simultaneously + + // feature logger to output the features. + private FeatureLogger fl; + // Map of external parameters, such as query intent, that can be used by + // features + final private Map efi; + // Original solr query used to fetch matching documents + private Query originalQuery; + // Original solr request + private SolrQueryRequest request; + + public LTRScoringQuery(LTRScoringModel ltrScoringModel) { + this(ltrScoringModel, Collections.emptyMap(), false, null); + } + + public LTRScoringQuery(LTRScoringModel ltrScoringModel, boolean extractAllFeatures) { + this(ltrScoringModel, Collections.emptyMap(), extractAllFeatures, null); + } + + public LTRScoringQuery(LTRScoringModel ltrScoringModel, + Map externalFeatureInfo, + boolean extractAllFeatures, LTRThreadModule ltrThreadMgr) { + this.ltrScoringModel = ltrScoringModel; + this.efi = externalFeatureInfo; + this.extractAllFeatures = extractAllFeatures; + this.ltrThreadMgr = ltrThreadMgr; + if (this.ltrThreadMgr != null) { + this.querySemaphore = this.ltrThreadMgr.createQuerySemaphore(); + } else{ + this.querySemaphore = null; + } + } + + public LTRScoringModel getScoringModel() { + return ltrScoringModel; + } + + public void setFeatureLogger(FeatureLogger fl) { + this.fl = fl; + } + + public FeatureLogger getFeatureLogger() { + return fl; + } + + public void setOriginalQuery(Query originalQuery) { + this.originalQuery = originalQuery; + } + + public Query getOriginalQuery() { + return originalQuery; + } + + public Map getExternalFeatureInfo() { + return efi; + } + + public void setRequest(SolrQueryRequest request) { + this.request = request; + } + + public SolrQueryRequest getRequest() { + return request; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = classHash(); + result = (prime * result) + ((ltrScoringModel == null) ? 0 : ltrScoringModel.hashCode()); + result = (prime * result) + + ((originalQuery == null) ? 0 : originalQuery.hashCode()); + if (efi == null) { + result = (prime * result) + 0; + } + else { + for (final Map.Entry entry : efi.entrySet()) { + final String key = entry.getKey(); + final String[] values = entry.getValue(); + result = (prime * result) + key.hashCode(); + result = (prime * result) + Arrays.hashCode(values); + } + } + result = (prime * result) + this.toString().hashCode(); + return result; + } + @Override + public boolean equals(Object o) { + return sameClassAs(o) && equalsTo(getClass().cast(o)); + } + + private boolean equalsTo(LTRScoringQuery other) { + if (ltrScoringModel == null) { + if (other.ltrScoringModel != null) { + return false; + } + } else if (!ltrScoringModel.equals(other.ltrScoringModel)) { + return false; + } + if (originalQuery == null) { + if (other.originalQuery != null) { + return false; + } + } else if (!originalQuery.equals(other.originalQuery)) { + return false; + } + if (efi == null) { + if (other.efi != null) { + return false; + } + } else { + if (other.efi == null || efi.size() != other.efi.size()) { + return false; + } + for(final Map.Entry entry : efi.entrySet()) { + final String key = entry.getKey(); + final String[] otherValues = other.efi.get(key); + if (otherValues == null || !Arrays.equals(otherValues,entry.getValue())) { + return false; + } + } + } + return true; + } + + @Override + public ModelWeight createWeight(IndexSearcher searcher, boolean needsScores, float boost) + throws IOException { + final Collection modelFeatures = ltrScoringModel.getFeatures(); + final Collection allFeatures = ltrScoringModel.getAllFeatures(); + int modelFeatSize = modelFeatures.size(); + + Collection features = null; + if (this.extractAllFeatures) { + features = allFeatures; + } + else{ + features = modelFeatures; + } + final Feature.FeatureWeight[] extractedFeatureWeights = new Feature.FeatureWeight[features.size()]; + final Feature.FeatureWeight[] modelFeaturesWeights = new Feature.FeatureWeight[modelFeatSize]; + List featureWeights = new ArrayList<>(features.size()); + + if (querySemaphore == null) { + createWeights(searcher, needsScores, boost, featureWeights, features); + } + else{ + createWeightsParallel(searcher, needsScores, boost, featureWeights, features); + } + int i=0, j = 0; + if (this.extractAllFeatures) { + for (final Feature.FeatureWeight fw : featureWeights) { + extractedFeatureWeights[i++] = fw; + } + for (final Feature f : modelFeatures){ + modelFeaturesWeights[j++] = extractedFeatureWeights[f.getIndex()]; // we can lookup by featureid because all features will be extracted when this.extractAllFeatures is set + } + } + else{ + for (final Feature.FeatureWeight fw: featureWeights){ + extractedFeatureWeights[i++] = fw; + modelFeaturesWeights[j++] = fw; + } + } + return new ModelWeight(modelFeaturesWeights, extractedFeatureWeights, allFeatures.size()); + } + + private void createWeights(IndexSearcher searcher, boolean needsScores, float boost, + List featureWeights, Collection features) throws IOException { + final SolrQueryRequest req = getRequest(); + // since the feature store is a linkedhashmap order is preserved + for (final Feature f : features) { + try{ + Feature.FeatureWeight fw = f.createWeight(searcher, needsScores, req, originalQuery, efi); + featureWeights.add(fw); + } catch (final Exception e) { + throw new RuntimeException("Exception from createWeight for " + f.toString() + " " + + e.getMessage(), e); + } + } + } + + private class CreateWeightCallable implements Callable{ + final private Feature f; + final private IndexSearcher searcher; + final private boolean needsScores; + final private SolrQueryRequest req; + + public CreateWeightCallable(Feature f, IndexSearcher searcher, boolean needsScores, SolrQueryRequest req){ + this.f = f; + this.searcher = searcher; + this.needsScores = needsScores; + this.req = req; + } + + @Override + public Feature.FeatureWeight call() throws Exception{ + try { + Feature.FeatureWeight fw = f.createWeight(searcher, needsScores, req, originalQuery, efi); + return fw; + } catch (final Exception e) { + throw new RuntimeException("Exception from createWeight for " + f.toString() + " " + + e.getMessage(), e); + } finally { + querySemaphore.release(); + ltrThreadMgr.releaseLTRSemaphore(); + } + } + } // end of call CreateWeightCallable + + private void createWeightsParallel(IndexSearcher searcher, boolean needsScores, float boost, + List featureWeights, Collection features) throws RuntimeException { + + final SolrQueryRequest req = getRequest(); + List > futures = new ArrayList<>(features.size()); + try{ + for (final Feature f : features) { + CreateWeightCallable callable = new CreateWeightCallable(f, searcher, needsScores, req); + RunnableFuture runnableFuture = new FutureTask<>(callable); + querySemaphore.acquire(); // always acquire before the ltrSemaphore is acquired, to guarantee a that the current query is within the limit for max. threads + ltrThreadMgr.acquireLTRSemaphore();//may block and/or interrupt + ltrThreadMgr.execute(runnableFuture);//releases semaphore when done + futures.add(runnableFuture); + } + //Loop over futures to get the feature weight objects + for (final Future future : futures) { + featureWeights.add(future.get()); // future.get() will block if the job is still running + } + } catch (Exception e) { // To catch InterruptedException and ExecutionException + log.info("Error while creating weights in LTR: InterruptedException", e); + throw new RuntimeException("Error while creating weights in LTR: " + e.getMessage(), e); + } + } + + @Override + public String toString(String field) { + return field; + } + + public class FeatureInfo { + final private String name; + private float value; + private boolean used; + + FeatureInfo(String n, float v, boolean u){ + name = n; value = v; used = u; + } + + public void setValue(float value){ + this.value = value; + } + + public String getName(){ + return name; + } + + public float getValue(){ + return value; + } + + public boolean isUsed(){ + return used; + } + + public void setUsed(boolean used){ + this.used = used; + } + } + + public class ModelWeight extends Weight { + + // List of the model's features used for scoring. This is a subset of the + // features used for logging. + final private Feature.FeatureWeight[] modelFeatureWeights; + final private float[] modelFeatureValuesNormalized; + final private Feature.FeatureWeight[] extractedFeatureWeights; + + // List of all the feature names, values - used for both scoring and logging + /* + * What is the advantage of using a hashmap here instead of an array of objects? + * A set of arrays was used earlier and the elements were accessed using the featureId. + * With the updated logic to create weights selectively, + * the number of elements in the array can be fewer than the total number of features. + * When [features] are not requested, only the model features are extracted. + * In this case, the indexing by featureId, fails. For this reason, + * we need a map which holds just the features that were triggered by the documents in the result set. + * + */ + final private FeatureInfo[] featuresInfo; + /* + * @param modelFeatureWeights + * - should be the same size as the number of features used by the model + * @param extractedFeatureWeights + * - if features are requested from the same store as model feature store, + * this will be the size of total number of features in the model feature store + * else, this will be the size of the modelFeatureWeights + * @param allFeaturesSize + * - total number of feature in the feature store used by this model + */ + public ModelWeight(Feature.FeatureWeight[] modelFeatureWeights, + Feature.FeatureWeight[] extractedFeatureWeights, int allFeaturesSize) { + super(LTRScoringQuery.this); + this.extractedFeatureWeights = extractedFeatureWeights; + this.modelFeatureWeights = modelFeatureWeights; + this.modelFeatureValuesNormalized = new float[modelFeatureWeights.length]; + this.featuresInfo = new FeatureInfo[allFeaturesSize]; + setFeaturesInfo(); + } + + private void setFeaturesInfo(){ + for (int i = 0; i < extractedFeatureWeights.length;++i){ + String featName = extractedFeatureWeights[i].getName(); + int featId = extractedFeatureWeights[i].getIndex(); + float value = extractedFeatureWeights[i].getDefaultValue(); + featuresInfo[featId] = new FeatureInfo(featName,value,false); + } + } + + public FeatureInfo[] getFeaturesInfo(){ + return featuresInfo; + } + + // for test use + Feature.FeatureWeight[] getModelFeatureWeights() { + return modelFeatureWeights; + } + + // for test use + float[] getModelFeatureValuesNormalized() { + return modelFeatureValuesNormalized; + } + + // for test use + Feature.FeatureWeight[] getExtractedFeatureWeights() { + return extractedFeatureWeights; + } + + /** + * Goes through all the stored feature values, and calculates the normalized + * values for all the features that will be used for scoring. + */ + private void makeNormalizedFeatures() { + int pos = 0; + for (final Feature.FeatureWeight feature : modelFeatureWeights) { + final int featureId = feature.getIndex(); + FeatureInfo fInfo = featuresInfo[featureId]; + if (fInfo.isUsed()) { // not checking for finfo == null as that would be a bug we should catch + modelFeatureValuesNormalized[pos] = fInfo.getValue(); + } else { + modelFeatureValuesNormalized[pos] = feature.getDefaultValue(); + } + pos++; + } + ltrScoringModel.normalizeFeaturesInPlace(modelFeatureValuesNormalized); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) + throws IOException { + + final Explanation[] explanations = new Explanation[this.featuresInfo.length]; + for (final Feature.FeatureWeight feature : extractedFeatureWeights) { + explanations[feature.getIndex()] = feature.explain(context, doc); + } + final List featureExplanations = new ArrayList<>(); + for (int idx = 0 ;idx < modelFeatureWeights.length; ++idx) { + final Feature.FeatureWeight f = modelFeatureWeights[idx]; + Explanation e = ltrScoringModel.getNormalizerExplanation(explanations[f.getIndex()], idx); + featureExplanations.add(e); + } + final ModelScorer bs = scorer(context); + bs.iterator().advance(doc); + + final float finalScore = bs.score(); + + return ltrScoringModel.explain(context, doc, finalScore, featureExplanations); + + } + + @Override + public void extractTerms(Set terms) { + for (final Feature.FeatureWeight feature : extractedFeatureWeights) { + feature.extractTerms(terms); + } + } + + protected void reset() { + for (int i = 0; i < extractedFeatureWeights.length;++i){ + int featId = extractedFeatureWeights[i].getIndex(); + float value = extractedFeatureWeights[i].getDefaultValue(); + featuresInfo[featId].setValue(value); // need to set default value everytime as the default value is used in 'dense' mode even if used=false + featuresInfo[featId].setUsed(false); + } + } + + @Override + public ModelScorer scorer(LeafReaderContext context) throws IOException { + + final List featureScorers = new ArrayList( + extractedFeatureWeights.length); + for (final Feature.FeatureWeight featureWeight : extractedFeatureWeights) { + final Feature.FeatureWeight.FeatureScorer scorer = featureWeight.scorer(context); + if (scorer != null) { + featureScorers.add(scorer); + } + } + // Always return a ModelScorer, even if no features match, because we + // always need to call + // score on the model for every document, since 0 features matching could + // return a + // non 0 score for a given model. + ModelScorer mscorer = new ModelScorer(this, featureScorers); + return mscorer; + + } + + public class ModelScorer extends Scorer { + final private DocInfo docInfo; + final private Scorer featureTraversalScorer; + + public DocInfo getDocInfo() { + return docInfo; + } + + public ModelScorer(Weight weight, List featureScorers) { + super(weight); + docInfo = new DocInfo(); + for (final Feature.FeatureWeight.FeatureScorer subSocer : featureScorers) { + subSocer.setDocInfo(docInfo); + } + if (featureScorers.size() <= 1) { // TODO: Allow the use of dense + // features in other cases + featureTraversalScorer = new DenseModelScorer(weight, featureScorers); + } else { + featureTraversalScorer = new SparseModelScorer(weight, featureScorers); + } + } + + @Override + public Collection getChildren() { + return featureTraversalScorer.getChildren(); + } + + @Override + public int docID() { + return featureTraversalScorer.docID(); + } + + @Override + public float score() throws IOException { + return featureTraversalScorer.score(); + } + + @Override + public int freq() throws IOException { + return featureTraversalScorer.freq(); + } + + @Override + public DocIdSetIterator iterator() { + return featureTraversalScorer.iterator(); + } + + private class SparseModelScorer extends Scorer { + final private DisiPriorityQueue subScorers; + final private ScoringQuerySparseIterator itr; + + private int targetDoc = -1; + private int activeDoc = -1; + + private SparseModelScorer(Weight weight, + List featureScorers) { + super(weight); + if (featureScorers.size() <= 1) { + throw new IllegalArgumentException( + "There must be at least 2 subScorers"); + } + subScorers = new DisiPriorityQueue(featureScorers.size()); + for (final Scorer scorer : featureScorers) { + final DisiWrapper w = new DisiWrapper(scorer); + subScorers.add(w); + } + + itr = new ScoringQuerySparseIterator(subScorers); + } + + @Override + public int docID() { + return itr.docID(); + } + + @Override + public float score() throws IOException { + final DisiWrapper topList = subScorers.topList(); + // If target doc we wanted to advance to matches the actual doc + // the underlying features advanced to, perform the feature + // calculations, + // otherwise just continue with the model's scoring process with empty + // features. + reset(); + if (activeDoc == targetDoc) { + for (DisiWrapper w = topList; w != null; w = w.next) { + final Scorer subScorer = w.scorer; + Feature.FeatureWeight scFW = (Feature.FeatureWeight) subScorer.getWeight(); + final int featureId = scFW.getIndex(); + featuresInfo[featureId].setValue(subScorer.score()); + featuresInfo[featureId].setUsed(true); + } + } + makeNormalizedFeatures(); + return ltrScoringModel.score(modelFeatureValuesNormalized); + } + + @Override + public int freq() throws IOException { + final DisiWrapper subMatches = subScorers.topList(); + int freq = 1; + for (DisiWrapper w = subMatches.next; w != null; w = w.next) { + freq += 1; + } + return freq; + } + + @Override + public DocIdSetIterator iterator() { + return itr; + } + + @Override + public final Collection getChildren() { + final ArrayList children = new ArrayList<>(); + for (final DisiWrapper scorer : subScorers) { + children.add(new ChildScorer(scorer.scorer, "SHOULD")); + } + return children; + } + + private class ScoringQuerySparseIterator extends DisjunctionDISIApproximation { + + public ScoringQuerySparseIterator(DisiPriorityQueue subIterators) { + super(subIterators); + } + + @Override + public final int nextDoc() throws IOException { + if (activeDoc == targetDoc) { + activeDoc = super.nextDoc(); + } else if (activeDoc < targetDoc) { + activeDoc = super.advance(targetDoc + 1); + } + return ++targetDoc; + } + + @Override + public final int advance(int target) throws IOException { + // If target doc we wanted to advance to matches the actual doc + // the underlying features advanced to, perform the feature + // calculations, + // otherwise just continue with the model's scoring process with + // empty features. + if (activeDoc < target) { + activeDoc = super.advance(target); + } + targetDoc = target; + return targetDoc; + } + } + + } + + private class DenseModelScorer extends Scorer { + private int activeDoc = -1; // The doc that our scorer's are actually at + private int targetDoc = -1; // The doc we were most recently told to go to + private int freq = -1; + final private List featureScorers; + + private DenseModelScorer(Weight weight, + List featureScorers) { + super(weight); + this.featureScorers = featureScorers; + } + + @Override + public int docID() { + return targetDoc; + } + + @Override + public float score() throws IOException { + reset(); + freq = 0; + if (targetDoc == activeDoc) { + for (final Scorer scorer : featureScorers) { + if (scorer.docID() == activeDoc) { + freq++; + Feature.FeatureWeight scFW = (Feature.FeatureWeight) scorer.getWeight(); + final int featureId = scFW.getIndex(); + featuresInfo[featureId].setValue(scorer.score()); + featuresInfo[featureId].setUsed(true); + } + } + } + makeNormalizedFeatures(); + return ltrScoringModel.score(modelFeatureValuesNormalized); + } + + @Override + public final Collection getChildren() { + final ArrayList children = new ArrayList<>(); + for (final Scorer scorer : featureScorers) { + children.add(new ChildScorer(scorer, "SHOULD")); + } + return children; + } + + @Override + public int freq() throws IOException { + return freq; + } + + @Override + public DocIdSetIterator iterator() { + return new DenseIterator(); + } + + private class DenseIterator extends DocIdSetIterator { + + @Override + public int docID() { + return targetDoc; + } + + @Override + public int nextDoc() throws IOException { + if (activeDoc <= targetDoc) { + activeDoc = NO_MORE_DOCS; + for (final Scorer scorer : featureScorers) { + if (scorer.docID() != NO_MORE_DOCS) { + activeDoc = Math.min(activeDoc, scorer.iterator().nextDoc()); + } + } + } + return ++targetDoc; + } + + @Override + public int advance(int target) throws IOException { + if (activeDoc < target) { + activeDoc = NO_MORE_DOCS; + for (final Scorer scorer : featureScorers) { + if (scorer.docID() != NO_MORE_DOCS) { + activeDoc = Math.min(activeDoc, + scorer.iterator().advance(target)); + } + } + } + targetDoc = target; + return target; + } + + @Override + public long cost() { + long sum = 0; + for (int i = 0; i < featureScorers.size(); i++) { + sum += featureScorers.get(i).iterator().cost(); + } + return sum; + } + + } + } + } + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRThreadModule.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRThreadModule.java new file mode 100644 index 00000000000..8e2563f1e08 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRThreadModule.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr; + +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.Executor; +import java.util.concurrent.Semaphore; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.TimeUnit; + +import org.apache.solr.common.util.ExecutorUtil; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.util.DefaultSolrThreadFactory; +import org.apache.solr.util.SolrPluginUtils; +import org.apache.solr.util.plugin.NamedListInitializedPlugin; + +final public class LTRThreadModule implements NamedListInitializedPlugin { + + public static LTRThreadModule getInstance(NamedList args) { + + final LTRThreadModule threadManager; + final NamedList threadManagerArgs = extractThreadModuleParams(args); + // if and only if there are thread module args then we want a thread module! + if (threadManagerArgs.size() > 0) { + // create and initialize the new instance + threadManager = new LTRThreadModule(); + threadManager.init(threadManagerArgs); + } else { + threadManager = null; + } + + return threadManager; + } + + private static String CONFIG_PREFIX = "threadModule."; + + private static NamedList extractThreadModuleParams(NamedList args) { + + // gather the thread module args from amongst the general args + final NamedList extractedArgs = new NamedList(); + for (Iterator> it = args.iterator(); + it.hasNext(); ) { + final Map.Entry entry = it.next(); + final String key = entry.getKey(); + if (key.startsWith(CONFIG_PREFIX)) { + extractedArgs.add(key.substring(CONFIG_PREFIX.length()), entry.getValue()); + } + } + + // remove consumed keys only once iteration is complete + // since NamedList iterator does not support 'remove' + for (Object key : extractedArgs.asShallowMap().keySet()) { + args.remove(CONFIG_PREFIX+key); + } + + return extractedArgs; + } + + // settings + private int totalPoolThreads = 1; + private int numThreadsPerRequest = 1; + private int maxPoolSize = Integer.MAX_VALUE; + private long keepAliveTimeSeconds = 10; + private String threadNamePrefix = "ltrExecutor"; + + // implementation + private Semaphore ltrSemaphore; + private Executor createWeightScoreExecutor; + + public LTRThreadModule() { + } + + // For test use only. + LTRThreadModule(int totalPoolThreads, int numThreadsPerRequest) { + this.totalPoolThreads = totalPoolThreads; + this.numThreadsPerRequest = numThreadsPerRequest; + init(null); + } + + @Override + public void init(NamedList args) { + if (args != null) { + SolrPluginUtils.invokeSetters(this, args); + } + validate(); + if (this.totalPoolThreads > 1 ){ + ltrSemaphore = new Semaphore(totalPoolThreads); + } else { + ltrSemaphore = null; + } + createWeightScoreExecutor = new ExecutorUtil.MDCAwareThreadPoolExecutor( + 0, + maxPoolSize, + keepAliveTimeSeconds, TimeUnit.SECONDS, // terminate idle threads after 10 sec + new SynchronousQueue(), // directly hand off tasks + new DefaultSolrThreadFactory(threadNamePrefix) + ); + } + + private void validate() { + if (totalPoolThreads <= 0){ + throw new IllegalArgumentException("totalPoolThreads cannot be less than 1"); + } + if (numThreadsPerRequest <= 0){ + throw new IllegalArgumentException("numThreadsPerRequest cannot be less than 1"); + } + if (totalPoolThreads < numThreadsPerRequest){ + throw new IllegalArgumentException("numThreadsPerRequest cannot be greater than totalPoolThreads"); + } + } + + public void setTotalPoolThreads(int totalPoolThreads) { + this.totalPoolThreads = totalPoolThreads; + } + + public void setNumThreadsPerRequest(int numThreadsPerRequest) { + this.numThreadsPerRequest = numThreadsPerRequest; + } + + public void setMaxPoolSize(int maxPoolSize) { + this.maxPoolSize = maxPoolSize; + } + + public void setKeepAliveTimeSeconds(long keepAliveTimeSeconds) { + this.keepAliveTimeSeconds = keepAliveTimeSeconds; + } + + public void setThreadNamePrefix(String threadNamePrefix) { + this.threadNamePrefix = threadNamePrefix; + } + + public Semaphore createQuerySemaphore() { + return (numThreadsPerRequest > 1 ? new Semaphore(numThreadsPerRequest) : null); + } + + public void acquireLTRSemaphore() throws InterruptedException { + ltrSemaphore.acquire(); + } + + public void releaseLTRSemaphore() throws InterruptedException { + ltrSemaphore.release(); + } + + public void execute(Runnable command) { + createWeightScoreExecutor.execute(command); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/SolrQueryRequestContextUtils.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/SolrQueryRequestContextUtils.java new file mode 100644 index 00000000000..66426eaa05f --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/SolrQueryRequestContextUtils.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr; + +import org.apache.solr.request.SolrQueryRequest; + +public class SolrQueryRequestContextUtils { + + /** key prefix to reduce possibility of clash with other code's key choices **/ + private static final String LTR_PREFIX = "ltr."; + + /** key of the feature logger in the request context **/ + private static final String FEATURE_LOGGER = LTR_PREFIX + "feature_logger"; + + /** key of the scoring query in the request context **/ + private static final String SCORING_QUERY = LTR_PREFIX + "scoring_query"; + + /** key of the isExtractingFeatures flag in the request context **/ + private static final String IS_EXTRACTING_FEATURES = LTR_PREFIX + "isExtractingFeatures"; + + /** key of the feature vector store name in the request context **/ + private static final String STORE = LTR_PREFIX + "store"; + + /** feature logger accessors **/ + + public static void setFeatureLogger(SolrQueryRequest req, FeatureLogger featureLogger) { + req.getContext().put(FEATURE_LOGGER, featureLogger); + } + + public static FeatureLogger getFeatureLogger(SolrQueryRequest req) { + return (FeatureLogger) req.getContext().get(FEATURE_LOGGER); + } + + /** scoring query accessors **/ + + public static void setScoringQuery(SolrQueryRequest req, LTRScoringQuery scoringQuery) { + req.getContext().put(SCORING_QUERY, scoringQuery); + } + + public static LTRScoringQuery getScoringQuery(SolrQueryRequest req) { + return (LTRScoringQuery) req.getContext().get(SCORING_QUERY); + } + + /** isExtractingFeatures flag accessors **/ + + public static void setIsExtractingFeatures(SolrQueryRequest req) { + req.getContext().put(IS_EXTRACTING_FEATURES, Boolean.TRUE); + } + + public static void clearIsExtractingFeatures(SolrQueryRequest req) { + req.getContext().put(IS_EXTRACTING_FEATURES, Boolean.FALSE); + } + + public static boolean isExtractingFeatures(SolrQueryRequest req) { + return Boolean.TRUE.equals(req.getContext().get(IS_EXTRACTING_FEATURES)); + } + + /** feature vector store name accessors **/ + + public static void setFvStoreName(SolrQueryRequest req, String fvStoreName) { + req.getContext().put(STORE, fvStoreName); + } + + public static String getFvStoreName(SolrQueryRequest req) { + return (String) req.getContext().get(STORE); + } + +} + diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/Feature.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/Feature.java new file mode 100644 index 00000000000..228b964e6b9 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/Feature.java @@ -0,0 +1,335 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Set; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.apache.solr.core.SolrResourceLoader; +import org.apache.solr.ltr.DocInfo; +import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.request.macro.MacroExpander; +import org.apache.solr.util.SolrPluginUtils; + +/** + * A recipe for computing a feature. Subclass this for specialized feature calculations. + *

+ * A feature consists of + *

    + *
  • a name as the identifier + *
  • parameters to represent the specific feature + *
+ *

+ * Example configuration (snippet): + *

{
+   "class" : "...",
+   "name" : "myFeature",
+   "params" : {
+       ...
+   }
+}
+ *

+ * {@link Feature} is an abstract class and concrete classes should implement + * the {@link #validate()} function, and must implement the {@link #paramsToMap()} + * and createWeight() methods. + */ +public abstract class Feature extends Query { + + final protected String name; + private int index = -1; + private float defaultValue = 0.0f; + + final private Map params; + + public static Feature getInstance(SolrResourceLoader solrResourceLoader, + String className, String name, Map params) { + final Feature f = solrResourceLoader.newInstance( + className, + Feature.class, + new String[0], // no sub packages + new Class[] { String.class, Map.class }, + new Object[] { name, params }); + if (params != null) { + SolrPluginUtils.invokeSetters(f, params.entrySet()); + } + f.validate(); + return f; + } + + public Feature(String name, Map params) { + this.name = name; + this.params = params; + } + + /** + * As part of creation of a feature instance, this function confirms + * that the feature parameters are valid. + * + * @throws FeatureException + * Feature Exception + */ + protected abstract void validate() throws FeatureException; + + @Override + public String toString(String field) { + final StringBuilder sb = new StringBuilder(64); // default initialCapacity of 16 won't be enough + sb.append(getClass().getSimpleName()); + sb.append(" [name=").append(name); + final LinkedHashMap params = paramsToMap(); + if (params != null) { + sb.append(", params=").append(params); + } + sb.append(']'); + return sb.toString(); + } + + public abstract FeatureWeight createWeight(IndexSearcher searcher, + boolean needsScores, SolrQueryRequest request, Query originalQuery, Map efi) throws IOException; + + public float getDefaultValue() { + return defaultValue; + } + + public void setDefaultValue(String value){ + defaultValue = Float.parseFloat(value); + } + + + @Override + public int hashCode() { + final int prime = 31; + int result = classHash(); + result = (prime * result) + index; + result = (prime * result) + ((name == null) ? 0 : name.hashCode()); + result = (prime * result) + ((params == null) ? 0 : params.hashCode()); + return result; + } + + @Override + public boolean equals(Object o) { + return sameClassAs(o) && equalsTo(getClass().cast(o)); + } + + private boolean equalsTo(Feature other) { + if (index != other.index) { + return false; + } + if (name == null) { + if (other.name != null) { + return false; + } + } else if (!name.equals(other.name)) { + return false; + } + if (params == null) { + if (other.params != null) { + return false; + } + } else if (!params.equals(other.params)) { + return false; + } + return true; + } + + /** + * @return the name + */ + public String getName() { + return name; + } + + /** + * @return the id + */ + public int getIndex() { + return index; + } + + /** + * @param index + * Unique ID for this feature. Similar to feature name, except it can + * be used to directly access the feature in the global list of + * features. + */ + public void setIndex(int index) { + this.index = index; + } + + public abstract LinkedHashMap paramsToMap(); + /** + * Weight for a feature + **/ + public abstract class FeatureWeight extends Weight { + + final protected IndexSearcher searcher; + final protected SolrQueryRequest request; + final protected Map efi; + final protected MacroExpander macroExpander; + final protected Query originalQuery; + + /** + * Initialize a feature without the normalizer from the feature file. This is + * called on initial construction since multiple models share the same + * features, but have different normalizers. A concrete model's feature is + * copied through featForNewModel(). + * + * @param q + * Solr query associated with this FeatureWeight + * @param searcher + * Solr searcher available for features if they need them + */ + public FeatureWeight(Query q, IndexSearcher searcher, + SolrQueryRequest request, Query originalQuery, Map efi) { + super(q); + this.searcher = searcher; + this.request = request; + this.originalQuery = originalQuery; + this.efi = efi; + macroExpander = new MacroExpander(efi,true); + } + + public String getName() { + return Feature.this.getName(); + } + + public int getIndex() { + return Feature.this.getIndex(); + } + + public float getDefaultValue() { + return Feature.this.getDefaultValue(); + } + + @Override + public abstract FeatureScorer scorer(LeafReaderContext context) + throws IOException; + + @Override + public Explanation explain(LeafReaderContext context, int doc) + throws IOException { + final FeatureScorer r = scorer(context); + float score = getDefaultValue(); + if (r != null) { + r.iterator().advance(doc); + if (r.docID() == doc) { + score = r.score(); + } + return Explanation.match(score, toString()); + }else{ + return Explanation.match(score, "The feature has no value"); + } + } + + /** + * Used in the FeatureWeight's explain. Each feature should implement this + * returning properties of the specific scorer useful for an explain. For + * example "MyCustomClassFeature [name=" + name + "myVariable:" + myVariable + + * "]"; If not provided, a default implementation will return basic feature + * properties, which might not include query time specific values. + */ + @Override + public String toString() { + return Feature.this.toString(); + } + + @Override + public void extractTerms(Set terms) { + // needs to be implemented by query subclasses + throw new UnsupportedOperationException(); + } + + /** + * A 'recipe' for computing a feature + */ + public abstract class FeatureScorer extends Scorer { + + final protected String name; + private DocInfo docInfo; + final protected DocIdSetIterator itr; + + public FeatureScorer(Feature.FeatureWeight weight, + DocIdSetIterator itr) { + super(weight); + this.itr = itr; + name = weight.getName(); + docInfo = null; + } + + @Override + public abstract float score() throws IOException; + + /** + * Used to provide context from initial score steps to later reranking steps. + */ + public void setDocInfo(DocInfo docInfo) { + this.docInfo = docInfo; + } + + public DocInfo getDocInfo() { + return docInfo; + } + + @Override + public int freq() throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public int docID() { + return itr.docID(); + } + + @Override + public DocIdSetIterator iterator() { + return itr; + } + } + + /** + * Default FeatureScorer class that returns the score passed in. Can be used + * as a simple ValueFeature, or to return a default scorer in case an + * underlying feature's scorer is null. + */ + public class ValueFeatureScorer extends FeatureScorer { + float constScore; + + public ValueFeatureScorer(FeatureWeight weight, float constScore, + DocIdSetIterator itr) { + super(weight,itr); + this.constScore = constScore; + } + + @Override + public float score() { + return constScore; + } + + } + + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FeatureException.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FeatureException.java new file mode 100644 index 00000000000..6c8f82762f0 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FeatureException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +public class FeatureException extends RuntimeException { + + private static final long serialVersionUID = 1L; + + public FeatureException(String message) { + super(message); + } + + public FeatureException(String message, Exception cause) { + super(message, cause); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FieldLengthFeature.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FieldLengthFeature.java new file mode 100644 index 00000000000..4c17affe5bc --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FieldLengthFeature.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.Map; + +import org.apache.lucene.index.IndexableField; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.NumericDocValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.util.SmallFloat; +import org.apache.solr.request.SolrQueryRequest; +/** + * This feature returns the length of a field (in terms) for the current document. + * Example configuration: + *

{
+  "name":  "titleLength",
+  "class": "org.apache.solr.ltr.feature.FieldLengthFeature",
+  "params": {
+      "field": "title"
+  }
+}
+ * Note: since this feature relies on norms values that are stored in a single byte + * the value of the feature could have a lightly different value. + * (see also {@link org.apache.lucene.search.similarities.ClassicSimilarity}) + **/ +public class FieldLengthFeature extends Feature { + + private String field; + + public String getField() { + return field; + } + + public void setField(String field) { + this.field = field; + } + + @Override + public LinkedHashMap paramsToMap() { + final LinkedHashMap params = new LinkedHashMap<>(1, 1.0f); + params.put("field", field); + return params; + } + + @Override + protected void validate() throws FeatureException { + if (field == null || field.isEmpty()) { + throw new FeatureException(getClass().getSimpleName()+ + ": field must be provided"); + } + } + + /** Cache of decoded bytes. */ + + private static final float[] NORM_TABLE = new float[256]; + + static { + NORM_TABLE[0] = 0; + for (int i = 1; i < 256; i++) { + float norm = SmallFloat.byte315ToFloat((byte) i); + NORM_TABLE[i] = 1.0f / (norm * norm); + } + } + + /** + * Decodes the norm value, assuming it is a single byte. + * + */ + + private final float decodeNorm(long norm) { + return NORM_TABLE[(int) (norm & 0xFF)]; // & 0xFF maps negative bytes to + // positive above 127 + } + + public FieldLengthFeature(String name, Map params) { + super(name, params); + } + + @Override + public FeatureWeight createWeight(IndexSearcher searcher, boolean needsScores, + SolrQueryRequest request, Query originalQuery, Map efi) + throws IOException { + + return new FieldLengthFeatureWeight(searcher, request, originalQuery, efi); + } + + + public class FieldLengthFeatureWeight extends FeatureWeight { + + public FieldLengthFeatureWeight(IndexSearcher searcher, + SolrQueryRequest request, Query originalQuery, Map efi) { + super(FieldLengthFeature.this, searcher, request, originalQuery, efi); + } + + @Override + public FeatureScorer scorer(LeafReaderContext context) throws IOException { + NumericDocValues norms = context.reader().getNormValues(field); + if (norms == null){ + return new ValueFeatureScorer(this, 0f, + DocIdSetIterator.all(DocIdSetIterator.NO_MORE_DOCS)); + } + return new FieldLengthFeatureScorer(this, norms); + } + + public class FieldLengthFeatureScorer extends FeatureScorer { + + NumericDocValues norms = null; + + public FieldLengthFeatureScorer(FeatureWeight weight, + NumericDocValues norms) throws IOException { + super(weight, norms); + this.norms = norms; + + // In the constructor, docId is -1, so using 0 as default lookup + final IndexableField idxF = searcher.doc(0).getField(field); + if (idxF.fieldType().omitNorms()) { + throw new IOException( + "FieldLengthFeatures can't be used if omitNorms is enabled (field=" + + field + ")"); + } + } + + @Override + public float score() throws IOException { + + final long l = norms.longValue(); + final float numTerms = decodeNorm(l); + return numTerms; + } + } + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java new file mode 100644 index 00000000000..279adbc3ca3 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import java.io.IOException; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Set; + +import org.apache.lucene.document.Document; +import org.apache.lucene.index.IndexableField; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.solr.request.SolrQueryRequest; + +/** + * This feature returns the value of a field in the current document + * Example configuration: + *
{
+  "name":  "rawHits",
+  "class": "org.apache.solr.ltr.feature.FieldValueFeature",
+  "params": {
+      "field": "hits"
+  }
+}
+ */ +public class FieldValueFeature extends Feature { + + private String field; + private Set fieldAsSet; + + public String getField() { + return field; + } + + public void setField(String field) { + this.field = field; + fieldAsSet = Collections.singleton(field); + } + + @Override + public LinkedHashMap paramsToMap() { + final LinkedHashMap params = new LinkedHashMap<>(1, 1.0f); + params.put("field", field); + return params; + } + + @Override + protected void validate() throws FeatureException { + if (field == null || field.isEmpty()) { + throw new FeatureException(getClass().getSimpleName()+ + ": field must be provided"); + } + } + + public FieldValueFeature(String name, Map params) { + super(name, params); + } + + @Override + public FeatureWeight createWeight(IndexSearcher searcher, boolean needsScores, + SolrQueryRequest request, Query originalQuery, Map efi) + throws IOException { + return new FieldValueFeatureWeight(searcher, request, originalQuery, efi); + } + + public class FieldValueFeatureWeight extends FeatureWeight { + + public FieldValueFeatureWeight(IndexSearcher searcher, + SolrQueryRequest request, Query originalQuery, Map efi) { + super(FieldValueFeature.this, searcher, request, originalQuery, efi); + } + + @Override + public FeatureScorer scorer(LeafReaderContext context) throws IOException { + return new FieldValueFeatureScorer(this, context, + DocIdSetIterator.all(DocIdSetIterator.NO_MORE_DOCS)); + } + + public class FieldValueFeatureScorer extends FeatureScorer { + + LeafReaderContext context = null; + + public FieldValueFeatureScorer(FeatureWeight weight, + LeafReaderContext context, DocIdSetIterator itr) { + super(weight, itr); + this.context = context; + } + + @Override + public float score() throws IOException { + + try { + final Document document = context.reader().document(itr.docID(), + fieldAsSet); + final IndexableField indexableField = document.getField(field); + if (indexableField == null) { + return getDefaultValue(); + } + final Number number = indexableField.numericValue(); + if (number != null) { + return number.floatValue(); + } else { + final String string = indexableField.stringValue(); + // boolean values in the index are encoded with the + // chars T/F + if (string.equals("T")) { + return 1; + } + if (string.equals("F")) { + return 0; + } + } + } catch (final IOException e) { + throw new FeatureException( + e.toString() + ": " + + "Unable to extract feature for " + + name, e); + } + return getDefaultValue(); + } + } + } +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/OriginalScoreFeature.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/OriginalScoreFeature.java new file mode 100644 index 00000000000..125615cbb4f --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/OriginalScoreFeature.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.Map; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.apache.solr.ltr.DocInfo; +import org.apache.solr.request.SolrQueryRequest; +/** + * This feature returns the original score that the document had before performing + * the reranking. + * Example configuration: + *
{
+  "name":  "originalScore",
+  "class": "org.apache.solr.ltr.feature.OriginalScoreFeature",
+  "params": { }
+}
+ **/ +public class OriginalScoreFeature extends Feature { + + public OriginalScoreFeature(String name, Map params) { + super(name, params); + } + + @Override + public LinkedHashMap paramsToMap() { + return null; + } + + @Override + protected void validate() throws FeatureException { + } + + @Override + public OriginalScoreWeight createWeight(IndexSearcher searcher, + boolean needsScores, SolrQueryRequest request, Query originalQuery, Map efi) throws IOException { + return new OriginalScoreWeight(searcher, request, originalQuery, efi); + + } + + public class OriginalScoreWeight extends FeatureWeight { + + final Weight w; + + public OriginalScoreWeight(IndexSearcher searcher, + SolrQueryRequest request, Query originalQuery, Map efi) throws IOException { + super(OriginalScoreFeature.this, searcher, request, originalQuery, efi); + w = searcher.createNormalizedWeight(originalQuery, true); + }; + + + @Override + public String toString() { + return "OriginalScoreFeature [query:" + originalQuery.toString() + "]"; + } + + + + @Override + public FeatureScorer scorer(LeafReaderContext context) throws IOException { + + final Scorer originalScorer = w.scorer(context); + return new OriginalScoreScorer(this, originalScorer); + } + + public class OriginalScoreScorer extends FeatureScorer { + final private Scorer originalScorer; + + public OriginalScoreScorer(FeatureWeight weight, Scorer originalScorer) { + super(weight,null); + this.originalScorer = originalScorer; + } + + @Override + public float score() throws IOException { + // This is done to improve the speed of feature extraction. Since this + // was already scored in step 1 + // we shouldn't need to calc original score again. + final DocInfo docInfo = getDocInfo(); + return (docInfo.hasOriginalDocScore() ? docInfo.getOriginalDocScore() : originalScorer.score()); + } + + @Override + public int docID() { + return originalScorer.docID(); + } + + @Override + public DocIdSetIterator iterator() { + return originalScorer.iterator(); + } + } + + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/SolrFeature.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/SolrFeature.java new file mode 100644 index 00000000000..cb7c1a0c81a --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/SolrFeature.java @@ -0,0 +1,320 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSet; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.Bits; +import org.apache.solr.common.params.CommonParams; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.core.SolrCore; +import org.apache.solr.request.LocalSolrQueryRequest; +import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.search.QParser; +import org.apache.solr.search.SolrIndexSearcher; +import org.apache.solr.search.SyntaxError; +/** + * This feature allows you to reuse any Solr query as a feature. The value + * of the feature will be the score of the given query for the current document. + * See Solr documentation of other parsers you can use as a feature. + * Example configurations: + *
[{ "name": "isBook",
+  "class": "org.apache.solr.ltr.feature.SolrFeature",
+  "params":{ "fq": ["{!terms f=category}book"] }
+},
+{
+  "name":  "documentRecency",
+  "class": "org.apache.solr.ltr.feature.SolrFeature",
+  "params": {
+      "q": "{!func}recip( ms(NOW,publish_date), 3.16e-11, 1, 1)"
+  }
+}]
+ **/ +public class SolrFeature extends Feature { + + private String df; + private String q; + private List fq; + + public String getDf() { + return df; + } + + public void setDf(String df) { + this.df = df; + } + + public String getQ() { + return q; + } + + public void setQ(String q) { + this.q = q; + } + + public List getFq() { + return fq; + } + + public void setFq(List fq) { + this.fq = fq; + } + + public SolrFeature(String name, Map params) { + super(name, params); + } + + @Override + public LinkedHashMap paramsToMap() { + final LinkedHashMap params = new LinkedHashMap<>(3, 1.0f); + if (df != null) { + params.put("df", df); + } + if (q != null) { + params.put("q", q); + } + if (fq != null) { + params.put("fq", fq); + } + return params; + } + + @Override + public FeatureWeight createWeight(IndexSearcher searcher, boolean needsScores, + SolrQueryRequest request, Query originalQuery, Map efi) + throws IOException { + return new SolrFeatureWeight(searcher, request, originalQuery, efi); + } + + @Override + protected void validate() throws FeatureException { + if ((q == null || q.isEmpty()) && + ((fq == null) || fq.isEmpty())) { + throw new FeatureException(getClass().getSimpleName()+ + ": Q or FQ must be provided"); + } + } + /** + * Weight for a SolrFeature + **/ + public class SolrFeatureWeight extends FeatureWeight { + Weight solrQueryWeight; + Query query; + List queryAndFilters; + + public SolrFeatureWeight(IndexSearcher searcher, + SolrQueryRequest request, Query originalQuery, Map efi) throws IOException { + super(SolrFeature.this, searcher, request, originalQuery, efi); + try { + String solrQuery = q; + final List fqs = fq; + + if ((solrQuery == null) || solrQuery.isEmpty()) { + solrQuery = "*:*"; + } + + solrQuery = macroExpander.expand(solrQuery); + if (solrQuery == null) { + throw new FeatureException(this.getClass().getSimpleName()+" requires efi parameter that was not passed in request."); + } + + final SolrQueryRequest req = makeRequest(request.getCore(), solrQuery, + fqs, df); + if (req == null) { + throw new IOException("ERROR: No parameters provided"); + } + + // Build the filter queries + queryAndFilters = new ArrayList(); // If there are no fqs we just want an empty list + if (fqs != null) { + for (String fq : fqs) { + if ((fq != null) && (fq.trim().length() != 0)) { + fq = macroExpander.expand(fq); + final QParser fqp = QParser.getParser(fq, req); + final Query filterQuery = fqp.getQuery(); + if (filterQuery != null) { + queryAndFilters.add(filterQuery); + } + } + } + } + + final QParser parser = QParser.getParser(solrQuery, req); + query = parser.parse(); + + // Query can be null if there was no input to parse, for instance if you + // make a phrase query with "to be", and the analyzer removes all the + // words + // leaving nothing for the phrase query to parse. + if (query != null) { + queryAndFilters.add(query); + solrQueryWeight = searcher.createNormalizedWeight(query, true); + } + } catch (final SyntaxError e) { + throw new FeatureException("Failed to parse feature query.", e); + } + } + + private LocalSolrQueryRequest makeRequest(SolrCore core, String solrQuery, + List fqs, String df) { + final NamedList returnList = new NamedList(); + if ((solrQuery != null) && !solrQuery.isEmpty()) { + returnList.add(CommonParams.Q, solrQuery); + } + if (fqs != null) { + for (final String fq : fqs) { + returnList.add(CommonParams.FQ, fq); + } + } + if ((df != null) && !df.isEmpty()) { + returnList.add(CommonParams.DF, df); + } + if (returnList.size() > 0) { + return new LocalSolrQueryRequest(core, returnList); + } else { + return null; + } + } + + @Override + public FeatureScorer scorer(LeafReaderContext context) throws IOException { + Scorer solrScorer = null; + if (solrQueryWeight != null) { + solrScorer = solrQueryWeight.scorer(context); + } + + final DocIdSetIterator idItr = getDocIdSetIteratorFromQueries( + queryAndFilters, context); + if (idItr != null) { + return solrScorer == null ? new ValueFeatureScorer(this, 1f, idItr) + : new SolrFeatureScorer(this, solrScorer, + new SolrFeatureScorerIterator(idItr, solrScorer.iterator())); + } else { + return null; + } + } + + /** + * Given a list of Solr filters/queries, return a doc iterator that + * traverses over the documents that matched all the criteria of the + * queries. + * + * @param queries + * Filtering criteria to match documents against + * @param context + * Index reader + * @return DocIdSetIterator to traverse documents that matched all filter + * criteria + */ + private DocIdSetIterator getDocIdSetIteratorFromQueries(List queries, + LeafReaderContext context) throws IOException { + final SolrIndexSearcher.ProcessedFilter pf = ((SolrIndexSearcher) searcher) + .getProcessedFilter(null, queries); + final Bits liveDocs = context.reader().getLiveDocs(); + + DocIdSetIterator idIter = null; + if (pf.filter != null) { + final DocIdSet idSet = pf.filter.getDocIdSet(context, liveDocs); + if (idSet != null) { + idIter = idSet.iterator(); + } + } + + return idIter; + } + + /** + * Scorer for a SolrFeature + **/ + public class SolrFeatureScorer extends FeatureScorer { + final private Scorer solrScorer; + + public SolrFeatureScorer(FeatureWeight weight, Scorer solrScorer, + SolrFeatureScorerIterator itr) { + super(weight, itr); + this.solrScorer = solrScorer; + } + + @Override + public float score() throws IOException { + try { + return solrScorer.score(); + } catch (UnsupportedOperationException e) { + throw new FeatureException( + e.toString() + ": " + + "Unable to extract feature for " + + name, e); + } + } + } + + /** + * An iterator that allows to iterate only on the documents for which a feature has + * a value. + **/ + public class SolrFeatureScorerIterator extends DocIdSetIterator { + + final private DocIdSetIterator filterIterator; + final private DocIdSetIterator scorerFilter; + + SolrFeatureScorerIterator(DocIdSetIterator filterIterator, + DocIdSetIterator scorerFilter) { + this.filterIterator = filterIterator; + this.scorerFilter = scorerFilter; + } + + @Override + public int docID() { + return filterIterator.docID(); + } + + @Override + public int nextDoc() throws IOException { + int docID = filterIterator.nextDoc(); + scorerFilter.advance(docID); + return docID; + } + + @Override + public int advance(int target) throws IOException { + // We use iterator to catch the scorer up since + // that checks if the target id is in the query + all the filters + int docID = filterIterator.advance(target); + scorerFilter.advance(docID); + return docID; + } + + @Override + public long cost() { + return filterIterator.cost() + scorerFilter.cost(); + } + + } + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/ValueFeature.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/ValueFeature.java new file mode 100644 index 00000000000..61aa9e5fb7d --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/ValueFeature.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.Map; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.solr.request.SolrQueryRequest; +/** + * This feature allows to return a constant given value for the current document. + * + * Example configuration: + *
{
+   "name" : "userFromMobile",
+   "class" : "org.apache.solr.ltr.feature.ValueFeature",
+   "params" : { "value" : "${userFromMobile}", "required":true }
+ }
+ * + *You can place a constant value like "1.3f" in the value params, but many times you + *would want to pass in external information to use per request. For instance, maybe + *you want to rank things differently if the search came from a mobile device, or maybe + *you want to use your external query intent system as a feature. + *In the rerank request you can pass in rq={... efi.userFromMobile=1}, and the above + *feature will return 1 for all the docs for that request. If required is set to true, + *the request will return an error since you failed to pass in the efi, otherwise if will + *just skip the feature and use a default value of 0 instead. + **/ +public class ValueFeature extends Feature { + private float configValue = -1f; + private String configValueStr = null; + + private Object value = null; + private Boolean required = null; + + public Object getValue() { + return value; + } + + public void setValue(Object value) { + this.value = value; + if (value instanceof String) { + configValueStr = (String) value; + } else if (value instanceof Double) { + configValue = ((Double) value).floatValue(); + } else if (value instanceof Float) { + configValue = ((Float) value).floatValue(); + } else if (value instanceof Integer) { + configValue = ((Integer) value).floatValue(); + } else if (value instanceof Long) { + configValue = ((Long) value).floatValue(); + } else { + throw new FeatureException("Invalid type for 'value' in params for " + this); + } + } + + public boolean isRequired() { + return Boolean.TRUE.equals(required); + } + + public void setRequired(boolean required) { + this.required = required; + } + + @Override + public LinkedHashMap paramsToMap() { + final LinkedHashMap params = new LinkedHashMap<>(2, 1.0f); + params.put("value", value); + if (required != null) { + params.put("required", required); + } + return params; + } + + @Override + protected void validate() throws FeatureException { + if (configValueStr != null && configValueStr.trim().isEmpty()) { + throw new FeatureException("Empty field 'value' in params for " + this); + } + } + + public ValueFeature(String name, Map params) { + super(name, params); + } + + @Override + public FeatureWeight createWeight(IndexSearcher searcher, boolean needsScores, + SolrQueryRequest request, Query originalQuery, Map efi) + throws IOException { + return new ValueFeatureWeight(searcher, request, originalQuery, efi); + } + + public class ValueFeatureWeight extends FeatureWeight { + + final protected Float featureValue; + + public ValueFeatureWeight(IndexSearcher searcher, + SolrQueryRequest request, Query originalQuery, Map efi) { + super(ValueFeature.this, searcher, request, originalQuery, efi); + if (configValueStr != null) { + final String expandedValue = macroExpander.expand(configValueStr); + if (expandedValue != null) { + featureValue = Float.parseFloat(expandedValue); + } else if (isRequired()) { + throw new FeatureException(this.getClass().getSimpleName() + " requires efi parameter that was not passed in request."); + } else { + featureValue=null; + } + } else { + featureValue = configValue; + } + } + + @Override + public FeatureScorer scorer(LeafReaderContext context) throws IOException { + if(featureValue!=null) { + return new ValueFeatureScorer(this, featureValue, + DocIdSetIterator.all(DocIdSetIterator.NO_MORE_DOCS)); + } else { + return null; + } + } + + + + + + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/package-info.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/package-info.java new file mode 100644 index 00000000000..456fffcffa3 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Contains Feature related classes + */ +package org.apache.solr.ltr.feature; diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/LTRScoringModel.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/LTRScoringModel.java new file mode 100644 index 00000000000..9edcfe50d0e --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/LTRScoringModel.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.model; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Explanation; +import org.apache.solr.core.SolrResourceLoader; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.feature.FeatureException; +import org.apache.solr.ltr.norm.IdentityNormalizer; +import org.apache.solr.ltr.norm.Normalizer; +import org.apache.solr.util.SolrPluginUtils; + +/** + * A scoring model computes scores that can be used to rerank documents. + *

+ * A scoring model consists of + *

    + *
  • a list of features ({@link Feature}) and + *
  • a list of normalizers ({@link Normalizer}) plus + *
  • parameters or configuration to represent the scoring algorithm. + *
+ *

+ * Example configuration (snippet): + *

{
+   "class" : "...",
+   "name" : "myModelName",
+   "features" : [
+       {
+         "name" : "isBook"
+       },
+       {
+         "name" : "originalScore",
+         "norm": {
+             "class" : "org.apache.solr.ltr.norm.StandardNormalizer",
+             "params" : { "avg":"100", "std":"10" }
+         }
+       },
+       {
+         "name" : "price",
+         "norm": {
+             "class" : "org.apache.solr.ltr.norm.MinMaxNormalizer",
+             "params" : { "min":"0", "max":"1000" }
+         }
+       }
+   ],
+   "params" : {
+       ...
+   }
+}
+ *

+ * {@link LTRScoringModel} is an abstract class and concrete classes must + * implement the {@link #score(float[])} and + * {@link #explain(LeafReaderContext, int, float, List)} methods. + */ +public abstract class LTRScoringModel { + + protected final String name; + private final String featureStoreName; + protected final List features; + private final List allFeatures; + private final Map params; + private final List norms; + + public static LTRScoringModel getInstance(SolrResourceLoader solrResourceLoader, + String className, String name, List features, + List norms, + String featureStoreName, List allFeatures, + Map params) throws ModelException { + final LTRScoringModel model; + try { + // create an instance of the model + model = solrResourceLoader.newInstance( + className, + LTRScoringModel.class, + new String[0], // no sub packages + new Class[] { String.class, List.class, List.class, String.class, List.class, Map.class }, + new Object[] { name, features, norms, featureStoreName, allFeatures, params }); + if (params != null) { + SolrPluginUtils.invokeSetters(model, params.entrySet()); + } + } catch (final Exception e) { + throw new ModelException("Model type does not exist " + className, e); + } + model.validate(); + return model; + } + + public LTRScoringModel(String name, List features, + List norms, + String featureStoreName, List allFeatures, + Map params) { + this.name = name; + this.features = features; + this.featureStoreName = featureStoreName; + this.allFeatures = allFeatures; + this.params = params; + this.norms = norms; + } + + /** + * Validate that settings make sense and throws + * {@link ModelException} if they do not make sense. + */ + protected void validate() throws ModelException { + if (features.isEmpty()) { + throw new ModelException("no features declared for model "+name); + } + final HashSet featureNames = new HashSet<>(); + for (final Feature feature : features) { + final String featureName = feature.getName(); + if (!featureNames.add(featureName)) { + throw new ModelException("duplicated feature "+featureName+" in model "+name); + } + } + if (features.size() != norms.size()) { + throw new ModelException("counted "+features.size()+" features and "+norms.size()+" norms in model "+name); + } + } + + /** + * @return the norms + */ + public List getNorms() { + return Collections.unmodifiableList(norms); + } + + /** + * @return the name + */ + public String getName() { + return name; + } + + /** + * @return the features + */ + public List getFeatures() { + return Collections.unmodifiableList(features); + } + + public Map getParams() { + return params; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = (prime * result) + ((features == null) ? 0 : features.hashCode()); + result = (prime * result) + ((name == null) ? 0 : name.hashCode()); + result = (prime * result) + ((params == null) ? 0 : params.hashCode()); + result = (prime * result) + ((norms == null) ? 0 : norms.hashCode()); + result = (prime * result) + ((featureStoreName == null) ? 0 : featureStoreName.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + final LTRScoringModel other = (LTRScoringModel) obj; + if (features == null) { + if (other.features != null) { + return false; + } + } else if (!features.equals(other.features)) { + return false; + } + if (norms == null) { + if (other.norms != null) { + return false; + } + } else if (!norms.equals(other.norms)) { + return false; + } + if (name == null) { + if (other.name != null) { + return false; + } + } else if (!name.equals(other.name)) { + return false; + } + if (params == null) { + if (other.params != null) { + return false; + } + } else if (!params.equals(other.params)) { + return false; + } + if (featureStoreName == null) { + if (other.featureStoreName != null) { + return false; + } + } else if (!featureStoreName.equals(other.featureStoreName)) { + return false; + } + + + return true; + } + + public boolean hasParams() { + return !((params == null) || params.isEmpty()); + } + + public Collection getAllFeatures() { + return allFeatures; + } + + public String getFeatureStoreName() { + return featureStoreName; + } + + /** + * Given a list of normalized values for all features a scoring algorithm + * cares about, calculate and return a score. + * + * @param modelFeatureValuesNormalized + * List of normalized feature values. Each feature is identified by + * its id, which is the index in the array + * @return The final score for a document + */ + public abstract float score(float[] modelFeatureValuesNormalized); + + /** + * Similar to the score() function, except it returns an explanation of how + * the features were used to calculate the score. + * + * @param context + * Context the document is in + * @param doc + * Document to explain + * @param finalScore + * Original score + * @param featureExplanations + * Explanations for each feature calculation + * @return Explanation for the scoring of a document + */ + public abstract Explanation explain(LeafReaderContext context, int doc, + float finalScore, List featureExplanations); + + @Override + public String toString() { + return getClass().getSimpleName() + "(name="+getName()+")"; + } + + /** + * Goes through all the stored feature values, and calculates the normalized + * values for all the features that will be used for scoring. + */ + public void normalizeFeaturesInPlace(float[] modelFeatureValues) { + float[] modelFeatureValuesNormalized = modelFeatureValues; + if (modelFeatureValues.length != norms.size()) { + throw new FeatureException("Must have normalizer for every feature"); + } + for(int idx = 0; idx < modelFeatureValuesNormalized.length; ++idx) { + modelFeatureValuesNormalized[idx] = + norms.get(idx).normalize(modelFeatureValuesNormalized[idx]); + } + } + + public Explanation getNormalizerExplanation(Explanation e, int idx) { + Normalizer n = norms.get(idx); + if (n != IdentityNormalizer.INSTANCE) { + return n.explain(e); + } + return e; + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/LinearModel.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/LinearModel.java new file mode 100644 index 00000000000..57fc5ad43c4 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/LinearModel.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.model; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Explanation; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.norm.Normalizer; + +/** + * A scoring model that computes scores using a dot product. + * Example models are RankSVM and Pranking. + *

+ * Example configuration: + *

{
+   "class" : "org.apache.solr.ltr.model.LinearModel",
+   "name" : "myModelName",
+   "features" : [
+       { "name" : "userTextTitleMatch" },
+       { "name" : "originalScore" },
+       { "name" : "isBook" }
+   ],
+   "params" : {
+       "weights" : {
+           "userTextTitleMatch" : 1.0,
+           "originalScore" : 0.5,
+           "isBook" : 0.1
+       }
+   }
+}
+ *

+ * Background reading: + *

+ * + */ +public class LinearModel extends LTRScoringModel { + + protected Float[] featureToWeight; + + public void setWeights(Object weights) { + final Map modelWeights = (Map) weights; + for (int ii = 0; ii < features.size(); ++ii) { + final String key = features.get(ii).getName(); + final Double val = modelWeights.get(key); + featureToWeight[ii] = (val == null ? null : new Float(val.floatValue())); + } + } + + public LinearModel(String name, List features, + List norms, + String featureStoreName, List allFeatures, + Map params) { + super(name, features, norms, featureStoreName, allFeatures, params); + featureToWeight = new Float[features.size()]; + } + + @Override + protected void validate() throws ModelException { + super.validate(); + + final ArrayList missingWeightFeatureNames = new ArrayList(); + for (int i = 0; i < features.size(); ++i) { + if (featureToWeight[i] == null) { + missingWeightFeatureNames.add(features.get(i).getName()); + } + } + if (missingWeightFeatureNames.size() == features.size()) { + throw new ModelException("Model " + name + " doesn't contain any weights"); + } + if (!missingWeightFeatureNames.isEmpty()) { + throw new ModelException("Model " + name + " lacks weight(s) for "+missingWeightFeatureNames); + } + } + + @Override + public float score(float[] modelFeatureValuesNormalized) { + float score = 0; + for (int i = 0; i < modelFeatureValuesNormalized.length; ++i) { + score += modelFeatureValuesNormalized[i] * featureToWeight[i]; + } + return score; + } + + @Override + public Explanation explain(LeafReaderContext context, int doc, + float finalScore, List featureExplanations) { + final List details = new ArrayList<>(); + int index = 0; + + for (final Explanation featureExplain : featureExplanations) { + final List featureDetails = new ArrayList<>(); + featureDetails.add(Explanation.match(featureToWeight[index], + "weight on feature")); + featureDetails.add(featureExplain); + + details.add(Explanation.match(featureExplain.getValue() + * featureToWeight[index], "prod of:", featureDetails)); + index++; + } + + return Explanation.match(finalScore, toString() + + " model applied to features, sum of:", details); + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder(getClass().getSimpleName()); + sb.append("(name=").append(getName()); + sb.append(",featureWeights=["); + for (int ii = 0; ii < features.size(); ++ii) { + if (ii>0) { + sb.append(','); + } + final String key = features.get(ii).getName(); + sb.append(key).append('=').append(featureToWeight[ii]); + } + sb.append("])"); + return sb.toString(); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/ModelException.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/ModelException.java new file mode 100644 index 00000000000..de8786d81c0 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/ModelException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.model; + +public class ModelException extends RuntimeException { + + private static final long serialVersionUID = 1L; + + public ModelException(String message) { + super(message); + } + + public ModelException(String message, Exception cause) { + super(message, cause); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/MultipleAdditiveTreesModel.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/MultipleAdditiveTreesModel.java new file mode 100644 index 00000000000..4fa595ecb87 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/MultipleAdditiveTreesModel.java @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.model; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Explanation; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.norm.Normalizer; +import org.apache.solr.util.SolrPluginUtils; + +/** + * A scoring model that computes scores based on the summation of multiple weighted trees. + * Example models are LambdaMART and Gradient Boosted Regression Trees (GBRT) . + *

+ * Example configuration: +

{
+   "class" : "org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
+   "name" : "multipleadditivetreesmodel",
+   "features":[
+       { "name" : "userTextTitleMatch"},
+       { "name" : "originalScore"}
+   ],
+   "params" : {
+       "trees" : [
+           {
+               "weight" : 1,
+               "root": {
+                   "feature" : "userTextTitleMatch",
+                   "threshold" : 0.5,
+                   "left" : {
+                       "value" : -100
+                   },
+                   "right" : {
+                       "feature" : "originalScore",
+                       "threshold" : 10.0,
+                       "left" : {
+                           "value" : 50
+                       },
+                       "right" : {
+                           "value" : 75
+                       }
+                   }
+               }
+           },
+           {
+               "weight" : 2,
+               "root" : {
+                   "value" : -10
+               }
+           }
+       ]
+   }
+}
+ *

+ * Background reading: + *

+ * + */ +public class MultipleAdditiveTreesModel extends LTRScoringModel { + + private final HashMap fname2index; + private List trees; + + private RegressionTree createRegressionTree(Map map) { + final RegressionTree rt = new RegressionTree(); + if (map != null) { + SolrPluginUtils.invokeSetters(rt, map.entrySet()); + } + return rt; + } + + private RegressionTreeNode createRegressionTreeNode(Map map) { + final RegressionTreeNode rtn = new RegressionTreeNode(); + if (map != null) { + SolrPluginUtils.invokeSetters(rtn, map.entrySet()); + } + return rtn; + } + + public class RegressionTreeNode { + private static final float NODE_SPLIT_SLACK = 1E-6f; + + private float value = 0f; + private String feature; + private int featureIndex = -1; + private Float threshold; + private RegressionTreeNode left; + private RegressionTreeNode right; + + public void setValue(float value) { + this.value = value; + } + + public void setValue(String value) { + this.value = Float.parseFloat(value); + } + + public void setFeature(String feature) { + this.feature = feature; + final Integer idx = fname2index.get(this.feature); + // this happens if the tree specifies a feature that does not exist + // this could be due to lambdaSmart building off of pre-existing trees + // that use a feature that is no longer output during feature extraction + featureIndex = (idx == null) ? -1 : idx; + } + + public void setThreshold(float threshold) { + this.threshold = threshold + NODE_SPLIT_SLACK; + } + + public void setThreshold(String threshold) { + this.threshold = Float.parseFloat(threshold) + NODE_SPLIT_SLACK; + } + + public void setLeft(Object left) { + this.left = createRegressionTreeNode((Map) left); + } + + public void setRight(Object right) { + this.right = createRegressionTreeNode((Map) right); + } + + public boolean isLeaf() { + return feature == null; + } + + public float score(float[] featureVector) { + if (isLeaf()) { + return value; + } + + // unsupported feature (tree is looking for a feature that does not exist) + if ((featureIndex < 0) || (featureIndex >= featureVector.length)) { + return 0f; + } + + if (featureVector[featureIndex] <= threshold) { + return left.score(featureVector); + } else { + return right.score(featureVector); + } + } + + public String explain(float[] featureVector) { + if (isLeaf()) { + return "val: " + value; + } + + // unsupported feature (tree is looking for a feature that does not exist) + if ((featureIndex < 0) || (featureIndex >= featureVector.length)) { + return "'" + feature + "' does not exist in FV, Return Zero"; + } + + // could store extra information about how much training data supported + // each branch and report + // that here + + if (featureVector[featureIndex] <= threshold) { + String rval = "'" + feature + "':" + featureVector[featureIndex] + " <= " + + threshold + ", Go Left | "; + return rval + left.explain(featureVector); + } else { + String rval = "'" + feature + "':" + featureVector[featureIndex] + " > " + + threshold + ", Go Right | "; + return rval + right.explain(featureVector); + } + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder(); + if (isLeaf()) { + sb.append(value); + } else { + sb.append("(feature=").append(feature); + sb.append(",threshold=").append(threshold.floatValue()-NODE_SPLIT_SLACK); + sb.append(",left=").append(left); + sb.append(",right=").append(right); + sb.append(')'); + } + return sb.toString(); + } + + public RegressionTreeNode() { + } + + public void validate() throws ModelException { + if (isLeaf()) { + if (left != null || right != null) { + throw new ModelException("MultipleAdditiveTreesModel tree node is leaf with left="+left+" and right="+right); + } + return; + } + if (null == threshold) { + throw new ModelException("MultipleAdditiveTreesModel tree node is missing threshold"); + } + if (null == left) { + throw new ModelException("MultipleAdditiveTreesModel tree node is missing left"); + } else { + left.validate(); + } + if (null == right) { + throw new ModelException("MultipleAdditiveTreesModel tree node is missing right"); + } else { + right.validate(); + } + } + + } + + public class RegressionTree { + + private Float weight; + private RegressionTreeNode root; + + public void setWeight(float weight) { + this.weight = new Float(weight); + } + + public void setWeight(String weight) { + this.weight = new Float(weight); + } + + public void setRoot(Object root) { + this.root = createRegressionTreeNode((Map)root); + } + + public float score(float[] featureVector) { + return weight.floatValue() * root.score(featureVector); + } + + public String explain(float[] featureVector) { + return root.explain(featureVector); + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder(); + sb.append("(weight=").append(weight); + sb.append(",root=").append(root); + sb.append(")"); + return sb.toString(); + } + + public RegressionTree() { + } + + public void validate() throws ModelException { + if (weight == null) { + throw new ModelException("MultipleAdditiveTreesModel tree doesn't contain a weight"); + } + if (root == null) { + throw new ModelException("MultipleAdditiveTreesModel tree doesn't contain a tree"); + } else { + root.validate(); + } + } + } + + public void setTrees(Object trees) { + this.trees = new ArrayList(); + for (final Object o : (List) trees) { + final RegressionTree rt = createRegressionTree((Map) o); + this.trees.add(rt); + } + } + + public MultipleAdditiveTreesModel(String name, List features, + List norms, + String featureStoreName, List allFeatures, + Map params) { + super(name, features, norms, featureStoreName, allFeatures, params); + + fname2index = new HashMap(); + for (int i = 0; i < features.size(); ++i) { + final String key = features.get(i).getName(); + fname2index.put(key, i); + } + } + + @Override + protected void validate() throws ModelException { + super.validate(); + if (trees == null) { + throw new ModelException("no trees declared for model "+name); + } + for (RegressionTree tree : trees) { + tree.validate(); + } + } + + @Override + public float score(float[] modelFeatureValuesNormalized) { + float score = 0; + for (final RegressionTree t : trees) { + score += t.score(modelFeatureValuesNormalized); + } + return score; + } + + // ///////////////////////////////////////// + // produces a string that looks like: + // 40.0 = multipleadditivetreesmodel [ org.apache.solr.ltr.model.MultipleAdditiveTreesModel ] + // model applied to + // features, sum of: + // 50.0 = tree 0 | 'matchedTitle':1.0 > 0.500001, Go Right | + // 'this_feature_doesnt_exist' does not + // exist in FV, Go Left | val: 50.0 + // -10.0 = tree 1 | val: -10.0 + @Override + public Explanation explain(LeafReaderContext context, int doc, + float finalScore, List featureExplanations) { + final float[] fv = new float[featureExplanations.size()]; + int index = 0; + for (final Explanation featureExplain : featureExplanations) { + fv[index] = featureExplain.getValue(); + index++; + } + + final List details = new ArrayList<>(); + index = 0; + + for (final RegressionTree t : trees) { + final float score = t.score(fv); + final Explanation p = Explanation.match(score, "tree " + index + " | " + + t.explain(fv)); + details.add(p); + index++; + } + + return Explanation.match(finalScore, toString() + + " model applied to features, sum of:", details); + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder(getClass().getSimpleName()); + sb.append("(name=").append(getName()); + sb.append(",trees=["); + for (int ii = 0; ii < trees.size(); ++ii) { + if (ii>0) { + sb.append(','); + } + sb.append(trees.get(ii)); + } + sb.append("])"); + return sb.toString(); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/package-info.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/package-info.java new file mode 100644 index 00000000000..32bd626f95b --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/model/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Contains Model related classes + */ +package org.apache.solr.ltr.model; diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/IdentityNormalizer.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/IdentityNormalizer.java new file mode 100644 index 00000000000..a3d1026064f --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/IdentityNormalizer.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.norm; + +import java.util.LinkedHashMap; + +/** + * A Normalizer that normalizes a feature value to itself. This is the + * default normalizer class, if no normalizer is configured then the + * IdentityNormalizer will be used. + */ +public class IdentityNormalizer extends Normalizer { + + public static final IdentityNormalizer INSTANCE = new IdentityNormalizer(); + + public IdentityNormalizer() { + + } + + @Override + public float normalize(float value) { + return value; + } + + @Override + public LinkedHashMap paramsToMap() { + return null; + } + + @Override + protected void validate() throws NormalizerException { + } + + @Override + public String toString() { + return getClass().getSimpleName(); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/MinMaxNormalizer.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/MinMaxNormalizer.java new file mode 100644 index 00000000000..92e233c95ca --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/MinMaxNormalizer.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.norm; + +import java.util.LinkedHashMap; + +/** + * A Normalizer to scale a feature value using a (min,max) range. + *

+ * Example configuration: +

+"norm" : {
+    "class" : "org.apache.solr.ltr.norm.MinMaxNormalizer",
+    "params" : { "min":"0", "max":"50" }
+}
+
+ * Example normalizations: + *
    + *
  • -5 will be normalized to -0.1 + *
  • 55 will be normalized to 1.1 + *
  • +5 will be normalized to +0.1 + *
+ */ +public class MinMaxNormalizer extends Normalizer { + + private float min = Float.NEGATIVE_INFINITY; + private float max = Float.POSITIVE_INFINITY; + private float delta = max - min; + + private void updateDelta() { + delta = max - min; + } + + public float getMin() { + return min; + } + + public void setMin(float min) { + this.min = min; + updateDelta(); + } + + public void setMin(String min) { + this.min = Float.parseFloat(min); + updateDelta(); + } + + public float getMax() { + return max; + } + + public void setMax(float max) { + this.max = max; + updateDelta(); + } + + public void setMax(String max) { + this.max = Float.parseFloat(max); + updateDelta(); + } + + @Override + protected void validate() throws NormalizerException { + if (delta == 0f) { + throw + new NormalizerException("MinMax Normalizer delta must not be zero " + + "| min = " + min + ",max = " + max + ",delta = " + delta); + } + } + + @Override + public float normalize(float value) { + return (value - min) / delta; + } + + @Override + public LinkedHashMap paramsToMap() { + final LinkedHashMap params = new LinkedHashMap<>(2, 1.0f); + params.put("min", min); + params.put("max", max); + return params; + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder(64); // default initialCapacity of 16 won't be enough + sb.append(getClass().getSimpleName()).append('('); + sb.append("min=").append(min); + sb.append(",max=").append(max).append(')'); + return sb.toString(); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/Normalizer.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/Normalizer.java new file mode 100644 index 00000000000..2b311f8cc7b --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/Normalizer.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.norm; + +import java.util.LinkedHashMap; +import java.util.Map; + +import org.apache.lucene.search.Explanation; +import org.apache.solr.core.SolrResourceLoader; +import org.apache.solr.util.SolrPluginUtils; + +/** + * A normalizer normalizes the value of a feature. After the feature values + * have been computed, the {@link Normalizer#normalize(float)} methods will + * be called and the resulting values will be used by the model. + */ +public abstract class Normalizer { + + + public abstract float normalize(float value); + + public abstract LinkedHashMap paramsToMap(); + + public Explanation explain(Explanation explain) { + final float normalized = normalize(explain.getValue()); + final String explainDesc = "normalized using " + toString(); + + return Explanation.match(normalized, explainDesc, explain); + } + + public static Normalizer getInstance(SolrResourceLoader solrResourceLoader, + String className, Map params) { + final Normalizer f = solrResourceLoader.newInstance(className, Normalizer.class); + if (params != null) { + SolrPluginUtils.invokeSetters(f, params.entrySet()); + } + f.validate(); + return f; + } + + /** + * As part of creation of a normalizer instance, this function confirms + * that the normalizer parameters are valid. + * + * @throws NormalizerException + * Normalizer Exception + */ + protected abstract void validate() throws NormalizerException; + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/NormalizerException.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/NormalizerException.java new file mode 100644 index 00000000000..5b33f05a7a9 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/NormalizerException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.norm; + +public class NormalizerException extends RuntimeException { + + private static final long serialVersionUID = 1L; + + public NormalizerException(String message) { + super(message); + } + + public NormalizerException(String message, Exception cause) { + super(message, cause); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/StandardNormalizer.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/StandardNormalizer.java new file mode 100644 index 00000000000..7ab525cb296 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/StandardNormalizer.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.norm; + +import java.util.LinkedHashMap; + +/** + * A Normalizer to scale a feature value around an average-and-standard-deviation distribution. + *

+ * Example configuration: +

+"norm" : {
+    "class" : "org.apache.solr.ltr.norm.StandardNormalizer",
+    "params" : { "avg":"42", "std":"6" }
+}
+
+ *

+ * Example normalizations: + *

    + *
  • 39 will be normalized to -0.5 + *
  • 42 will be normalized to 0 + *
  • 45 will be normalized to +0.5 + *
+ */ +public class StandardNormalizer extends Normalizer { + + private float avg = 0f; + private float std = 1f; + + public float getAvg() { + return avg; + } + + public void setAvg(float avg) { + this.avg = avg; + } + + public float getStd() { + return std; + } + + public void setStd(float std) { + this.std = std; + } + + public void setAvg(String avg) { + this.avg = Float.parseFloat(avg); + } + + public void setStd(String std) { + this.std = Float.parseFloat(std); + } + + @Override + public float normalize(float value) { + return (value - avg) / std; + } + + @Override + protected void validate() throws NormalizerException { + if (std <= 0f) { + throw + new NormalizerException("Standard Normalizer standard deviation must " + + "be positive | avg = " + avg + ",std = " + std); + } + } + + @Override + public LinkedHashMap paramsToMap() { + final LinkedHashMap params = new LinkedHashMap<>(2, 1.0f); + params.put("avg", avg); + params.put("std", std); + return params; + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder(64); // default initialCapacity of 16 won't be enough + sb.append(getClass().getSimpleName()).append('('); + sb.append("avg=").append(avg); + sb.append(",std=").append(avg).append(')'); + return sb.toString(); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/package-info.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/package-info.java new file mode 100644 index 00000000000..164b425df52 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/norm/package-info.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * A normalizer normalizes the value of a feature. Once that the feature values + * will be computed, the normalizer will be applied and the resulting values + * will be received by the model. + */ +package org.apache.solr.ltr.norm; diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/package-info.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/package-info.java new file mode 100644 index 00000000000..59aebe83f2e --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/package-info.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + *

+ * This package contains the main logic for performing the reranking using + * a Learning to Rank model. + *

+ *

+ * A model will be applied on each document through a {@link org.apache.solr.ltr.LTRScoringQuery}, a + * subclass of {@link org.apache.lucene.search.Query}. As a normal query, + * the learned model will produce a new score + * for each document reranked. + *

+ *

+ * A {@link org.apache.solr.ltr.LTRScoringQuery} is created by providing an instance of + * {@link org.apache.solr.ltr.model.LTRScoringModel}. An instance of + * {@link org.apache.solr.ltr.model.LTRScoringModel} + * defines how to combine the features in order to create a new + * score for a document. A new Learning to Rank model is plugged + * into the framework by extending {@link org.apache.solr.ltr.model.LTRScoringModel}, + * (see for example {@link org.apache.solr.ltr.model.MultipleAdditiveTreesModel} and {@link org.apache.solr.ltr.model.LinearModel}). + *

+ *

+ * The {@link org.apache.solr.ltr.LTRScoringQuery} will take care of computing the values of + * all the features (see {@link org.apache.solr.ltr.feature.Feature}) and then will delegate the final score + * generation to the {@link org.apache.solr.ltr.model.LTRScoringModel}, by calling the method + * {@link org.apache.solr.ltr.model.LTRScoringModel#score(float[] modelFeatureValuesNormalized) score(float[] modelFeatureValuesNormalized)}. + *

+ */ +package org.apache.solr.ltr; diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/FeatureStore.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/FeatureStore.java new file mode 100644 index 00000000000..ab2595ff38c --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/FeatureStore.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.store; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; + +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.feature.FeatureException; + +public class FeatureStore { + + /** the name of the default feature store **/ + public static final String DEFAULT_FEATURE_STORE_NAME = "_DEFAULT_"; + + private final LinkedHashMap store = new LinkedHashMap<>(); // LinkedHashMap because we need predictable iteration order + private final String name; + + public FeatureStore(String name) { + this.name = name; + } + + public String getName() { + return name; + } + + public Feature get(String name) { + return store.get(name); + } + + public void add(Feature feature) { + final String name = feature.getName(); + if (store.containsKey(name)) { + throw new FeatureException(name + + " already contained in the store, please use a different name"); + } + feature.setIndex(store.size()); + store.put(name, feature); + } + + public List getFeatures() { + final List storeValues = new ArrayList(store.values()); + return Collections.unmodifiableList(storeValues); + } + + @Override + public String toString() { + return "FeatureStore [features=" + store.keySet() + "]"; + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/ModelStore.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/ModelStore.java new file mode 100644 index 00000000000..dbb065f6403 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/ModelStore.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.store; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.ltr.model.ModelException; + +/** + * Contains the model and features declared. + */ +public class ModelStore { + + private final Map availableModels; + + public ModelStore() { + availableModels = new HashMap<>(); + } + + public synchronized LTRScoringModel getModel(String name) { + return availableModels.get(name); + } + + public void clear() { + availableModels.clear(); + } + + public List getModels() { + final List availableModelsValues = + new ArrayList(availableModels.values()); + return Collections.unmodifiableList(availableModelsValues); + } + + @Override + public String toString() { + return "ModelStore [availableModels=" + availableModels.keySet() + "]"; + } + + public LTRScoringModel delete(String modelName) { + return availableModels.remove(modelName); + } + + public synchronized void addModel(LTRScoringModel modeldata) + throws ModelException { + final String name = modeldata.getName(); + + if (availableModels.containsKey(name)) { + throw new ModelException("model '" + name + + "' already exists. Please use a different name"); + } + + availableModels.put(modeldata.getName(), modeldata); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/package-info.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/package-info.java new file mode 100644 index 00000000000..1ed9bffc86e --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Contains feature and model store related classes. + */ +package org.apache.solr.ltr.store; diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/rest/ManagedFeatureStore.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/rest/ManagedFeatureStore.java new file mode 100644 index 00000000000..beb217c5c37 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/rest/ManagedFeatureStore.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.store.rest; + +import java.lang.invoke.MethodHandles; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.apache.solr.common.SolrException; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.core.SolrCore; +import org.apache.solr.core.SolrResourceLoader; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.store.FeatureStore; +import org.apache.solr.response.SolrQueryResponse; +import org.apache.solr.rest.BaseSolrResource; +import org.apache.solr.rest.ManagedResource; +import org.apache.solr.rest.ManagedResourceObserver; +import org.apache.solr.rest.ManagedResourceStorage; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Managed resource for a storing a feature. + */ +public class ManagedFeatureStore extends ManagedResource implements ManagedResource.ChildResourceSupport { + + public static void registerManagedFeatureStore(SolrResourceLoader solrResourceLoader, + ManagedResourceObserver managedResourceObserver) { + solrResourceLoader.getManagedResourceRegistry().registerManagedResource( + REST_END_POINT, + ManagedFeatureStore.class, + managedResourceObserver); + } + + public static ManagedFeatureStore getManagedFeatureStore(SolrCore core) { + return (ManagedFeatureStore) core.getRestManager() + .getManagedResource(REST_END_POINT); + } + + /** the feature store rest endpoint **/ + public static final String REST_END_POINT = "/schema/feature-store"; + // TODO: reduce from public to package visibility (once tests no longer need public access) + + /** name of the attribute containing the feature class **/ + static final String CLASS_KEY = "class"; + /** name of the attribute containing the feature name **/ + static final String NAME_KEY = "name"; + /** name of the attribute containing the feature params **/ + static final String PARAMS_KEY = "params"; + /** name of the attribute containing the feature store used **/ + static final String FEATURE_STORE_NAME_KEY = "store"; + + private final Map stores = new HashMap<>(); + + /** + * Managed feature store: the name of the attribute containing all the feature + * stores + **/ + private static final String FEATURE_STORE_JSON_FIELD = "featureStores"; + + /** + * Managed feature store: the name of the attribute containing all the + * features of a feature store + **/ + private static final String FEATURES_JSON_FIELD = "features"; + + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + public ManagedFeatureStore(String resourceId, SolrResourceLoader loader, + ManagedResourceStorage.StorageIO storageIO) throws SolrException { + super(resourceId, loader, storageIO); + + } + + public synchronized FeatureStore getFeatureStore(String name) { + if (name == null) { + name = FeatureStore.DEFAULT_FEATURE_STORE_NAME; + } + if (!stores.containsKey(name)) { + stores.put(name, new FeatureStore(name)); + } + return stores.get(name); + } + + @Override + protected void onManagedDataLoadedFromStorage(NamedList managedInitArgs, + Object managedData) throws SolrException { + + stores.clear(); + log.info("------ managed feature ~ loading ------"); + if (managedData instanceof List) { + @SuppressWarnings("unchecked") + final List> up = (List>) managedData; + for (final Map u : up) { + final String featureStore = (String) u.get(FEATURE_STORE_NAME_KEY); + addFeature(u, featureStore); + } + } + } + + public synchronized void addFeature(Map map, String featureStore) { + log.info("register feature based on {}", map); + final FeatureStore fstore = getFeatureStore(featureStore); + final Feature feature = fromFeatureMap(solrResourceLoader, map); + fstore.add(feature); + } + + @SuppressWarnings("unchecked") + @Override + public Object applyUpdatesToManagedData(Object updates) { + if (updates instanceof List) { + final List> up = (List>) updates; + for (final Map u : up) { + final String featureStore = (String) u.get(FEATURE_STORE_NAME_KEY); + addFeature(u, featureStore); + } + } + + if (updates instanceof Map) { + // a unique feature + Map updatesMap = (Map) updates; + final String featureStore = (String) updatesMap.get(FEATURE_STORE_NAME_KEY); + addFeature(updatesMap, featureStore); + } + + final List features = new ArrayList<>(); + for (final FeatureStore fs : stores.values()) { + features.addAll(featuresAsManagedResources(fs)); + } + return features; + } + + @Override + public synchronized void doDeleteChild(BaseSolrResource endpoint, String childId) { + if (childId.equals("*")) { + stores.clear(); + } + if (stores.containsKey(childId)) { + stores.remove(childId); + } + storeManagedData(applyUpdatesToManagedData(null)); + } + + /** + * Called to retrieve a named part (the given childId) of the resource at the + * given endpoint. Note: since we have a unique child feature store we ignore + * the childId. + */ + @Override + public void doGet(BaseSolrResource endpoint, String childId) { + final SolrQueryResponse response = endpoint.getSolrResponse(); + + // If no feature store specified, show all the feature stores available + if (childId == null) { + response.add(FEATURE_STORE_JSON_FIELD, stores.keySet()); + } else { + final FeatureStore store = getFeatureStore(childId); + if (store == null) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, + "missing feature store [" + childId + "]"); + } + response.add(FEATURES_JSON_FIELD, + featuresAsManagedResources(store)); + } + } + + private static List featuresAsManagedResources(FeatureStore store) { + final List storedFeatures = store.getFeatures(); + final List features = new ArrayList(storedFeatures.size()); + for (final Feature f : storedFeatures) { + final LinkedHashMap m = toFeatureMap(f); + m.put(FEATURE_STORE_NAME_KEY, store.getName()); + features.add(m); + } + return features; + } + + private static LinkedHashMap toFeatureMap(Feature feat) { + final LinkedHashMap o = new LinkedHashMap<>(4, 1.0f); // 1 extra for caller to add store + o.put(NAME_KEY, feat.getName()); + o.put(CLASS_KEY, feat.getClass().getCanonicalName()); + o.put(PARAMS_KEY, feat.paramsToMap()); + return o; + } + + private static Feature fromFeatureMap(SolrResourceLoader solrResourceLoader, + Map featureMap) { + final String className = (String) featureMap.get(CLASS_KEY); + + final String name = (String) featureMap.get(NAME_KEY); + + @SuppressWarnings("unchecked") + final Map params = (Map) featureMap.get(PARAMS_KEY); + + return Feature.getInstance(solrResourceLoader, className, name, params); + } +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/rest/ManagedModelStore.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/rest/ManagedModelStore.java new file mode 100644 index 00000000000..97aaa4004ad --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/rest/ManagedModelStore.java @@ -0,0 +1,319 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.store.rest; + +import java.lang.invoke.MethodHandles; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.apache.solr.common.SolrException; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.core.SolrCore; +import org.apache.solr.core.SolrResourceLoader; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.ltr.model.ModelException; +import org.apache.solr.ltr.norm.IdentityNormalizer; +import org.apache.solr.ltr.norm.Normalizer; +import org.apache.solr.ltr.store.FeatureStore; +import org.apache.solr.ltr.store.ModelStore; +import org.apache.solr.response.SolrQueryResponse; +import org.apache.solr.rest.BaseSolrResource; +import org.apache.solr.rest.ManagedResource; +import org.apache.solr.rest.ManagedResourceObserver; +import org.apache.solr.rest.ManagedResourceStorage; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Menaged resource for storing a model + */ +public class ManagedModelStore extends ManagedResource implements ManagedResource.ChildResourceSupport { + + public static void registerManagedModelStore(SolrResourceLoader solrResourceLoader, + ManagedResourceObserver managedResourceObserver) { + solrResourceLoader.getManagedResourceRegistry().registerManagedResource( + REST_END_POINT, + ManagedModelStore.class, + managedResourceObserver); + } + + public static ManagedModelStore getManagedModelStore(SolrCore core) { + return (ManagedModelStore) core.getRestManager() + .getManagedResource(REST_END_POINT); + } + + /** the model store rest endpoint **/ + public static final String REST_END_POINT = "/schema/model-store"; + // TODO: reduce from public to package visibility (once tests no longer need public access) + + /** + * Managed model store: the name of the attribute containing all the models of + * a model store + **/ + private static final String MODELS_JSON_FIELD = "models"; + + /** name of the attribute containing a class **/ + static final String CLASS_KEY = "class"; + /** name of the attribute containing the features **/ + static final String FEATURES_KEY = "features"; + /** name of the attribute containing a name **/ + static final String NAME_KEY = "name"; + /** name of the attribute containing a normalizer **/ + static final String NORM_KEY = "norm"; + /** name of the attribute containing parameters **/ + static final String PARAMS_KEY = "params"; + /** name of the attribute containing a store **/ + static final String STORE_KEY = "store"; + + private final ModelStore store; + private ManagedFeatureStore managedFeatureStore; + + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + public ManagedModelStore(String resourceId, SolrResourceLoader loader, + ManagedResourceStorage.StorageIO storageIO) throws SolrException { + super(resourceId, loader, storageIO); + store = new ModelStore(); + } + + public void setManagedFeatureStore(ManagedFeatureStore managedFeatureStore) { + log.info("INIT model store"); + this.managedFeatureStore = managedFeatureStore; + } + + public ManagedFeatureStore getManagedFeatureStore() { + return managedFeatureStore; + } + + private Object managedData; + + @SuppressWarnings("unchecked") + @Override + protected void onManagedDataLoadedFromStorage(NamedList managedInitArgs, + Object managedData) throws SolrException { + store.clear(); + // the managed models on the disk or on zookeeper will be loaded in a lazy + // way, since we need to set the managed features first (unfortunately + // managed resources do not + // decouple the creation of a managed resource with the reading of the data + // from the storage) + this.managedData = managedData; + + } + + public void loadStoredModels() { + log.info("------ managed models ~ loading ------"); + + if ((managedData != null) && (managedData instanceof List)) { + final List> up = (List>) managedData; + for (final Map u : up) { + try { + final LTRScoringModel algo = fromLTRScoringModelMap(solrResourceLoader, u, managedFeatureStore); + addModel(algo); + } catch (final ModelException e) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e); + } + } + } + } + + public synchronized void addModel(LTRScoringModel ltrScoringModel) throws ModelException { + try { + log.info("adding model {}", ltrScoringModel.getName()); + store.addModel(ltrScoringModel); + } catch (final ModelException e) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e); + } + } + + @SuppressWarnings("unchecked") + @Override + protected Object applyUpdatesToManagedData(Object updates) { + if (updates instanceof List) { + final List> up = (List>) updates; + for (final Map u : up) { + try { + final LTRScoringModel algo = fromLTRScoringModelMap(solrResourceLoader, u, managedFeatureStore); + addModel(algo); + } catch (final ModelException e) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e); + } + } + } + + if (updates instanceof Map) { + final Map map = (Map) updates; + try { + final LTRScoringModel algo = fromLTRScoringModelMap(solrResourceLoader, map, managedFeatureStore); + addModel(algo); + } catch (final ModelException e) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e); + } + } + + return modelsAsManagedResources(store.getModels()); + } + + @Override + public synchronized void doDeleteChild(BaseSolrResource endpoint, String childId) { + if (childId.equals("*")) { + store.clear(); + } else { + store.delete(childId); + } + storeManagedData(applyUpdatesToManagedData(null)); + } + + /** + * Called to retrieve a named part (the given childId) of the resource at the + * given endpoint. Note: since we have a unique child managed store we ignore + * the childId. + */ + @Override + public void doGet(BaseSolrResource endpoint, String childId) { + + final SolrQueryResponse response = endpoint.getSolrResponse(); + response.add(MODELS_JSON_FIELD, + modelsAsManagedResources(store.getModels())); + } + + public LTRScoringModel getModel(String modelName) { + // this function replicates getModelStore().getModel(modelName), but + // it simplifies the testing (we can avoid to mock also a ModelStore). + return store.getModel(modelName); + } + + @Override + public String toString() { + return "ManagedModelStore [store=" + store + ", featureStores=" + + managedFeatureStore + "]"; + } + + /** + * Returns the available models as a list of Maps objects. After an update the + * managed resources needs to return the resources in this format in order to + * store in json somewhere (zookeeper, disk...) + * + * + * @return the available models as a list of Maps objects + */ + private static List modelsAsManagedResources(List models) { + final List list = new ArrayList<>(models.size()); + for (final LTRScoringModel model : models) { + list.add(toLTRScoringModelMap(model)); + } + return list; + } + + @SuppressWarnings("unchecked") + public static LTRScoringModel fromLTRScoringModelMap(SolrResourceLoader solrResourceLoader, + Map modelMap, ManagedFeatureStore managedFeatureStore) { + + final FeatureStore featureStore = + managedFeatureStore.getFeatureStore((String) modelMap.get(STORE_KEY)); + + final List features = new ArrayList<>(); + final List norms = new ArrayList<>(); + + final List featureList = (List) modelMap.get(FEATURES_KEY); + if (featureList != null) { + for (final Object feature : featureList) { + final Map featureMap = (Map) feature; + features.add(lookupFeatureFromFeatureMap(featureMap, featureStore)); + norms.add(createNormalizerFromFeatureMap(solrResourceLoader, featureMap)); + } + } + + return LTRScoringModel.getInstance(solrResourceLoader, + (String) modelMap.get(CLASS_KEY), // modelClassName + (String) modelMap.get(NAME_KEY), // modelName + features, + norms, + featureStore.getName(), + featureStore.getFeatures(), + (Map) modelMap.get(PARAMS_KEY)); + } + + private static LinkedHashMap toLTRScoringModelMap(LTRScoringModel model) { + final LinkedHashMap modelMap = new LinkedHashMap<>(5, 1.0f); + + modelMap.put(NAME_KEY, model.getName()); + modelMap.put(CLASS_KEY, model.getClass().getCanonicalName()); + modelMap.put(STORE_KEY, model.getFeatureStoreName()); + + final List> features = new ArrayList<>(); + final List featuresList = model.getFeatures(); + final List normsList = model.getNorms(); + for (int ii=0; ii featureMap, + FeatureStore featureStore) { + final String featureName = (String)featureMap.get(NAME_KEY); + return (featureName == null ? null + : featureStore.get(featureName)); + } + + @SuppressWarnings("unchecked") + private static Normalizer createNormalizerFromFeatureMap(SolrResourceLoader solrResourceLoader, + Map featureMap) { + final Map normMap = (Map)featureMap.get(NORM_KEY); + return (normMap == null ? IdentityNormalizer.INSTANCE + : fromNormalizerMap(solrResourceLoader, normMap)); + } + + private static LinkedHashMap toFeatureMap(Feature feature, Normalizer norm) { + final LinkedHashMap map = new LinkedHashMap(2, 1.0f); + map.put(NAME_KEY, feature.getName()); + map.put(NORM_KEY, toNormalizerMap(norm)); + return map; + } + + private static Normalizer fromNormalizerMap(SolrResourceLoader solrResourceLoader, + Map normMap) { + final String className = (String) normMap.get(CLASS_KEY); + + @SuppressWarnings("unchecked") + final Map params = (Map) normMap.get(PARAMS_KEY); + + return Normalizer.getInstance(solrResourceLoader, className, params); + } + + private static LinkedHashMap toNormalizerMap(Normalizer norm) { + final LinkedHashMap normalizer = new LinkedHashMap<>(2, 1.0f); + + normalizer.put(CLASS_KEY, norm.getClass().getCanonicalName()); + + final LinkedHashMap params = norm.paramsToMap(); + if (params != null) { + normalizer.put(PARAMS_KEY, params); + } + + return normalizer; + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/rest/package-info.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/rest/package-info.java new file mode 100644 index 00000000000..fbf702999d4 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/store/rest/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Contains the {@link org.apache.solr.rest.ManagedResource} that encapsulate + * the feature and the model stores. + */ +package org.apache.solr.ltr.store.rest; diff --git a/solr/contrib/ltr/src/java/org/apache/solr/response/transform/LTRFeatureLoggerTransformerFactory.java b/solr/contrib/ltr/src/java/org/apache/solr/response/transform/LTRFeatureLoggerTransformerFactory.java new file mode 100644 index 00000000000..d1442929fb5 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/response/transform/LTRFeatureLoggerTransformerFactory.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.response.transform; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Explanation; +import org.apache.solr.common.SolrDocument; +import org.apache.solr.common.SolrException; +import org.apache.solr.common.params.SolrParams; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.ltr.FeatureLogger; +import org.apache.solr.ltr.LTRRescorer; +import org.apache.solr.ltr.LTRScoringQuery; +import org.apache.solr.ltr.LTRThreadModule; +import org.apache.solr.ltr.SolrQueryRequestContextUtils; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.ltr.norm.Normalizer; +import org.apache.solr.ltr.store.FeatureStore; +import org.apache.solr.ltr.store.rest.ManagedFeatureStore; +import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.response.ResultContext; +import org.apache.solr.search.LTRQParserPlugin; +import org.apache.solr.search.SolrIndexSearcher; +import org.apache.solr.util.SolrPluginUtils; + +/** + * This transformer will take care to generate and append in the response the + * features declared in the feature store of the current model. The class is + * useful if you are not interested in the reranking (e.g., bootstrapping a + * machine learning framework). + */ +public class LTRFeatureLoggerTransformerFactory extends TransformerFactory { + + // used inside fl to specify the output format (csv/json) of the extracted features + private static final String FV_RESPONSE_WRITER = "fvwt"; + + // used inside fl to specify the format (dense|sparse) of the extracted features + private static final String FV_FORMAT = "format"; + + // used inside fl to specify the feature store to use for the feature extraction + private static final String FV_STORE = "store"; + + private static String DEFAULT_LOGGING_MODEL_NAME = "logging-model"; + + private String loggingModelName = DEFAULT_LOGGING_MODEL_NAME; + private String defaultFvStore; + private String defaultFvwt; + private String defaultFvFormat; + + private LTRThreadModule threadManager = null; + + public void setLoggingModelName(String loggingModelName) { + this.loggingModelName = loggingModelName; + } + + public void setStore(String defaultFvStore) { + this.defaultFvStore = defaultFvStore; + } + + public void setFvwt(String defaultFvwt) { + this.defaultFvwt = defaultFvwt; + } + + public void setFormat(String defaultFvFormat) { + this.defaultFvFormat = defaultFvFormat; + } + + @Override + public void init(@SuppressWarnings("rawtypes") NamedList args) { + super.init(args); + threadManager = LTRThreadModule.getInstance(args); + SolrPluginUtils.invokeSetters(this, args); + } + + @Override + public DocTransformer create(String name, SolrParams params, + SolrQueryRequest req) { + + // Hint to enable feature vector cache since we are requesting features + SolrQueryRequestContextUtils.setIsExtractingFeatures(req); + + // Communicate which feature store we are requesting features for + SolrQueryRequestContextUtils.setFvStoreName(req, params.get(FV_STORE, defaultFvStore)); + + // Create and supply the feature logger to be used + SolrQueryRequestContextUtils.setFeatureLogger(req, + FeatureLogger.createFeatureLogger( + params.get(FV_RESPONSE_WRITER, defaultFvwt), + params.get(FV_FORMAT, defaultFvFormat))); + + return new FeatureTransformer(name, params, req); + } + + class FeatureTransformer extends DocTransformer { + + final private String name; + final private SolrParams params; + final private SolrQueryRequest req; + + private List leafContexts; + private SolrIndexSearcher searcher; + private LTRScoringQuery scoringQuery; + private LTRScoringQuery.ModelWeight modelWeight; + private FeatureLogger featureLogger; + private boolean docsWereNotReranked; + + /** + * @param name + * Name of the field to be added in a document representing the + * feature vectors + */ + public FeatureTransformer(String name, SolrParams params, + SolrQueryRequest req) { + this.name = name; + this.params = params; + this.req = req; + } + + @Override + public String getName() { + return name; + } + + @Override + public void setContext(ResultContext context) { + super.setContext(context); + if (context == null) { + return; + } + if (context.getRequest() == null) { + return; + } + + searcher = context.getSearcher(); + if (searcher == null) { + throw new SolrException( + SolrException.ErrorCode.BAD_REQUEST, + "searcher is null"); + } + leafContexts = searcher.getTopReaderContext().leaves(); + + // Setup LTRScoringQuery + scoringQuery = SolrQueryRequestContextUtils.getScoringQuery(req); + docsWereNotReranked = (scoringQuery == null); + String featureStoreName = SolrQueryRequestContextUtils.getFvStoreName(req); + if (docsWereNotReranked || (featureStoreName != null && (!featureStoreName.equals(scoringQuery.getScoringModel().getFeatureStoreName())))) { + // if store is set in the transformer we should overwrite the logger + + final ManagedFeatureStore fr = ManagedFeatureStore.getManagedFeatureStore(req.getCore()); + + final FeatureStore store = fr.getFeatureStore(featureStoreName); + featureStoreName = store.getName(); // if featureStoreName was null before this gets actual name + + try { + final LoggingModel lm = new LoggingModel(loggingModelName, + featureStoreName, store.getFeatures()); + + scoringQuery = new LTRScoringQuery(lm, + LTRQParserPlugin.extractEFIParams(params), + true, + threadManager); // request feature weights to be created for all features + + // Local transformer efi if provided + scoringQuery.setOriginalQuery(context.getQuery()); + + }catch (final Exception e) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, + "retrieving the feature store "+featureStoreName, e); + } + } + + if (scoringQuery.getFeatureLogger() == null){ + scoringQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) ); + } + scoringQuery.setRequest(req); + + featureLogger = scoringQuery.getFeatureLogger(); + + try { + modelWeight = scoringQuery.createWeight(searcher, true, 1f); + } catch (final IOException e) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e.getMessage(), e); + } + if (modelWeight == null) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, + "error logging the features, model weight is null"); + } + } + + @Override + public void transform(SolrDocument doc, int docid, float score) + throws IOException { + Object fv = featureLogger.getFeatureVector(docid, scoringQuery, searcher); + if (fv == null) { // FV for this document was not in the cache + fv = featureLogger.makeFeatureVector( + LTRRescorer.extractFeaturesInfo( + modelWeight, + docid, + (docsWereNotReranked ? new Float(score) : null), + leafContexts)); + } + + doc.addField(name, fv); + } + + } + + private static class LoggingModel extends LTRScoringModel { + + public LoggingModel(String name, String featureStoreName, List allFeatures){ + this(name, Collections.emptyList(), Collections.emptyList(), + featureStoreName, allFeatures, Collections.emptyMap()); + } + + protected LoggingModel(String name, List features, + List norms, String featureStoreName, + List allFeatures, Map params) { + super(name, features, norms, featureStoreName, allFeatures, params); + } + + @Override + public float score(float[] modelFeatureValuesNormalized) { + return 0; + } + + @Override + public Explanation explain(LeafReaderContext context, int doc, float finalScore, List featureExplanations) { + return Explanation.match(finalScore, toString() + + " logging model, used only for logging the features"); + } + + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/response/transform/package-info.java b/solr/contrib/ltr/src/java/org/apache/solr/response/transform/package-info.java new file mode 100644 index 00000000000..bab3ebfc65d --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/response/transform/package-info.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * APIs and implementations of {@link org.apache.solr.response.transform.DocTransformer} for modifying documents in Solr request responses + */ +package org.apache.solr.response.transform; + + diff --git a/solr/contrib/ltr/src/java/org/apache/solr/search/LTRQParserPlugin.java b/solr/contrib/ltr/src/java/org/apache/solr/search/LTRQParserPlugin.java new file mode 100644 index 00000000000..40cbaa90982 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/search/LTRQParserPlugin.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.search; + +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; + +import org.apache.lucene.analysis.util.ResourceLoader; +import org.apache.lucene.analysis.util.ResourceLoaderAware; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.solr.common.SolrException; +import org.apache.solr.common.params.SolrParams; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.core.SolrResourceLoader; +import org.apache.solr.ltr.LTRRescorer; +import org.apache.solr.ltr.LTRScoringQuery; +import org.apache.solr.ltr.LTRThreadModule; +import org.apache.solr.ltr.SolrQueryRequestContextUtils; +import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.ltr.store.rest.ManagedFeatureStore; +import org.apache.solr.ltr.store.rest.ManagedModelStore; +import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.rest.ManagedResource; +import org.apache.solr.rest.ManagedResourceObserver; +import org.apache.solr.util.SolrPluginUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Plug into solr a rerank model. + * + * Learning to Rank Query Parser Syntax: rq={!ltr model=6029760550880411648 reRankDocs=300 + * efi.myCompanyQueryIntent=0.98} + * + */ +public class LTRQParserPlugin extends QParserPlugin implements ResourceLoaderAware, ManagedResourceObserver { + public static final String NAME = "ltr"; + private static Query defaultQuery = new MatchAllDocsQuery(); + + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + // params for setting custom external info that features can use, like query + // intent + static final String EXTERNAL_FEATURE_INFO = "efi."; + + private ManagedFeatureStore fr = null; + private ManagedModelStore mr = null; + + private LTRThreadModule threadManager = null; + + /** query parser plugin: the name of the attribute for setting the model **/ + public static final String MODEL = "model"; + + /** query parser plugin: default number of documents to rerank **/ + public static final int DEFAULT_RERANK_DOCS = 200; + + /** + * query parser plugin:the param that will select how the number of document + * to rerank + **/ + public static final String RERANK_DOCS = "reRankDocs"; + + @Override + public void init(@SuppressWarnings("rawtypes") NamedList args) { + super.init(args); + threadManager = LTRThreadModule.getInstance(args); + SolrPluginUtils.invokeSetters(this, args); + } + + @Override + public QParser createParser(String qstr, SolrParams localParams, + SolrParams params, SolrQueryRequest req) { + return new LTRQParser(qstr, localParams, params, req); + } + + /** + * Given a set of local SolrParams, extract all of the efi.key=value params into a map + * @param localParams Local request parameters that might conatin efi params + * @return Map of efi params, where the key is the name of the efi param, and the + * value is the value of the efi param + */ + public static Map extractEFIParams(SolrParams localParams) { + final Map externalFeatureInfo = new HashMap<>(); + for (final Iterator it = localParams.getParameterNamesIterator(); it + .hasNext();) { + final String name = it.next(); + if (name.startsWith(EXTERNAL_FEATURE_INFO)) { + externalFeatureInfo.put( + name.substring(EXTERNAL_FEATURE_INFO.length()), + new String[] {localParams.get(name)}); + } + } + return externalFeatureInfo; + } + + + @Override + public void inform(ResourceLoader loader) throws IOException { + final SolrResourceLoader solrResourceLoader = (SolrResourceLoader) loader; + ManagedFeatureStore.registerManagedFeatureStore(solrResourceLoader, this); + ManagedModelStore.registerManagedModelStore(solrResourceLoader, this); + } + + @Override + public void onManagedResourceInitialized(NamedList args, ManagedResource res) throws SolrException { + if (res instanceof ManagedFeatureStore) { + fr = (ManagedFeatureStore)res; + } + if (res instanceof ManagedModelStore){ + mr = (ManagedModelStore)res; + } + if (mr != null && fr != null){ + mr.setManagedFeatureStore(fr); + // now we can safely load the models + mr.loadStoredModels(); + + } + } + + public class LTRQParser extends QParser { + + public LTRQParser(String qstr, SolrParams localParams, SolrParams params, + SolrQueryRequest req) { + super(qstr, localParams, params, req); + } + + @Override + public Query parse() throws SyntaxError { + // ReRanking Model + final String modelName = localParams.get(LTRQParserPlugin.MODEL); + if ((modelName == null) || modelName.isEmpty()) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, + "Must provide model in the request"); + } + + final LTRScoringModel ltrScoringModel = mr.getModel(modelName); + if (ltrScoringModel == null) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, + "cannot find " + LTRQParserPlugin.MODEL + " " + modelName); + } + + final String modelFeatureStoreName = ltrScoringModel.getFeatureStoreName(); + final boolean extractFeatures = SolrQueryRequestContextUtils.isExtractingFeatures(req); + final String fvStoreName = SolrQueryRequestContextUtils.getFvStoreName(req); + // Check if features are requested and if the model feature store and feature-transform feature store are the same + final boolean featuresRequestedFromSameStore = (modelFeatureStoreName.equals(fvStoreName) || fvStoreName == null) ? extractFeatures:false; + + final LTRScoringQuery scoringQuery = new LTRScoringQuery(ltrScoringModel, + extractEFIParams(localParams), + featuresRequestedFromSameStore, threadManager); + + // Enable the feature vector caching if we are extracting features, and the features + // we requested are the same ones we are reranking with + if (featuresRequestedFromSameStore) { + scoringQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) ); + } + SolrQueryRequestContextUtils.setScoringQuery(req, scoringQuery); + + int reRankDocs = localParams.getInt(RERANK_DOCS, DEFAULT_RERANK_DOCS); + reRankDocs = Math.max(1, reRankDocs); + + // External features + scoringQuery.setRequest(req); + + return new LTRQuery(scoringQuery, reRankDocs); + } + } + + /** + * A learning to rank Query, will incapsulate a learning to rank model, and delegate to it the rescoring + * of the documents. + **/ + public class LTRQuery extends AbstractReRankQuery { + private final LTRScoringQuery scoringQuery; + + public LTRQuery(LTRScoringQuery scoringQuery, int reRankDocs) { + super(defaultQuery, reRankDocs, new LTRRescorer(scoringQuery)); + this.scoringQuery = scoringQuery; + } + + @Override + public int hashCode() { + return 31 * classHash() + (mainQuery.hashCode() + scoringQuery.hashCode() + reRankDocs); + } + + @Override + public boolean equals(Object o) { + return sameClassAs(o) && equalsTo(getClass().cast(o)); + } + + private boolean equalsTo(LTRQuery other) { + return (mainQuery.equals(other.mainQuery) + && scoringQuery.equals(other.scoringQuery) && (reRankDocs == other.reRankDocs)); + } + + @Override + public RankQuery wrap(Query _mainQuery) { + super.wrap(_mainQuery); + scoringQuery.setOriginalQuery(_mainQuery); + return this; + } + + @Override + public String toString(String field) { + return "{!ltr mainQuery='" + mainQuery.toString() + "' scoringQuery='" + + scoringQuery.toString() + "' reRankDocs=" + reRankDocs + "}"; + } + + @Override + protected Query rewrite(Query rewrittenMainQuery) throws IOException { + return new LTRQuery(scoringQuery, reRankDocs).wrap(rewrittenMainQuery); + } + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/search/package-info.java b/solr/contrib/ltr/src/java/org/apache/solr/search/package-info.java new file mode 100644 index 00000000000..2286a9355aa --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/search/package-info.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * APIs and classes for {@linkplain org.apache.solr.search.QParserPlugin parsing} and {@linkplain org.apache.solr.search.SolrIndexSearcher processing} search requests + */ +package org.apache.solr.search; + + diff --git a/solr/contrib/ltr/src/java/overview.html b/solr/contrib/ltr/src/java/overview.html new file mode 100644 index 00000000000..ccae361310a --- /dev/null +++ b/solr/contrib/ltr/src/java/overview.html @@ -0,0 +1,91 @@ + + + +Apache Solr Search Server: Learning to Rank Contrib + +

+This module contains a logic to plug machine learned ranking modules into Solr. +

+

+In information retrieval systems, Learning to Rank is used to re-rank the top X +retrieved documents using trained machine learning models. The hope is +that sophisticated models can make more nuanced ranking decisions than standard ranking +functions like TF-IDF or BM25. +

+

+This module allows to plug a reranking component directly into Solr, enabling users +to easily build their own learning to rank systems and access the rich +matching features readily available in Solr. It also provides tools to perform +feature engineering and feature extraction. +

+

Code structure

+

+A Learning to Rank model is plugged into the ranking through the {@link org.apache.solr.search.LTRQParserPlugin}, +a {@link org.apache.solr.search.QParserPlugin}. The plugin will +read from the request the model (instance of {@link org.apache.solr.ltr.model.LTRScoringModel}) +used to perform the request plus other +parameters. The plugin will generate a {@link org.apache.solr.search.LTRQParserPlugin.LTRQuery LTRQuery}: +a particular {@link org.apache.solr.search.RankQuery} +that will encapsulate the given model and use it to +rescore and rerank the document (by using an {@link org.apache.solr.ltr.LTRRescorer}). +

+

+A model will be applied on each document through a {@link org.apache.solr.ltr.LTRScoringQuery}, a +subclass of {@link org.apache.lucene.search.Query}. As a normal query, +the learned model will produce a new score +for each document reranked. +

+

+A {@link org.apache.solr.ltr.LTRScoringQuery} is created by providing an instance of +{@link org.apache.solr.ltr.model.LTRScoringModel}. An instance of +{@link org.apache.solr.ltr.model.LTRScoringModel} +defines how to combine the features in order to create a new +score for a document. A new learning to rank model is plugged +into the framework by extending {@link org.apache.solr.ltr.model.LTRScoringModel}, +(see for example {@link org.apache.solr.ltr.model.MultipleAdditiveTreesModel} and {@link org.apache.solr.ltr.model.LinearModel}). +

+

+The {@link org.apache.solr.ltr.LTRScoringQuery} will take care of computing the values of +all the features (see {@link org.apache.solr.ltr.feature.Feature}) and then will delegate the final score +generation to the {@link org.apache.solr.ltr.model.LTRScoringModel}, by calling the method +{@link org.apache.solr.ltr.model.LTRScoringModel#score(float[] modelFeatureValuesNormalized)}. +

+

+A {@link org.apache.solr.ltr.feature.Feature} will produce a particular value for each document, so +it is modeled as a {@link org.apache.lucene.search.Query}. The package +{@link org.apache.solr.ltr.feature} contains several examples +of features. One benefit of extending the Query object is that we can reuse +Query as a feature, see for example {@link org.apache.solr.ltr.feature.SolrFeature}. +Features for a document can also be returned in the response by +using the FeatureTransformer (a {@link org.apache.solr.response.transform.DocTransformer DocTransformer}) +provided by {@link org.apache.solr.response.transform.LTRFeatureLoggerTransformerFactory}. +

+

+{@link org.apache.solr.ltr.store} contains all the logic to store all the features and the models. +Models are registered into a unique {@link org.apache.solr.ltr.store.ModelStore ModelStore}, +and each model specifies a particular {@link org.apache.solr.ltr.store.FeatureStore FeatureStore} that +will contain a particular subset of features. +

+

+Features and models can be managed through a REST API, provided by the +{@link org.apache.solr.rest.ManagedResource Managed Resources} +{@link org.apache.solr.ltr.store.rest.ManagedFeatureStore ManagedFeatureStore} +and {@link org.apache.solr.ltr.store.rest.ManagedModelStore ManagedModelStore}. +

+ + diff --git a/solr/contrib/ltr/src/test-files/featureExamples/comp_features.json b/solr/contrib/ltr/src/test-files/featureExamples/comp_features.json new file mode 100644 index 00000000000..8d757395186 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/featureExamples/comp_features.json @@ -0,0 +1,37 @@ +[ +{ "name":"origScore", + "class":"org.apache.solr.ltr.feature.OriginalScoreFeature", + "params":{}, + "store": "feature-store-6" +}, +{ + "name": "descriptionTermFreq", + "class": "org.apache.solr.ltr.feature.SolrFeature", + "params": { "q" : "{!func}termfreq(description,${user_text})" }, + "store": "feature-store-6" +}, +{ + "name": "popularity", + "class": "org.apache.solr.ltr.feature.SolrFeature", + "params": { "q" : "{!func}normHits"}, + "store": "feature-store-6" +}, +{ + "name": "isPopular", + "class": "org.apache.solr.ltr.feature.SolrFeature", + "params": {"fq" : ["{!field f=popularity}201"] }, + "store": "feature-store-6" +}, +{ + "name": "queryPartialMatch2", + "class": "org.apache.solr.ltr.feature.SolrFeature", + "params": {"q": "{!dismax qf=description mm=2}${user_text}" }, + "store": "feature-store-6" +}, +{ + "name": "queryPartialMatch2.1", + "class": "org.apache.solr.ltr.feature.SolrFeature", + "params": {"q": "{!dismax qf=description mm=2}${user_text}" }, + "store": "feature-store-6" +} +] \ No newline at end of file diff --git a/solr/contrib/ltr/src/test-files/featureExamples/external_features.json b/solr/contrib/ltr/src/test-files/featureExamples/external_features.json new file mode 100644 index 00000000000..6c0cfa63452 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/featureExamples/external_features.json @@ -0,0 +1,51 @@ +[ { + "name" : "matchedTitle", + "class" : "org.apache.solr.ltr.feature.SolrFeature", + "params" : { + "q" : "{!terms f=title}${user_query}" + } +}, { + "name" : "confidence", + "class" : "org.apache.solr.ltr.feature.ValueFeature", + "store": "fstore2", + "params" : { + "value" : "${myconf}" + } +}, { + "name":"originalScore", + "class":"org.apache.solr.ltr.feature.OriginalScoreFeature", + "store": "fstore2", + "params":{} +}, { + "name" : "occurrences", + "class" : "org.apache.solr.ltr.feature.ValueFeature", + "store": "fstore3", + "params" : { + "value" : "${myOcc}", + "required" : false + } +}, { + "name":"originalScore", + "class":"org.apache.solr.ltr.feature.OriginalScoreFeature", + "store": "fstore3", + "params":{} +}, { + "name" : "popularity", + "class" : "org.apache.solr.ltr.feature.ValueFeature", + "store": "fstore4", + "params" : { + "value" : "${myPop}", + "required" : true + } +}, { + "name":"originalScore", + "class":"org.apache.solr.ltr.feature.OriginalScoreFeature", + "store": "fstore4", + "params":{} +}, { + "name" : "titlePhraseMatch", + "class" : "org.apache.solr.ltr.feature.SolrFeature", + "params" : { + "q" : "{!field f=title}${user_query}" + } +} ] diff --git a/solr/contrib/ltr/src/test-files/featureExamples/external_features_for_sparse_processing.json b/solr/contrib/ltr/src/test-files/featureExamples/external_features_for_sparse_processing.json new file mode 100644 index 00000000000..52bab275ab2 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/featureExamples/external_features_for_sparse_processing.json @@ -0,0 +1,18 @@ +[{ + "name" : "user_device_smartphone", + "class":"org.apache.solr.ltr.feature.ValueFeature", + "params" : { + "value": "${user_device_smartphone}" + } +}, + { + "name" : "user_device_tablet", + "class":"org.apache.solr.ltr.feature.ValueFeature", + "params" : { + "value": "${user_device_tablet}" + } + } + + + +] diff --git a/solr/contrib/ltr/src/test-files/featureExamples/features-linear-efi.json b/solr/contrib/ltr/src/test-files/featureExamples/features-linear-efi.json new file mode 100644 index 00000000000..e05542ac773 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/featureExamples/features-linear-efi.json @@ -0,0 +1,17 @@ +[ + { + "name": "sampleConstant", + "class": "org.apache.solr.ltr.feature.ValueFeature", + "params": { + "value": 5 + } + }, + { + "name" : "search_number_of_nights", + "class":"org.apache.solr.ltr.feature.ValueFeature", + "params" : { + "value": "${search_number_of_nights}" + } + } + +] diff --git a/solr/contrib/ltr/src/test-files/featureExamples/features-linear.json b/solr/contrib/ltr/src/test-files/featureExamples/features-linear.json new file mode 100644 index 00000000000..8cc29969aa3 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/featureExamples/features-linear.json @@ -0,0 +1,51 @@ +[ + { + "name": "title", + "class": "org.apache.solr.ltr.feature.ValueFeature", + "params": { + "value": 1 + } + }, + { + "name": "description", + "class": "org.apache.solr.ltr.feature.ValueFeature", + "params": { + "value": 2 + } + }, + { + "name": "keywords", + "class": "org.apache.solr.ltr.feature.ValueFeature", + "params": { + "value": 2 + } + }, + { + "name": "popularity", + "class": "org.apache.solr.ltr.feature.ValueFeature", + "params": { + "value": 3 + } + }, + { + "name": "text", + "class": "org.apache.solr.ltr.feature.ValueFeature", + "params": { + "value": 4 + } + }, + { + "name": "queryIntentPerson", + "class": "org.apache.solr.ltr.feature.ValueFeature", + "params": { + "value": 5 + } + }, + { + "name": "queryIntentCompany", + "class": "org.apache.solr.ltr.feature.ValueFeature", + "params": { + "value": 5 + } + } +] diff --git a/solr/contrib/ltr/src/test-files/featureExamples/features-store-test-model.json b/solr/contrib/ltr/src/test-files/featureExamples/features-store-test-model.json new file mode 100644 index 00000000000..69aad84aca2 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/featureExamples/features-store-test-model.json @@ -0,0 +1,51 @@ +[ + { + "name": "constant1", + "class": "org.apache.solr.ltr.feature.ValueFeature", + "store":"test", + "params": { + "value": 1 + } + }, + { + "name": "constant2", + "class": "org.apache.solr.ltr.feature.ValueFeature", + "store":"test", + "params": { + "value": 2 + } + }, + { + "name": "constant3", + "class": "org.apache.solr.ltr.feature.ValueFeature", + "store":"test", + "params": { + "value": 3 + } + }, + { + "name": "constant4", + "class": "org.apache.solr.ltr.feature.ValueFeature", + "store":"test", + "params": { + "value": 4 + } + }, + { + "name": "constant5", + "class": "org.apache.solr.ltr.feature.ValueFeature", + "store":"test", + "params": { + "value": 5 + } + }, + { + "name": "pop", + "class": "org.apache.solr.ltr.feature.FieldValueFeature", + "store":"test", + "params": { + "field": "popularity" + } + } + +] diff --git a/solr/contrib/ltr/src/test-files/featureExamples/fq_features.json b/solr/contrib/ltr/src/test-files/featureExamples/fq_features.json new file mode 100644 index 00000000000..13968f9f474 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/featureExamples/fq_features.json @@ -0,0 +1,16 @@ +[ + { + "name": "matchedTitle", + "class": "org.apache.solr.ltr.feature.SolrFeature", + "params": { + "q": "{!terms f=title}${user_query}" + } + }, + { + "name": "popularity", + "class": "org.apache.solr.ltr.feature.ValueFeature", + "params": { + "value": 3 + } + } +] diff --git a/solr/contrib/ltr/src/test-files/featureExamples/multipleadditivetreesmodel_features.json b/solr/contrib/ltr/src/test-files/featureExamples/multipleadditivetreesmodel_features.json new file mode 100644 index 00000000000..92f38616228 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/featureExamples/multipleadditivetreesmodel_features.json @@ -0,0 +1,16 @@ +[ + { + "name": "matchedTitle", + "class": "org.apache.solr.ltr.feature.SolrFeature", + "params": { + "q": "{!terms f=title}${user_query}" + } + }, + { + "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs", + "class": "org.apache.solr.ltr.feature.ValueFeature", + "params": { + "value": 1 + } + } +] diff --git a/solr/contrib/ltr/src/test-files/log4j.properties b/solr/contrib/ltr/src/test-files/log4j.properties new file mode 100644 index 00000000000..d86c6988d5e --- /dev/null +++ b/solr/contrib/ltr/src/test-files/log4j.properties @@ -0,0 +1,32 @@ +# Logging level +log4j.rootLogger=INFO, CONSOLE + +log4j.appender.CONSOLE=org.apache.log4j.ConsoleAppender +log4j.appender.CONSOLE.Target=System.err +log4j.appender.CONSOLE.layout=org.apache.log4j.EnhancedPatternLayout +log4j.appender.CONSOLE.layout.ConversionPattern=%-4r %-5p (%t) [%X{node_name} %X{collection} %X{shard} %X{replica} %X{core}] %c{1.} %m%n +log4j.logger.org.apache.zookeeper=WARN +log4j.logger.org.apache.hadoop=WARN +log4j.logger.org.apache.directory=WARN +log4j.logger.org.apache.solr.hadoop=INFO +log4j.logger.org.apache.solr.client.solrj.embedded.JettySolrRunner=DEBUG +org.apache.solr.client.solrj.embedded.JettySolrRunner=DEBUG + +#log4j.logger.org.apache.solr.update.processor.LogUpdateProcessor=DEBUG +#log4j.logger.org.apache.solr.update.processor.DistributedUpdateProcessor=DEBUG +#log4j.logger.org.apache.solr.update.PeerSync=DEBUG +#log4j.logger.org.apache.solr.core.CoreContainer=DEBUG +#log4j.logger.org.apache.solr.cloud.RecoveryStrategy=DEBUG +#log4j.logger.org.apache.solr.cloud.SyncStrategy=DEBUG +#log4j.logger.org.apache.solr.handler.admin.CoreAdminHandler=DEBUG +#log4j.logger.org.apache.solr.cloud.ZkController=DEBUG +#log4j.logger.org.apache.solr.update.DefaultSolrCoreState=DEBUG +#log4j.logger.org.apache.solr.common.cloud.ConnectionManager=DEBUG +#log4j.logger.org.apache.solr.update.UpdateLog=DEBUG +#log4j.logger.org.apache.solr.cloud.ChaosMonkey=DEBUG +#log4j.logger.org.apache.solr.update.TransactionLog=DEBUG +#log4j.logger.org.apache.solr.handler.ReplicationHandler=DEBUG +#log4j.logger.org.apache.solr.handler.IndexFetcher=DEBUG + +#log4j.logger.org.apache.solr.common.cloud.ClusterStateUtil=DEBUG +#log4j.logger.org.apache.solr.cloud.OverseerAutoReplicaFailoverThread=DEBUG diff --git a/solr/contrib/ltr/src/test-files/modelExamples/external_model.json b/solr/contrib/ltr/src/test-files/modelExamples/external_model.json new file mode 100644 index 00000000000..04ab22968b0 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/external_model.json @@ -0,0 +1,12 @@ +{ + "class":"org.apache.solr.ltr.model.LinearModel", + "name":"externalmodel", + "features":[ + { "name": "matchedTitle"} + ], + "params":{ + "weights": { + "matchedTitle": 0.999 + } + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/external_model_store.json b/solr/contrib/ltr/src/test-files/modelExamples/external_model_store.json new file mode 100644 index 00000000000..f8e664855d3 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/external_model_store.json @@ -0,0 +1,13 @@ +{ + "class":"org.apache.solr.ltr.model.LinearModel", + "name":"externalmodelstore", + "store": "fstore2", + "features":[ + { "name": "confidence"} + ], + "params":{ + "weights": { + "confidence": 0.999 + } + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/fq-model.json b/solr/contrib/ltr/src/test-files/modelExamples/fq-model.json new file mode 100644 index 00000000000..b5d631fdefa --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/fq-model.json @@ -0,0 +1,20 @@ +{ + "class":"org.apache.solr.ltr.model.LinearModel", + "name":"fqmodel", + "features":[ + { + "name":"matchedTitle", + "norm": { + "class":"org.apache.solr.ltr.norm.MinMaxNormalizer", + "params":{ "min":"0.0f", "max":"10.0f" } + } + }, + { "name":"popularity"} + ], + "params":{ + "weights": { + "matchedTitle": 0.5, + "popularity": 0.5 + } + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/linear-model-efi.json b/solr/contrib/ltr/src/test-files/modelExamples/linear-model-efi.json new file mode 100644 index 00000000000..018466e0b2e --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/linear-model-efi.json @@ -0,0 +1,14 @@ +{ + "class":"org.apache.solr.ltr.model.LinearModel", + "name":"linear-efi", + "features":[ + {"name":"sampleConstant"}, + {"name":"search_number_of_nights"} + ], + "params":{ + "weights":{ + "sampleConstant":1.0, + "search_number_of_nights":2.0 + } + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/linear-model.json b/solr/contrib/ltr/src/test-files/modelExamples/linear-model.json new file mode 100644 index 00000000000..6b46dca1ae6 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/linear-model.json @@ -0,0 +1,30 @@ +{ + "class":"org.apache.solr.ltr.model.LinearModel", + "name":"6029760550880411648", + "features":[ + {"name":"title"}, + {"name":"description"}, + {"name":"keywords"}, + { + "name":"popularity", + "norm": { + "class":"org.apache.solr.ltr.norm.MinMaxNormalizer", + "params":{ "min":"0.0f", "max":"10.0f" } + } + }, + {"name":"text"}, + {"name":"queryIntentPerson"}, + {"name":"queryIntentCompany"} + ], + "params":{ + "weights": { + "title": 0.0000000000, + "description": 0.1000000000, + "keywords": 0.2000000000, + "popularity": 0.3000000000, + "text": 0.4000000000, + "queryIntentPerson":0.1231231, + "queryIntentCompany":0.12121211 + } + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel.json b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel.json new file mode 100644 index 00000000000..37551a07dca --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel.json @@ -0,0 +1,38 @@ +{ + "class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel", + "name":"multipleadditivetreesmodel", + "features":[ + { "name": "matchedTitle"}, + { "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs"} + ], + "params":{ + "trees": [ + { + "weight" : "1f", + "root": { + "feature": "matchedTitle", + "threshold": "0.5f", + "left" : { + "value" : "-100" + }, + "right": { + "feature" : "this_feature_doesnt_exist", + "threshold": "10.0f", + "left" : { + "value" : "50" + }, + "right" : { + "value" : "75" + } + } + } + }, + { + "weight" : "2f", + "root": { + "value" : "-10" + } + } + ] + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_external_binary_features.json b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_external_binary_features.json new file mode 100644 index 00000000000..cb8996ea562 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_external_binary_features.json @@ -0,0 +1,38 @@ +{ + "class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel", + "name":"external_model_binary_feature", + "features":[ + { "name": "user_device_smartphone"}, + { "name": "user_device_tablet"} + ], + "params":{ + "trees": [ + { + "weight" : "1f", + "root": { + "feature": "user_device_smartphone", + "threshold": "0.5f", + "left" : { + "value" : "0" + }, + "right" : { + "value" : "50" + } + + }}, + { + "weight" : "1f", + "root": { + "feature": "user_device_tablet", + "threshold": "0.5f", + "left" : { + "value" : "0" + }, + "right" : { + "value" : "65" + } + + }} + ] + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_feature.json b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_feature.json new file mode 100644 index 00000000000..2919f078652 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_feature.json @@ -0,0 +1,24 @@ +{ + "class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel", + "name":"multipleadditivetreesmodel_no_feature", + "features":[ + { "name": "matchedTitle"}, + { "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs"} + ], + "params":{ + "trees": [ + { + "weight" : "1f", + "root": { + "threshold": "0.5f", + "left" : { + "value" : "-100" + }, + "right": { + "value" : "75" + } + } + } + ] + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_features.json b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_features.json new file mode 100644 index 00000000000..ec4c37f0ee4 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_features.json @@ -0,0 +1,14 @@ +{ + "class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel", + "name":"multipleadditivetreesmodel_no_features", + "params":{ + "trees": [ + { + "weight" : "2f", + "root": { + "value" : "-10" + } + } + ] + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_left.json b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_left.json new file mode 100644 index 00000000000..653d2fff8b3 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_left.json @@ -0,0 +1,22 @@ +{ + "class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel", + "name":"multipleadditivetreesmodel_no_left", + "features":[ + { "name": "matchedTitle"}, + { "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs"} + ], + "params":{ + "trees": [ + { + "weight" : "1f", + "root": { + "feature": "matchedTitle", + "threshold": "0.5f", + "right": { + "value" : "75" + } + } + } + ] + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_params.json b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_params.json new file mode 100644 index 00000000000..4d50c4e1711 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_params.json @@ -0,0 +1,8 @@ +{ + "class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel", + "name":"multipleadditivetreesmodel_no_params", + "features":[ + { "name": "matchedTitle"}, + { "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs"} + ] +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_right.json b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_right.json new file mode 100644 index 00000000000..acd2d83c546 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_right.json @@ -0,0 +1,22 @@ +{ + "class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel", + "name":"multipleadditivetreesmodel_no_right", + "features":[ + { "name": "matchedTitle"}, + { "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs"} + ], + "params":{ + "trees": [ + { + "weight" : "1f", + "root": { + "feature": "matchedTitle", + "threshold": "0.5f", + "left" : { + "value" : "-100" + } + } + } + ] + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_threshold.json b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_threshold.json new file mode 100644 index 00000000000..d0fc3816ee0 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_threshold.json @@ -0,0 +1,24 @@ +{ + "class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel", + "name":"multipleadditivetreesmodel_no_threshold", + "features":[ + { "name": "matchedTitle"}, + { "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs"} + ], + "params":{ + "trees": [ + { + "weight" : "1f", + "root": { + "feature": "matchedTitle", + "left" : { + "value" : "-100" + }, + "right": { + "value" : "75" + } + } + } + ] + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_tree.json b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_tree.json new file mode 100644 index 00000000000..507def321cd --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_tree.json @@ -0,0 +1,15 @@ +{ + "class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel", + "name":"multipleadditivetreesmodel_no_tree", + "features":[ + { "name": "matchedTitle"}, + { "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs"} + ], + "params":{ + "trees": [ + { + "weight" : "2f" + } + ] + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_trees.json b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_trees.json new file mode 100644 index 00000000000..bb360dd548c --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_trees.json @@ -0,0 +1,10 @@ +{ + "class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel", + "name":"multipleadditivetreesmodel_no_trees", + "features":[ + { "name": "matchedTitle"}, + { "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs"} + ], + "params":{ + } +} diff --git a/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_weight.json b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_weight.json new file mode 100644 index 00000000000..9048e6ce145 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_no_weight.json @@ -0,0 +1,24 @@ +{ + "class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel", + "name":"multipleadditivetreesmodel_no_weight", + "features":[ + { "name": "matchedTitle"}, + { "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs"} + ], + "params":{ + "trees": [ + { + "root": { + "feature": "matchedTitle", + "threshold": "0.5f", + "left" : { + "value" : "-100" + }, + "right": { + "value" : "75" + } + } + } + ] + } +} diff --git a/solr/contrib/ltr/src/test-files/solr/collection1/conf/schema.xml b/solr/contrib/ltr/src/test-files/solr/collection1/conf/schema.xml new file mode 100644 index 00000000000..15cf140cc09 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/solr/collection1/conf/schema.xml @@ -0,0 +1,88 @@ + + + + + + + + + + + + + + + + + + + id + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml b/solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml new file mode 100644 index 00000000000..1a184718379 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml @@ -0,0 +1,65 @@ + + + + + 6.0.0 + ${solr.data.dir:} + + + + + + + + + + + + + + + + + + + 15000 + false + + + 1000 + + + ${solr.data.dir:} + + + + + + + + explicit + json + true + id + + + + diff --git a/solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr_Th10_10.xml b/solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr_Th10_10.xml new file mode 100644 index 00000000000..fd0940ae272 --- /dev/null +++ b/solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr_Th10_10.xml @@ -0,0 +1,69 @@ + + + + + 6.0.0 + ${solr.data.dir:} + + + + + + + + 10 + 10 + + + + + + + + + + + + + + + 15000 + false + + + 1000 + + + ${solr.data.dir:} + + + + + + + + explicit + json + true + id + + + + diff --git a/solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-multiseg.xml b/solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-multiseg.xml new file mode 100644 index 00000000000..a36c1df36cb --- /dev/null +++ b/solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-multiseg.xml @@ -0,0 +1,62 @@ + + + + + 6.0.0 + ${solr.data.dir:} + + + + + + + + + 1 + + 10 + 1000 + + + + + + + 15000 + false + + + 1000 + + + ${solr.data.dir:} + + + + + + + + explicit + json + true + id + + + + \ No newline at end of file diff --git a/solr/contrib/ltr/src/test-files/solr/collection1/conf/stopwords.txt b/solr/contrib/ltr/src/test-files/solr/collection1/conf/stopwords.txt new file mode 100644 index 00000000000..eabae3b7c0d --- /dev/null +++ b/solr/contrib/ltr/src/test-files/solr/collection1/conf/stopwords.txt @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +a diff --git a/solr/contrib/ltr/src/test-files/solr/collection1/conf/synonyms.txt b/solr/contrib/ltr/src/test-files/solr/collection1/conf/synonyms.txt new file mode 100644 index 00000000000..0ef0e8daaba --- /dev/null +++ b/solr/contrib/ltr/src/test-files/solr/collection1/conf/synonyms.txt @@ -0,0 +1,28 @@ +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#----------------------------------------------------------------------- +#some test synonym mappings unlikely to appear in real input text +aaafoo => aaabar +bbbfoo => bbbfoo bbbbar +cccfoo => cccbar cccbaz +fooaaa,baraaa,bazaaa + +# Some synonym groups specific to this example +GB,gib,gigabyte,gigabytes +MB,mib,megabyte,megabytes +Television, Televisions, TV, TVs +#notice we use "gib" instead of "GiB" so any WordDelimiterFilter coming +#after us won't split it into two words. + +# Synonym mappings can be used for spelling correction too +pixima => pixma diff --git a/solr/contrib/ltr/src/test-files/solr/solr.xml b/solr/contrib/ltr/src/test-files/solr/solr.xml new file mode 100644 index 00000000000..c8c3ebeb30a --- /dev/null +++ b/solr/contrib/ltr/src/test-files/solr/solr.xml @@ -0,0 +1,42 @@ + + + + + + ${shareSchema:false} + ${configSetBaseDir:configsets} + ${coreRootDirectory:.} + + + ${urlScheme:} + ${socketTimeout:90000} + ${connTimeout:15000} + + + + 127.0.0.1 + ${hostPort:8983} + ${hostContext:solr} + ${solr.zkclienttimeout:30000} + ${genericCoreNodeNames:true} + ${leaderVoteWait:10000} + ${distribUpdateConnTimeout:45000} + ${distribUpdateSoTimeout:340000} + + + diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTROnSolrCloud.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTROnSolrCloud.java new file mode 100644 index 00000000000..2e01a644f15 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTROnSolrCloud.java @@ -0,0 +1,211 @@ +/* * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr; + +import java.io.File; +import java.util.SortedMap; + +import org.apache.commons.io.FileUtils; +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.client.solrj.embedded.JettyConfig; +import org.apache.solr.client.solrj.request.CollectionAdminRequest; +import org.apache.solr.client.solrj.response.CollectionAdminResponse; +import org.apache.solr.client.solrj.response.QueryResponse; +import org.apache.solr.cloud.AbstractDistribZkTestBase; +import org.apache.solr.cloud.MiniSolrCloudCluster; +import org.apache.solr.common.SolrInputDocument; +import org.apache.solr.common.cloud.ZkStateReader; +import org.apache.solr.ltr.feature.SolrFeature; +import org.apache.solr.ltr.feature.ValueFeature; +import org.apache.solr.ltr.model.LinearModel; +import org.eclipse.jetty.servlet.ServletHolder; +import org.junit.AfterClass; +import org.junit.Test; + +public class TestLTROnSolrCloud extends TestRerankBase { + + private MiniSolrCloudCluster solrCluster; + String solrconfig = "solrconfig-ltr.xml"; + String schema = "schema.xml"; + + SortedMap extraServlets = null; + + @Override + public void setUp() throws Exception { + super.setUp(); + extraServlets = setupTestInit(solrconfig, schema, true); + System.setProperty("enable.update.log", "true"); + + int numberOfShards = random().nextInt(4)+1; + int numberOfReplicas = random().nextInt(2)+1; + int maxShardsPerNode = numberOfShards+random().nextInt(4)+1; + + int numberOfNodes = numberOfShards * maxShardsPerNode; + + setupSolrCluster(numberOfShards, numberOfReplicas, numberOfNodes, maxShardsPerNode); + + + } + + + @Override + public void tearDown() throws Exception { + restTestHarness.close(); + restTestHarness = null; + jetty.stop(); + jetty = null; + solrCluster.shutdown(); + super.tearDown(); + } + + @Test + public void testSimpleQuery() throws Exception { + // will randomly pick a configuration with [1..5] shards and [1..3] replicas + + // Test regular query, it will sort the documents by inverse + // popularity (the less popular, docid == 1, will be in the first + // position + SolrQuery query = new SolrQuery("{!func}sub(8,field(popularity))"); + + query.setRequestHandler("/query"); + query.setFields("*,score"); + query.setParam("rows", "8"); + + QueryResponse queryResponse = + solrCluster.getSolrClient().query(COLLECTION,query); + assertEquals(8, queryResponse.getResults().getNumFound()); + assertEquals("1", queryResponse.getResults().get(0).get("id").toString()); + assertEquals("2", queryResponse.getResults().get(1).get("id").toString()); + assertEquals("3", queryResponse.getResults().get(2).get("id").toString()); + assertEquals("4", queryResponse.getResults().get(3).get("id").toString()); + + // Test re-rank and feature vectors returned + query.setFields("*,score,features:[fv]"); + query.add("rq", "{!ltr model=powpularityS-model reRankDocs=8}"); + queryResponse = + solrCluster.getSolrClient().query(COLLECTION,query); + assertEquals(8, queryResponse.getResults().getNumFound()); + assertEquals("8", queryResponse.getResults().get(0).get("id").toString()); + assertEquals("powpularityS:64.0;c3:2.0", + queryResponse.getResults().get(0).get("features").toString()); + assertEquals("7", queryResponse.getResults().get(1).get("id").toString()); + assertEquals("powpularityS:49.0;c3:2.0", + queryResponse.getResults().get(1).get("features").toString()); + assertEquals("6", queryResponse.getResults().get(2).get("id").toString()); + assertEquals("powpularityS:36.0;c3:2.0", + queryResponse.getResults().get(2).get("features").toString()); + assertEquals("5", queryResponse.getResults().get(3).get("id").toString()); + assertEquals("powpularityS:25.0;c3:2.0", + queryResponse.getResults().get(3).get("features").toString()); + } + + private void setupSolrCluster(int numShards, int numReplicas, int numServers, int maxShardsPerNode) throws Exception { + JettyConfig jc = buildJettyConfig("/solr"); + jc = JettyConfig.builder(jc).withServlets(extraServlets).build(); + solrCluster = new MiniSolrCloudCluster(numServers, tmpSolrHome.toPath(), jc); + File configDir = tmpSolrHome.toPath().resolve("collection1/conf").toFile(); + solrCluster.uploadConfigSet(configDir.toPath(), "conf1"); + + solrCluster.getSolrClient().setDefaultCollection(COLLECTION); + + createCollection(COLLECTION, "conf1", numShards, numReplicas, maxShardsPerNode); + indexDocuments(COLLECTION); + + createJettyAndHarness(tmpSolrHome.getAbsolutePath(), solrconfig, schema, + "/solr", true, extraServlets); + loadModelsAndFeatures(); + } + + + private void createCollection(String name, String config, int numShards, int numReplicas, int maxShardsPerNode) + throws Exception { + CollectionAdminResponse response; + CollectionAdminRequest.Create create = + CollectionAdminRequest.createCollection(name, config, numShards, numReplicas); + create.setMaxShardsPerNode(maxShardsPerNode); + response = create.process(solrCluster.getSolrClient()); + + if (response.getStatus() != 0 || response.getErrorMessages() != null) { + fail("Could not create collection. Response" + response.toString()); + } + ZkStateReader zkStateReader = solrCluster.getSolrClient().getZkStateReader(); + AbstractDistribZkTestBase.waitForRecoveriesToFinish(name, zkStateReader, false, true, 100); + } + + + void indexDocument(String collection, String id, String title, String description, int popularity) + throws Exception{ + SolrInputDocument doc = new SolrInputDocument(); + doc.setField("id", id); + doc.setField("title", title); + doc.setField("description", description); + doc.setField("popularity", popularity); + solrCluster.getSolrClient().add(collection, doc); + } + + private void indexDocuments(final String collection) + throws Exception { + final int collectionSize = 8; + for (int docId = 1; docId <= collectionSize; docId++) { + final int popularity = docId; + indexDocument(collection, String.valueOf(docId), "a1", "bloom", popularity); + } + solrCluster.getSolrClient().commit(collection); + } + + + private void loadModelsAndFeatures() throws Exception { + final String featureStore = "test"; + final String[] featureNames = new String[] {"powpularityS","c3"}; + final String jsonModelParams = "{\"weights\":{\"powpularityS\":1.0,\"c3\":1.0}}"; + + loadFeature( + featureNames[0], + SolrFeature.class.getCanonicalName(), + featureStore, + "{\"q\":\"{!func}pow(popularity,2)\"}" + ); + loadFeature( + featureNames[1], + ValueFeature.class.getCanonicalName(), + featureStore, + "{\"value\":2}" + ); + + loadModel( + "powpularityS-model", + LinearModel.class.getCanonicalName(), + featureNames, + featureStore, + jsonModelParams + ); + reloadCollection(COLLECTION); + } + + private void reloadCollection(String collection) throws Exception { + CollectionAdminRequest.Reload reloadRequest = CollectionAdminRequest.reloadCollection(collection); + CollectionAdminResponse response = reloadRequest.process(solrCluster.getSolrClient()); + assertEquals(0, response.getStatus()); + assertTrue(response.isSuccess()); + } + + @AfterClass + public static void after() throws Exception { + FileUtils.deleteDirectory(tmpSolrHome); + System.clearProperty("managed.schema.mutable"); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserExplain.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserExplain.java new file mode 100644 index 00000000000..2f90df841f9 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserExplain.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.model.LinearModel; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestLTRQParserExplain extends TestRerankBase { + + @BeforeClass + public static void setup() throws Exception { + setuptest(); + loadFeatures("features-store-test-model.json"); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + + @Test + public void testRerankedExplain() throws Exception { + loadModel("linear2", LinearModel.class.getCanonicalName(), new String[] { + "constant1", "constant2", "pop"}, + "{\"weights\":{\"pop\":1.0,\"constant1\":1.5,\"constant2\":3.5}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:bloomberg"); + query.setParam("debugQuery", "on"); + query.add("rows", "2"); + query.add("rq", "{!ltr reRankDocs=2 model=linear2}"); + query.add("fl", "*,score"); + + assertJQ( + "/query" + query.toQueryString(), + "/debug/explain/9=='\n13.5 = LinearModel(name=linear2,featureWeights=[constant1=1.5,constant2=3.5,pop=1.0]) model applied to features, sum of:\n 1.5 = prod of:\n 1.5 = weight on feature\n 1.0 = ValueFeature [name=constant1, params={value=1}]\n 7.0 = prod of:\n 3.5 = weight on feature\n 2.0 = ValueFeature [name=constant2, params={value=2}]\n 5.0 = prod of:\n 1.0 = weight on feature\n 5.0 = FieldValueFeature [name=pop, params={field=popularity}]\n'"); + } + + @Test + public void testRerankedExplainSameBetweenDifferentDocsWithSameFeatures() throws Exception { + loadFeatures("features-linear.json"); + loadModels("linear-model.json"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:bloomberg"); + query.setParam("debugQuery", "on"); + query.add("rows", "4"); + query.add("rq", "{!ltr reRankDocs=4 model=6029760550880411648}"); + query.add("fl", "*,score"); + query.add("wt", "json"); + final String expectedExplainNormalizer = "normalized using MinMaxNormalizer(min=0.0,max=10.0)"; + final String expectedExplain = "\n3.5116758 = LinearModel(name=6029760550880411648,featureWeights=[" + + "title=0.0," + + "description=0.1," + + "keywords=0.2," + + "popularity=0.3," + + "text=0.4," + + "queryIntentPerson=0.1231231," + + "queryIntentCompany=0.12121211" + + "]) model applied to features, sum of:\n 0.0 = prod of:\n 0.0 = weight on feature\n 1.0 = ValueFeature [name=title, params={value=1}]\n 0.2 = prod of:\n 0.1 = weight on feature\n 2.0 = ValueFeature [name=description, params={value=2}]\n 0.4 = prod of:\n 0.2 = weight on feature\n 2.0 = ValueFeature [name=keywords, params={value=2}]\n 0.09 = prod of:\n 0.3 = weight on feature\n 0.3 = "+expectedExplainNormalizer+"\n 3.0 = ValueFeature [name=popularity, params={value=3}]\n 1.6 = prod of:\n 0.4 = weight on feature\n 4.0 = ValueFeature [name=text, params={value=4}]\n 0.6156155 = prod of:\n 0.1231231 = weight on feature\n 5.0 = ValueFeature [name=queryIntentPerson, params={value=5}]\n 0.60606056 = prod of:\n 0.12121211 = weight on feature\n 5.0 = ValueFeature [name=queryIntentCompany, params={value=5}]\n"; + + assertJQ( + "/query" + query.toQueryString(), + "/debug/explain/7=='"+expectedExplain+"'}"); + assertJQ( + "/query" + query.toQueryString(), + "/debug/explain/9=='"+expectedExplain+"'}"); + } + + @Test + public void LinearScoreExplainMissingEfiFeatureShouldReturnDefaultScore() throws Exception { + loadFeatures("features-linear-efi.json"); + loadModels("linear-model-efi.json"); + + SolrQuery query = new SolrQuery(); + query.setQuery("title:bloomberg"); + query.setParam("debugQuery", "on"); + query.add("rows", "4"); + query.add("rq", "{!ltr reRankDocs=4 model=linear-efi}"); + query.add("fl", "*,score"); + query.add("wt", "xml"); + + final String linearModelEfiString = "LinearModel(name=linear-efi,featureWeights=[" + + "sampleConstant=1.0," + + "search_number_of_nights=2.0])"; + + query.remove("wt"); + query.add("wt", "json"); + assertJQ( + "/query" + query.toQueryString(), + "/debug/explain/7=='\n5.0 = "+linearModelEfiString+" model applied to features, sum of:\n 5.0 = prod of:\n 1.0 = weight on feature\n 5.0 = ValueFeature [name=sampleConstant, params={value=5}]\n" + + " 0.0 = prod of:\n" + + " 2.0 = weight on feature\n" + + " 0.0 = The feature has no value\n'}"); + assertJQ( + "/query" + query.toQueryString(), + "/debug/explain/9=='\n5.0 = "+linearModelEfiString+" model applied to features, sum of:\n 5.0 = prod of:\n 1.0 = weight on feature\n 5.0 = ValueFeature [name=sampleConstant, params={value=5}]\n" + + " 0.0 = prod of:\n" + + " 2.0 = weight on feature\n" + + " 0.0 = The feature has no value\n'}"); + } + + @Test + public void multipleAdditiveTreesScoreExplainMissingEfiFeatureShouldReturnDefaultScore() throws Exception { + loadFeatures("external_features_for_sparse_processing.json"); + loadModels("multipleadditivetreesmodel_external_binary_features.json"); + + SolrQuery query = new SolrQuery(); + query.setQuery("title:bloomberg"); + query.setParam("debugQuery", "on"); + query.add("rows", "4"); + query.add("rq", "{!ltr reRankDocs=4 model=external_model_binary_feature efi.user_device_tablet=1}"); + query.add("fl", "*,score"); + + final String tree1 = "(weight=1.0,root=(feature=user_device_smartphone,threshold=0.5,left=0.0,right=50.0))"; + final String tree2 = "(weight=1.0,root=(feature=user_device_tablet,threshold=0.5,left=0.0,right=65.0))"; + final String trees = "["+tree1+","+tree2+"]"; + + query.add("wt", "json"); + assertJQ( + "/query" + query.toQueryString(), + "/debug/explain/7=='\n" + + "65.0 = MultipleAdditiveTreesModel(name=external_model_binary_feature,trees="+trees+") model applied to features, sum of:\n" + + " 0.0 = tree 0 | \\'user_device_smartphone\\':0.0 <= 0.500001, Go Left | val: 0.0\n" + + " 65.0 = tree 1 | \\'user_device_tablet\\':1.0 > 0.500001, Go Right | val: 65.0\n'}"); + assertJQ( + "/query" + query.toQueryString(), + "/debug/explain/9=='\n" + + "65.0 = MultipleAdditiveTreesModel(name=external_model_binary_feature,trees="+trees+") model applied to features, sum of:\n" + + " 0.0 = tree 0 | \\'user_device_smartphone\\':0.0 <= 0.500001, Go Left | val: 0.0\n" + + " 65.0 = tree 1 | \\'user_device_tablet\\':1.0 > 0.500001, Go Right | val: 65.0\n'}"); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserPlugin.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserPlugin.java new file mode 100644 index 00000000000..f28ab0d9297 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserPlugin.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr; + +import org.apache.solr.client.solrj.SolrQuery; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestLTRQParserPlugin extends TestRerankBase { + + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema.xml"); + // store = getModelStore(); + bulkIndex(); + + loadFeatures("features-linear.json"); + loadModels("linear-model.json"); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + // store.clear(); + } + + @Test + public void ltrModelIdMissingTest() throws Exception { + final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}"; + final SolrQuery query = new SolrQuery(); + query.setQuery(solrQuery); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("fv", "true"); + query.add("rq", "{!ltr reRankDocs=100}"); + + final String res = restTestHarness.query("/query" + query.toQueryString()); + assert (res.contains("Must provide model in the request")); + } + + @Test + public void ltrModelIdDoesNotExistTest() throws Exception { + final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}"; + final SolrQuery query = new SolrQuery(); + query.setQuery(solrQuery); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("fv", "true"); + query.add("rq", "{!ltr model=-1 reRankDocs=100}"); + + final String res = restTestHarness.query("/query" + query.toQueryString()); + assert (res.contains("cannot find model")); + } + + @Test + public void ltrMoreResultsThanReRankedTest() throws Exception { + final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}"; + final SolrQuery query = new SolrQuery(); + query.setQuery(solrQuery); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("fv", "true"); + + String nonRerankedScore = "0.09271725"; + + // Normal solr order + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/id=='9'", + "/response/docs/[1]/id=='8'", + "/response/docs/[2]/id=='7'", + "/response/docs/[3]/id=='6'", + "/response/docs/[3]/score=="+nonRerankedScore + ); + + query.add("rq", "{!ltr model=6029760550880411648 reRankDocs=3}"); + + // Different order for top 3 reranked, but last one is the same top nonreranked doc + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/id=='7'", + "/response/docs/[1]/id=='8'", + "/response/docs/[2]/id=='9'", + "/response/docs/[3]/id=='6'", + "/response/docs/[3]/score=="+nonRerankedScore + ); + } + + @Test + public void ltrNoResultsTest() throws Exception { + final SolrQuery query = new SolrQuery(); + query.setQuery("title:bloomberg23"); + query.add("fl", "*,[fv]"); + query.add("rows", "3"); + query.add("debugQuery", "on"); + query.add("rq", "{!ltr reRankDocs=3 model=6029760550880411648}"); + assertJQ("/query" + query.toQueryString(), "/response/numFound/==0"); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java new file mode 100644 index 00000000000..a98fc4f5e34 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr; + +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.FloatDocValuesField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.LuceneTestCase; +import org.apache.solr.core.SolrResourceLoader; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.feature.FieldValueFeature; +import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.ltr.model.TestLinearModel; +import org.apache.solr.ltr.norm.IdentityNormalizer; +import org.apache.solr.ltr.norm.Normalizer; +import org.junit.Ignore; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class TestLTRReRankingPipeline extends LuceneTestCase { + + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + private static final SolrResourceLoader solrResourceLoader = new SolrResourceLoader(); + + private IndexSearcher getSearcher(IndexReader r) { + final IndexSearcher searcher = newSearcher(r); + + return searcher; + } + + private static List makeFieldValueFeatures(int[] featureIds, + String field) { + final List features = new ArrayList<>(); + for (final int i : featureIds) { + final Map params = new HashMap(); + params.put("field", field); + final Feature f = Feature.getInstance(solrResourceLoader, + FieldValueFeature.class.getCanonicalName(), + "f" + i, params); + f.setIndex(i); + features.add(f); + } + return features; + } + + private class MockModel extends LTRScoringModel { + + public MockModel(String name, List features, + List norms, + String featureStoreName, List allFeatures, + Map params) { + super(name, features, norms, featureStoreName, allFeatures, params); + } + + @Override + public float score(float[] modelFeatureValuesNormalized) { + return modelFeatureValuesNormalized[2]; + } + + @Override + public Explanation explain(LeafReaderContext context, int doc, + float finalScore, List featureExplanations) { + return null; + } + + } + + @Ignore + @Test + public void testRescorer() throws IOException { + final Directory dir = newDirectory(); + final RandomIndexWriter w = new RandomIndexWriter(random(), dir); + + Document doc = new Document(); + doc.add(newStringField("id", "0", Field.Store.YES)); + doc.add(newTextField("field", "wizard the the the the the oz", + Field.Store.NO)); + doc.add(new FloatDocValuesField("final-score", 1.0f)); + + w.addDocument(doc); + doc = new Document(); + doc.add(newStringField("id", "1", Field.Store.YES)); + // 1 extra token, but wizard and oz are close; + doc.add(newTextField("field", "wizard oz the the the the the the", + Field.Store.NO)); + doc.add(new FloatDocValuesField("final-score", 2.0f)); + w.addDocument(doc); + + final IndexReader r = w.getReader(); + w.close(); + + // Do ordinary BooleanQuery: + final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder(); + bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD); + bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD); + final IndexSearcher searcher = getSearcher(r); + // first run the standard query + TopDocs hits = searcher.search(bqBuilder.build(), 10); + assertEquals(2, hits.totalHits); + assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id")); + assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id")); + + final List features = makeFieldValueFeatures(new int[] {0, 1, 2}, + "final-score"); + final List norms = + new ArrayList( + Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE)); + final List allFeatures = makeFieldValueFeatures(new int[] {0, 1, + 2, 3, 4, 5, 6, 7, 8, 9}, "final-score"); + final LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test", + features, norms, "test", allFeatures, null); + + final LTRRescorer rescorer = new LTRRescorer(new LTRScoringQuery(ltrScoringModel)); + hits = rescorer.rescore(searcher, hits, 2); + + // rerank using the field final-score + assertEquals("1", searcher.doc(hits.scoreDocs[0].doc).get("id")); + assertEquals("0", searcher.doc(hits.scoreDocs[1].doc).get("id")); + + r.close(); + dir.close(); + + } + + @Ignore + @Test + public void testDifferentTopN() throws IOException { + final Directory dir = newDirectory(); + final RandomIndexWriter w = new RandomIndexWriter(random(), dir); + + Document doc = new Document(); + doc.add(newStringField("id", "0", Field.Store.YES)); + doc.add(newTextField("field", "wizard oz oz oz oz oz", Field.Store.NO)); + doc.add(new FloatDocValuesField("final-score", 1.0f)); + w.addDocument(doc); + + doc = new Document(); + doc.add(newStringField("id", "1", Field.Store.YES)); + doc.add(newTextField("field", "wizard oz oz oz oz the", Field.Store.NO)); + doc.add(new FloatDocValuesField("final-score", 2.0f)); + w.addDocument(doc); + doc = new Document(); + doc.add(newStringField("id", "2", Field.Store.YES)); + doc.add(newTextField("field", "wizard oz oz oz the the ", Field.Store.NO)); + doc.add(new FloatDocValuesField("final-score", 3.0f)); + w.addDocument(doc); + doc = new Document(); + doc.add(newStringField("id", "3", Field.Store.YES)); + doc.add(newTextField("field", "wizard oz oz the the the the ", + Field.Store.NO)); + doc.add(new FloatDocValuesField("final-score", 4.0f)); + w.addDocument(doc); + doc = new Document(); + doc.add(newStringField("id", "4", Field.Store.YES)); + doc.add(newTextField("field", "wizard oz the the the the the the", + Field.Store.NO)); + doc.add(new FloatDocValuesField("final-score", 5.0f)); + w.addDocument(doc); + + final IndexReader r = w.getReader(); + w.close(); + + // Do ordinary BooleanQuery: + final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder(); + bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD); + bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD); + final IndexSearcher searcher = getSearcher(r); + + // first run the standard query + TopDocs hits = searcher.search(bqBuilder.build(), 10); + assertEquals(5, hits.totalHits); + + assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id")); + assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id")); + assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id")); + assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id")); + assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id")); + + final List features = makeFieldValueFeatures(new int[] {0, 1, 2}, + "final-score"); + final List norms = + new ArrayList( + Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE)); + final List allFeatures = makeFieldValueFeatures(new int[] {0, 1, + 2, 3, 4, 5, 6, 7, 8, 9}, "final-score"); + final LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test", + features, norms, "test", allFeatures, null); + + final LTRRescorer rescorer = new LTRRescorer(new LTRScoringQuery(ltrScoringModel)); + + // rerank @ 0 should not change the order + hits = rescorer.rescore(searcher, hits, 0); + assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id")); + assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id")); + assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id")); + assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id")); + assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id")); + + // test rerank with different topN cuts + + for (int topN = 1; topN <= 5; topN++) { + log.info("rerank {} documents ", topN); + hits = searcher.search(bqBuilder.build(), 10); + + final ScoreDoc[] slice = new ScoreDoc[topN]; + System.arraycopy(hits.scoreDocs, 0, slice, 0, topN); + hits = new TopDocs(hits.totalHits, slice, hits.getMaxScore()); + hits = rescorer.rescore(searcher, hits, topN); + for (int i = topN - 1, j = 0; i >= 0; i--, j++) { + log.info("doc {} in pos {}", searcher.doc(hits.scoreDocs[j].doc) + .get("id"), j); + + assertEquals(i, + Integer.parseInt(searcher.doc(hits.scoreDocs[j].doc).get("id"))); + assertEquals(i + 1, hits.scoreDocs[j].score, 0.00001); + + } + } + + r.close(); + dir.close(); + + } + + @Test + public void testDocParam() throws Exception { + final Map test = new HashMap(); + test.put("fake", 2); + List features = makeFieldValueFeatures(new int[] {0}, + "final-score"); + List norms = + new ArrayList( + Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE)); + List allFeatures = makeFieldValueFeatures(new int[] {0}, + "final-score"); + MockModel ltrScoringModel = new MockModel("test", + features, norms, "test", allFeatures, null); + LTRScoringQuery query = new LTRScoringQuery(ltrScoringModel); + LTRScoringQuery.ModelWeight wgt = query.createWeight(null, true, 1f); + LTRScoringQuery.ModelWeight.ModelScorer modelScr = wgt.scorer(null); + modelScr.getDocInfo().setOriginalDocScore(new Float(1f)); + for (final Scorer.ChildScorer feat : modelScr.getChildren()) { + assertNotNull(((Feature.FeatureWeight.FeatureScorer) feat.child).getDocInfo().getOriginalDocScore()); + } + + features = makeFieldValueFeatures(new int[] {0, 1, 2}, "final-score"); + norms = + new ArrayList( + Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE)); + allFeatures = makeFieldValueFeatures(new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, + 9}, "final-score"); + ltrScoringModel = new MockModel("test", features, norms, + "test", allFeatures, null); + query = new LTRScoringQuery(ltrScoringModel); + wgt = query.createWeight(null, true, 1f); + modelScr = wgt.scorer(null); + modelScr.getDocInfo().setOriginalDocScore(new Float(1f)); + for (final Scorer.ChildScorer feat : modelScr.getChildren()) { + assertNotNull(((Feature.FeatureWeight.FeatureScorer) feat.child).getDocInfo().getOriginalDocScore()); + } + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRScoringQuery.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRScoringQuery.java new file mode 100644 index 00000000000..0576c999771 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRScoringQuery.java @@ -0,0 +1,319 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.FloatDocValuesField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.ReaderUtil; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.LuceneTestCase; +import org.apache.solr.core.SolrResourceLoader; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.feature.ValueFeature; +import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.ltr.model.ModelException; +import org.apache.solr.ltr.model.TestLinearModel; +import org.apache.solr.ltr.norm.IdentityNormalizer; +import org.apache.solr.ltr.norm.Normalizer; +import org.apache.solr.ltr.norm.NormalizerException; +import org.junit.Test; + +public class TestLTRScoringQuery extends LuceneTestCase { + + public final static SolrResourceLoader solrResourceLoader = new SolrResourceLoader(); + + private IndexSearcher getSearcher(IndexReader r) { + final IndexSearcher searcher = newSearcher(r, false, false); + return searcher; + } + + private static List makeFeatures(int[] featureIds) { + final List features = new ArrayList<>(); + for (final int i : featureIds) { + Map params = new HashMap(); + params.put("value", i); + final Feature f = Feature.getInstance(solrResourceLoader, + ValueFeature.class.getCanonicalName(), + "f" + i, params); + f.setIndex(i); + features.add(f); + } + return features; + } + + private static List makeFilterFeatures(int[] featureIds) { + final List features = new ArrayList<>(); + for (final int i : featureIds) { + Map params = new HashMap(); + params.put("value", i); + final Feature f = Feature.getInstance(solrResourceLoader, + ValueFeature.class.getCanonicalName(), + "f" + i, params); + f.setIndex(i); + features.add(f); + } + return features; + } + + private static Map makeFeatureWeights(List features) { + final Map nameParams = new HashMap(); + final HashMap modelWeights = new HashMap(); + for (final Feature feat : features) { + modelWeights.put(feat.getName(), 0.1); + } + nameParams.put("weights", modelWeights); + return nameParams; + } + + private LTRScoringQuery.ModelWeight performQuery(TopDocs hits, + IndexSearcher searcher, int docid, LTRScoringQuery model) throws IOException, + ModelException { + final List leafContexts = searcher.getTopReaderContext() + .leaves(); + final int n = ReaderUtil.subIndex(hits.scoreDocs[0].doc, leafContexts); + final LeafReaderContext context = leafContexts.get(n); + final int deBasedDoc = hits.scoreDocs[0].doc - context.docBase; + + final Weight weight = searcher.createNormalizedWeight(model, true); + final Scorer scorer = weight.scorer(context); + + // rerank using the field final-score + scorer.iterator().advance(deBasedDoc); + scorer.score(); + + // assertEquals(42.0f, score, 0.0001); + // assertTrue(weight instanceof AssertingWeight); + // (AssertingIndexSearcher) + assertTrue(weight instanceof LTRScoringQuery.ModelWeight); + final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) weight; + return modelWeight; + + } + + @Test + public void testLTRScoringQueryEquality() throws ModelException { + final List features = makeFeatures(new int[] {0, 1, 2}); + final List norms = + new ArrayList( + Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE)); + final List allFeatures = makeFeatures( + new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + final Map modelParams = makeFeatureWeights(features); + + final LTRScoringModel algorithm1 = TestLinearModel.createLinearModel( + "testModelName", + features, norms, "testStoreName", allFeatures, modelParams); + + final LTRScoringQuery m0 = new LTRScoringQuery(algorithm1); + + final HashMap externalFeatureInfo = new HashMap<>(); + externalFeatureInfo.put("queryIntent", new String[] {"company"}); + externalFeatureInfo.put("user_query", new String[] {"abc"}); + final LTRScoringQuery m1 = new LTRScoringQuery(algorithm1, externalFeatureInfo, false, null); + + final HashMap externalFeatureInfo2 = new HashMap<>(); + externalFeatureInfo2.put("user_query", new String[] {"abc"}); + externalFeatureInfo2.put("queryIntent", new String[] {"company"}); + int totalPoolThreads = 10, numThreadsPerRequest = 10; + LTRThreadModule threadManager = new LTRThreadModule(totalPoolThreads, numThreadsPerRequest); + final LTRScoringQuery m2 = new LTRScoringQuery(algorithm1, externalFeatureInfo2, false, threadManager); + + + // Models with same algorithm and efis, just in different order should be the same + assertEquals(m1, m2); + assertEquals(m1.hashCode(), m2.hashCode()); + + // Models with same algorithm, but different efi content should not match + assertFalse(m1.equals(m0)); + assertFalse(m1.hashCode() == m0.hashCode()); + + + final LTRScoringModel algorithm2 = TestLinearModel.createLinearModel( + "testModelName2", + features, norms, "testStoreName", allFeatures, modelParams); + final LTRScoringQuery m3 = new LTRScoringQuery(algorithm2); + + assertFalse(m1.equals(m3)); + assertFalse(m1.hashCode() == m3.hashCode()); + + final LTRScoringModel algorithm3 = TestLinearModel.createLinearModel( + "testModelName", + features, norms, "testStoreName3", allFeatures, modelParams); + final LTRScoringQuery m4 = new LTRScoringQuery(algorithm3); + + assertFalse(m1.equals(m4)); + assertFalse(m1.hashCode() == m4.hashCode()); + } + + + @Test + public void testLTRScoringQuery() throws IOException, ModelException { + final Directory dir = newDirectory(); + final RandomIndexWriter w = new RandomIndexWriter(random(), dir); + + Document doc = new Document(); + doc.add(newStringField("id", "0", Field.Store.YES)); + doc.add(newTextField("field", "wizard the the the the the oz", + Field.Store.NO)); + doc.add(new FloatDocValuesField("final-score", 1.0f)); + + w.addDocument(doc); + doc = new Document(); + doc.add(newStringField("id", "1", Field.Store.YES)); + // 1 extra token, but wizard and oz are close; + doc.add(newTextField("field", "wizard oz the the the the the the", + Field.Store.NO)); + doc.add(new FloatDocValuesField("final-score", 2.0f)); + w.addDocument(doc); + + final IndexReader r = w.getReader(); + w.close(); + + // Do ordinary BooleanQuery: + final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder(); + bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD); + bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD); + final IndexSearcher searcher = getSearcher(r); + // first run the standard query + final TopDocs hits = searcher.search(bqBuilder.build(), 10); + assertEquals(2, hits.totalHits); + assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id")); + assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id")); + + List features = makeFeatures(new int[] {0, 1, 2}); + final List allFeatures = makeFeatures(new int[] {0, 1, 2, 3, 4, 5, + 6, 7, 8, 9}); + List norms = + new ArrayList( + Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE)); + LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test", + features, norms, "test", allFeatures, + makeFeatureWeights(features)); + + LTRScoringQuery.ModelWeight modelWeight = performQuery(hits, searcher, + hits.scoreDocs[0].doc, new LTRScoringQuery(ltrScoringModel)); + assertEquals(3, modelWeight.getModelFeatureValuesNormalized().length); + + for (int i = 0; i < 3; i++) { + assertEquals(i, modelWeight.getModelFeatureValuesNormalized()[i], 0.0001); + } + int[] posVals = new int[] {0, 1, 2}; + int pos = 0; + for (LTRScoringQuery.FeatureInfo fInfo:modelWeight.getFeaturesInfo()) { + if (fInfo == null){ + continue; + } + assertEquals(posVals[pos], fInfo.getValue(), 0.0001); + assertEquals("f"+posVals[pos], fInfo.getName()); + pos++; + } + + final int[] mixPositions = new int[] {8, 2, 4, 9, 0}; + features = makeFeatures(mixPositions); + norms = + new ArrayList( + Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE)); + ltrScoringModel = TestLinearModel.createLinearModel("test", + features, norms, "test", allFeatures, makeFeatureWeights(features)); + + modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc, + new LTRScoringQuery(ltrScoringModel)); + assertEquals(mixPositions.length, + modelWeight.getModelFeatureWeights().length); + + for (int i = 0; i < mixPositions.length; i++) { + assertEquals(mixPositions[i], + modelWeight.getModelFeatureValuesNormalized()[i], 0.0001); + } + + final ModelException expectedModelException = new ModelException("no features declared for model test"); + final int[] noPositions = new int[] {}; + features = makeFeatures(noPositions); + norms = + new ArrayList( + Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE)); + try { + ltrScoringModel = TestLinearModel.createLinearModel("test", + features, norms, "test", allFeatures, makeFeatureWeights(features)); + fail("unexpectedly got here instead of catching "+expectedModelException); + modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc, + new LTRScoringQuery(ltrScoringModel)); + assertEquals(0, modelWeight.getModelFeatureWeights().length); + } catch (ModelException actualModelException) { + assertEquals(expectedModelException.toString(), actualModelException.toString()); + } + + // test normalizers + features = makeFilterFeatures(mixPositions); + final Normalizer norm = new Normalizer() { + + @Override + public float normalize(float value) { + return 42.42f; + } + + @Override + public LinkedHashMap paramsToMap() { + return null; + } + + @Override + protected void validate() throws NormalizerException { + } + + }; + norms = + new ArrayList( + Collections.nCopies(features.size(),norm)); + final LTRScoringModel normMeta = TestLinearModel.createLinearModel("test", + features, norms, "test", allFeatures, + makeFeatureWeights(features)); + + modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc, + new LTRScoringQuery(normMeta)); + normMeta.normalizeFeaturesInPlace(modelWeight.getModelFeatureValuesNormalized()); + assertEquals(mixPositions.length, + modelWeight.getModelFeatureWeights().length); + for (int i = 0; i < mixPositions.length; i++) { + assertEquals(42.42f, modelWeight.getModelFeatureValuesNormalized()[i], 0.0001); + } + r.close(); + dir.close(); + + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRWithFacet.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRWithFacet.java new file mode 100644 index 00000000000..ab519ec24ab --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRWithFacet.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.ltr; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.feature.SolrFeature; +import org.apache.solr.ltr.model.LinearModel; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestLTRWithFacet extends TestRerankBase { + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema.xml"); + + assertU(adoc("id", "1", "title", "a1", "description", "E", "popularity", + "1")); + assertU(adoc("id", "2", "title", "a1 b1", "description", + "B", "popularity", "2")); + assertU(adoc("id", "3", "title", "a1 b1 c1", "description", "B", "popularity", + "3")); + assertU(adoc("id", "4", "title", "a1 b1 c1 d1", "description", "B", "popularity", + "4")); + assertU(adoc("id", "5", "title", "a1 b1 c1 d1 e1", "description", "E", "popularity", + "5")); + assertU(adoc("id", "6", "title", "a1 b1 c1 d1 e1 f1", "description", "B", + "popularity", "6")); + assertU(adoc("id", "7", "title", "a1 b1 c1 d1 e1 f1 g1", "description", + "C", "popularity", "7")); + assertU(adoc("id", "8", "title", "a1 b1 c1 d1 e1 f1 g1 h1", "description", + "D", "popularity", "8")); + assertU(commit()); + } + + @Test + public void testRankingSolrFacet() throws Exception { + // before(); + loadFeature("powpularityS", SolrFeature.class.getCanonicalName(), + "{\"q\":\"{!func}pow(popularity,2)\"}"); + + loadModel("powpularityS-model", LinearModel.class.getCanonicalName(), + new String[] {"powpularityS"}, "{\"weights\":{\"powpularityS\":1.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:a1"); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("facet", "true"); + query.add("facet.field", "description"); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==8"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='2'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='3'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='4'"); + // Normal term match + assertJQ("/query" + query.toQueryString(), "" + + "/facet_counts/facet_fields/description==" + + "['b', 4, 'e', 2, 'c', 1, 'd', 1]"); + + query.add("rq", "{!ltr model=powpularityS-model reRankDocs=4}"); + query.set("debugQuery", "on"); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==8"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='4'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==16.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='3'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==9.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='2'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==4.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/score==1.0"); + + assertJQ("/query" + query.toQueryString(), "" + + "/facet_counts/facet_fields/description==" + + "['b', 4, 'e', 2, 'c', 1, 'd', 1]"); + // aftertest(); + + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRWithSort.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRWithSort.java new file mode 100644 index 00000000000..1fbe1d5fe58 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRWithSort.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.ltr; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.feature.SolrFeature; +import org.apache.solr.ltr.model.LinearModel; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestLTRWithSort extends TestRerankBase { + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema.xml"); + assertU(adoc("id", "1", "title", "a1", "description", "E", "popularity", + "1")); + assertU(adoc("id", "2", "title", "a1 b1", "description", + "B", "popularity", "2")); + assertU(adoc("id", "3", "title", "a1 b1 c1", "description", "B", "popularity", + "3")); + assertU(adoc("id", "4", "title", "a1 b1 c1 d1", "description", "B", "popularity", + "4")); + assertU(adoc("id", "5", "title", "a1 b1 c1 d1 e1", "description", "E", "popularity", + "5")); + assertU(adoc("id", "6", "title", "a1 b1 c1 d1 e1 f1", "description", "B", + "popularity", "6")); + assertU(adoc("id", "7", "title", "a1 b1 c1 d1 e1 f1 g1", "description", + "C", "popularity", "7")); + assertU(adoc("id", "8", "title", "a1 b1 c1 d1 e1 f1 g1 h1", "description", + "D", "popularity", "8")); + assertU(commit()); + } + + @Test + public void testRankingSolrSort() throws Exception { + // before(); + loadFeature("powpularityS", SolrFeature.class.getCanonicalName(), + "{\"q\":\"{!func}pow(popularity,2)\"}"); + + loadModel("powpularityS-model", LinearModel.class.getCanonicalName(), + new String[] {"powpularityS"}, "{\"weights\":{\"powpularityS\":1.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:a1"); + query.add("fl", "*, score"); + query.add("rows", "4"); + + // Normal term match + assertJQ("/query" + query.toQueryString(), "/response/numFound/==8"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='2'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='3'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='4'"); + + //Add sort + query.add("sort", "description desc"); + assertJQ("/query" + query.toQueryString(), "/response/numFound/==8"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='5'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='7'"); + + query.add("rq", "{!ltr model=powpularityS-model reRankDocs=4}"); + query.set("debugQuery", "on"); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==8"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==64.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='7'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==49.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='5'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==25.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/score==1.0"); + + // aftertest(); + + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestParallelWeightCreation.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestParallelWeightCreation.java new file mode 100644 index 00000000000..f4c21fd6607 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestParallelWeightCreation.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr; + +import org.apache.solr.client.solrj.SolrQuery; +import org.junit.Test; + +public class TestParallelWeightCreation extends TestRerankBase{ + + @Test + public void testLTRScoringQueryParallelWeightCreationResultOrder() throws Exception { + setuptest("solrconfig-ltr_Th10_10.xml", "schema.xml"); + + assertU(adoc("id", "1", "title", "w1 w3", "description", "w1", "popularity", + "1")); + assertU(adoc("id", "2", "title", "w2", "description", "w2", "popularity", + "2")); + assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity", + "3")); + assertU(adoc("id", "4", "title", "w4 w3", "description", "w4", "popularity", + "4")); + assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", + "5")); + assertU(commit()); + + loadFeatures("external_features.json"); + loadModels("external_model.json"); + loadModels("external_model_store.json"); + + // check to make sure that the order of results will be the same when using parallel weight creation + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*,score"); + query.add("rows", "4"); + + query.add("rq", "{!ltr reRankDocs=4 model=externalmodel efi.user_query=w3}"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='3'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='4'"); + aftertest(); + } + + @Test + public void testLTRQParserThreadInitialization() throws Exception { + // setting the value of number of threads to -ve should throw an exception + String msg1 = null; + try{ + new LTRThreadModule(1,-1); + }catch(IllegalArgumentException iae){ + msg1 = iae.getMessage();; + } + assertTrue(msg1.equals("numThreadsPerRequest cannot be less than 1")); + + // set totalPoolThreads to 1 and numThreadsPerRequest to 2 and verify that an exception is thrown + String msg2 = null; + try{ + new LTRThreadModule(1,2); + }catch(IllegalArgumentException iae){ + msg2 = iae.getMessage(); + } + assertTrue(msg2.equals("numThreadsPerRequest cannot be greater than totalPoolThreads")); + } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java new file mode 100644 index 00000000000..4914d28cb96 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.net.URL; +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Scanner; +import java.util.SortedMap; +import java.util.TreeMap; + +import org.apache.commons.io.FileUtils; +import org.apache.commons.lang.StringUtils; +import org.apache.solr.common.params.CommonParams; +import org.apache.solr.common.util.ContentStream; +import org.apache.solr.common.util.ContentStreamBase; +import org.apache.solr.core.SolrResourceLoader; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.feature.FeatureException; +import org.apache.solr.ltr.feature.ValueFeature; +import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.ltr.model.LinearModel; +import org.apache.solr.ltr.model.ModelException; +import org.apache.solr.ltr.store.rest.ManagedFeatureStore; +import org.apache.solr.ltr.store.rest.ManagedModelStore; +import org.apache.solr.request.SolrQueryRequestBase; +import org.apache.solr.response.SolrQueryResponse; +import org.apache.solr.rest.ManagedResourceStorage; +import org.apache.solr.rest.SolrSchemaRestApi; +import org.apache.solr.util.RestTestBase; +import org.eclipse.jetty.servlet.ServletHolder; +import org.noggit.ObjectBuilder; +import org.restlet.ext.servlet.ServerServlet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class TestRerankBase extends RestTestBase { + + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + protected static final SolrResourceLoader solrResourceLoader = new SolrResourceLoader(); + + protected static File tmpSolrHome; + protected static File tmpConfDir; + + public static final String FEATURE_FILE_NAME = "_schema_feature-store.json"; + public static final String MODEL_FILE_NAME = "_schema_model-store.json"; + public static final String PARENT_ENDPOINT = "/schema/*"; + + protected static final String COLLECTION = "collection1"; + protected static final String CONF_DIR = COLLECTION + "/conf"; + + protected static File fstorefile = null; + protected static File mstorefile = null; + + public static void setuptest() throws Exception { + setuptest("solrconfig-ltr.xml", "schema.xml"); + bulkIndex(); + } + + public static void setupPersistenttest() throws Exception { + setupPersistentTest("solrconfig-ltr.xml", "schema.xml"); + bulkIndex(); + } + + public static ManagedFeatureStore getManagedFeatureStore() { + return ManagedFeatureStore.getManagedFeatureStore(h.getCore()); + } + + public static ManagedModelStore getManagedModelStore() { + return ManagedModelStore.getManagedModelStore(h.getCore()); + } + + protected static SortedMap setupTestInit( + String solrconfig, String schema, + boolean isPersistent) throws Exception { + tmpSolrHome = createTempDir().toFile(); + tmpConfDir = new File(tmpSolrHome, CONF_DIR); + tmpConfDir.deleteOnExit(); + FileUtils.copyDirectory(new File(TEST_HOME()), + tmpSolrHome.getAbsoluteFile()); + + final File fstore = new File(tmpConfDir, FEATURE_FILE_NAME); + final File mstore = new File(tmpConfDir, MODEL_FILE_NAME); + + if (isPersistent) { + fstorefile = fstore; + mstorefile = mstore; + } + + if (fstore.exists()) { + log.info("remove feature store config file in {}", + fstore.getAbsolutePath()); + Files.delete(fstore.toPath()); + } + if (mstore.exists()) { + log.info("remove model store config file in {}", + mstore.getAbsolutePath()); + Files.delete(mstore.toPath()); + } + if (!solrconfig.equals("solrconfig.xml")) { + FileUtils.copyFile(new File(tmpSolrHome.getAbsolutePath() + + "/collection1/conf/" + solrconfig), + new File(tmpSolrHome.getAbsolutePath() + + "/collection1/conf/solrconfig.xml")); + } + if (!schema.equals("schema.xml")) { + FileUtils.copyFile(new File(tmpSolrHome.getAbsolutePath() + + "/collection1/conf/" + schema), + new File(tmpSolrHome.getAbsolutePath() + + "/collection1/conf/schema.xml")); + } + + final SortedMap extraServlets = new TreeMap<>(); + final ServletHolder solrRestApi = new ServletHolder("SolrSchemaRestApi", + ServerServlet.class); + solrRestApi.setInitParameter("org.restlet.application", + SolrSchemaRestApi.class.getCanonicalName()); + solrRestApi.setInitParameter("storageIO", + ManagedResourceStorage.InMemoryStorageIO.class.getCanonicalName()); + extraServlets.put(solrRestApi, PARENT_ENDPOINT); + + System.setProperty("managed.schema.mutable", "true"); + + return extraServlets; + } + + public static void setuptest(String solrconfig, String schema) + throws Exception { + initCore(solrconfig, schema); + + SortedMap extraServlets = + setupTestInit(solrconfig,schema,false); + System.setProperty("enable.update.log", "false"); + + createJettyAndHarness(tmpSolrHome.getAbsolutePath(), solrconfig, schema, + "/solr", true, extraServlets); + } + + public static void setupPersistentTest(String solrconfig, String schema) + throws Exception { + initCore(solrconfig, schema); + + SortedMap extraServlets = + setupTestInit(solrconfig,schema,true); + + createJettyAndHarness(tmpSolrHome.getAbsolutePath(), solrconfig, schema, + "/solr", true, extraServlets); + } + + protected static void aftertest() throws Exception { + restTestHarness.close(); + restTestHarness = null; + jetty.stop(); + jetty = null; + FileUtils.deleteDirectory(tmpSolrHome); + System.clearProperty("managed.schema.mutable"); + // System.clearProperty("enable.update.log"); + + + } + + public static void makeRestTestHarnessNull() { + restTestHarness = null; + } + + /** produces a model encoded in json **/ + public static String getModelInJson(String name, String type, + String[] features, String fstore, String params) { + final StringBuilder sb = new StringBuilder(); + sb.append("{\n"); + sb.append("\"name\":").append('"').append(name).append('"').append(",\n"); + sb.append("\"store\":").append('"').append(fstore).append('"') + .append(",\n"); + sb.append("\"class\":").append('"').append(type).append('"').append(",\n"); + sb.append("\"features\":").append('['); + for (final String feature : features) { + sb.append("\n\t{ "); + sb.append("\"name\":").append('"').append(feature).append('"') + .append("},"); + } + sb.deleteCharAt(sb.length() - 1); + sb.append("\n]\n"); + if (params != null) { + sb.append(",\n"); + sb.append("\"params\":").append(params); + } + sb.append("\n}\n"); + return sb.toString(); + } + + /** produces a model encoded in json **/ + public static String getFeatureInJson(String name, String type, + String fstore, String params) { + final StringBuilder sb = new StringBuilder(); + sb.append("{\n"); + sb.append("\"name\":").append('"').append(name).append('"').append(",\n"); + sb.append("\"store\":").append('"').append(fstore).append('"') + .append(",\n"); + sb.append("\"class\":").append('"').append(type).append('"'); + if (params != null) { + sb.append(",\n"); + sb.append("\"params\":").append(params); + } + sb.append("\n}\n"); + return sb.toString(); + } + + protected static void loadFeature(String name, String type, String params) + throws Exception { + final String feature = getFeatureInJson(name, type, "test", params); + log.info("loading feauture \n{} ", feature); + assertJPut(ManagedFeatureStore.REST_END_POINT, feature, + "/responseHeader/status==0"); + } + + protected static void loadFeature(String name, String type, String fstore, + String params) throws Exception { + final String feature = getFeatureInJson(name, type, fstore, params); + log.info("loading feauture \n{} ", feature); + assertJPut(ManagedFeatureStore.REST_END_POINT, feature, + "/responseHeader/status==0"); + } + + protected static void loadModel(String name, String type, String[] features, + String params) throws Exception { + loadModel(name, type, features, "test", params); + } + + protected static void loadModel(String name, String type, String[] features, + String fstore, String params) throws Exception { + final String model = getModelInJson(name, type, features, fstore, params); + log.info("loading model \n{} ", model); + assertJPut(ManagedModelStore.REST_END_POINT, model, + "/responseHeader/status==0"); + } + + public static void loadModels(String fileName) throws Exception { + final URL url = TestRerankBase.class.getResource("/modelExamples/" + + fileName); + final String multipleModels = FileUtils.readFileToString( + new File(url.toURI()), "UTF-8"); + + assertJPut(ManagedModelStore.REST_END_POINT, multipleModels, + "/responseHeader/status==0"); + } + + public static LTRScoringModel createModelFromFiles(String modelFileName, + String featureFileName) throws ModelException, Exception { + URL url = TestRerankBase.class.getResource("/modelExamples/" + + modelFileName); + final String modelJson = FileUtils.readFileToString(new File(url.toURI()), + "UTF-8"); + final ManagedModelStore ms = getManagedModelStore(); + + url = TestRerankBase.class.getResource("/featureExamples/" + + featureFileName); + final String featureJson = FileUtils.readFileToString( + new File(url.toURI()), "UTF-8"); + + Object parsedFeatureJson = null; + try { + parsedFeatureJson = ObjectBuilder.fromJSON(featureJson); + } catch (final IOException ioExc) { + throw new ModelException("ObjectBuilder failed parsing json", ioExc); + } + + final ManagedFeatureStore fs = getManagedFeatureStore(); + // fs.getFeatureStore(null).clear(); + fs.doDeleteChild(null, "*"); // is this safe?? + // based on my need to call this I dont think that + // "getNewManagedFeatureStore()" + // is actually returning a new feature store each time + fs.applyUpdatesToManagedData(parsedFeatureJson); + ms.setManagedFeatureStore(fs); // can we skip this and just use fs directly below? + + final LTRScoringModel ltrScoringModel = ManagedModelStore.fromLTRScoringModelMap( + solrResourceLoader, mapFromJson(modelJson), ms.getManagedFeatureStore()); + ms.addModel(ltrScoringModel); + return ltrScoringModel; + } + + @SuppressWarnings("unchecked") + static private Map mapFromJson(String json) throws ModelException { + Object parsedJson = null; + try { + parsedJson = ObjectBuilder.fromJSON(json); + } catch (final IOException ioExc) { + throw new ModelException("ObjectBuilder failed parsing json", ioExc); + } + return (Map) parsedJson; + } + + public static void loadFeatures(String fileName) throws Exception { + final URL url = TestRerankBase.class.getResource("/featureExamples/" + + fileName); + final String multipleFeatures = FileUtils.readFileToString( + new File(url.toURI()), "UTF-8"); + log.info("send \n{}", multipleFeatures); + + assertJPut(ManagedFeatureStore.REST_END_POINT, multipleFeatures, + "/responseHeader/status==0"); + } + + protected List getFeatures(List names) + throws FeatureException { + final List features = new ArrayList<>(); + int pos = 0; + for (final String name : names) { + final Map params = new HashMap(); + params.put("value", 10); + final Feature f = Feature.getInstance(solrResourceLoader, + ValueFeature.class.getCanonicalName(), + name, params); + f.setIndex(pos); + features.add(f); + ++pos; + } + return features; + } + + protected List getFeatures(String[] names) throws FeatureException { + return getFeatures(Arrays.asList(names)); + } + + protected static void loadModelAndFeatures(String name, int allFeatureCount, + int modelFeatureCount) throws Exception { + final String[] features = new String[modelFeatureCount]; + final String[] weights = new String[modelFeatureCount]; + for (int i = 0; i < allFeatureCount; i++) { + final String featureName = "c" + i; + if (i < modelFeatureCount) { + features[i] = featureName; + weights[i] = "\"" + featureName + "\":1.0"; + } + loadFeature(featureName, ValueFeature.ValueFeatureWeight.class.getCanonicalName(), + "{\"value\":" + i + "}"); + } + + loadModel(name, LinearModel.class.getCanonicalName(), features, + "{\"weights\":{" + StringUtils.join(weights, ",") + "}}"); + } + + protected static void bulkIndex() throws Exception { + assertU(adoc("title", "bloomberg different bla", "description", + "bloomberg", "id", "6", "popularity", "1")); + assertU(adoc("title", "bloomberg bloomberg ", "description", "bloomberg", + "id", "7", "popularity", "2")); + assertU(adoc("title", "bloomberg bloomberg bloomberg", "description", + "bloomberg", "id", "8", "popularity", "3")); + assertU(adoc("title", "bloomberg bloomberg bloomberg bloomberg", + "description", "bloomberg", "id", "9", "popularity", "5")); + assertU(commit()); + } + + protected static void bulkIndex(String filePath) throws Exception { + final SolrQueryRequestBase req = lrf.makeRequest( + CommonParams.STREAM_CONTENTTYPE, "application/xml"); + + final List streams = new ArrayList(); + final File file = new File(filePath); + streams.add(new ContentStreamBase.FileStream(file)); + req.setContentStreams(streams); + + try { + final SolrQueryResponse res = new SolrQueryResponse(); + h.updater.handleRequest(req, res); + } catch (final Throwable ex) { + // Ignore. Just log the exception and go to the next file + log.error(ex.getMessage(), ex); + } + assertU(commit()); + + } + + protected static void buildIndexUsingAdoc(String filepath) + throws FileNotFoundException { + final Scanner scn = new Scanner(new File(filepath), "UTF-8"); + StringBuffer buff = new StringBuffer(); + scn.nextLine(); + scn.nextLine(); + scn.nextLine(); // Skip the first 3 lines then add everything else + final ArrayList docsToAdd = new ArrayList(); + while (scn.hasNext()) { + String curLine = scn.nextLine(); + if (curLine.contains("")) { + buff.append(curLine + "\n"); + docsToAdd.add(buff.toString().replace("", "") + .replace("", "\n") + .replace("", "\n")); + if (!scn.hasNext()) { + break; + } else { + curLine = scn.nextLine(); + } + buff = new StringBuffer(); + } + buff.append(curLine + "\n"); + } + for (final String doc : docsToAdd) { + assertU(doc.trim()); + } + assertU(commit()); + scn.close(); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java new file mode 100644 index 00000000000..68961d22473 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.FloatDocValuesField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.ReaderUtil; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.feature.ValueFeature; +import org.apache.solr.ltr.model.LTRScoringModel; +import org.apache.solr.ltr.model.ModelException; +import org.apache.solr.ltr.model.TestLinearModel; +import org.apache.solr.ltr.norm.IdentityNormalizer; +import org.apache.solr.ltr.norm.Normalizer; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestSelectiveWeightCreation extends TestRerankBase { + private IndexSearcher getSearcher(IndexReader r) { + final IndexSearcher searcher = newSearcher(r, false, false); + return searcher; + } + + private static List makeFeatures(int[] featureIds) { + final List features = new ArrayList<>(); + for (final int i : featureIds) { + Map params = new HashMap(); + params.put("value", i); + final Feature f = Feature.getInstance(solrResourceLoader, + ValueFeature.class.getCanonicalName(), + "f" + i, params); + f.setIndex(i); + features.add(f); + } + return features; + } + + private static Map makeFeatureWeights(List features) { + final Map nameParams = new HashMap(); + final HashMap modelWeights = new HashMap(); + for (final Feature feat : features) { + modelWeights.put(feat.getName(), 0.1); + } + nameParams.put("weights", modelWeights); + return nameParams; + } + + private LTRScoringQuery.ModelWeight performQuery(TopDocs hits, + IndexSearcher searcher, int docid, LTRScoringQuery model) throws IOException, + ModelException { + final List leafContexts = searcher.getTopReaderContext() + .leaves(); + final int n = ReaderUtil.subIndex(hits.scoreDocs[0].doc, leafContexts); + final LeafReaderContext context = leafContexts.get(n); + final int deBasedDoc = hits.scoreDocs[0].doc - context.docBase; + + final Weight weight = searcher.createNormalizedWeight(model, true); + final Scorer scorer = weight.scorer(context); + + // rerank using the field final-score + scorer.iterator().advance(deBasedDoc); + scorer.score(); + assertTrue(weight instanceof LTRScoringQuery.ModelWeight); + final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) weight; + return modelWeight; + + } + + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema.xml"); + + assertU(adoc("id", "1", "title", "w1 w3", "description", "w1", "popularity", + "1")); + assertU(adoc("id", "2", "title", "w2", "description", "w2", "popularity", + "2")); + assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity", + "3")); + assertU(adoc("id", "4", "title", "w4 w3", "description", "w4", "popularity", + "4")); + assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", + "5")); + assertU(commit()); + + loadFeatures("external_features.json"); + loadModels("external_model.json"); + loadModels("external_model_store.json"); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void testScoringQueryWeightCreation() throws IOException, ModelException { + final Directory dir = newDirectory(); + final RandomIndexWriter w = new RandomIndexWriter(random(), dir); + + Document doc = new Document(); + doc.add(newStringField("id", "0", Field.Store.YES)); + doc.add(newTextField("field", "wizard the the the the the oz", + Field.Store.NO)); + doc.add(new FloatDocValuesField("final-score", 1.0f)); + + w.addDocument(doc); + doc = new Document(); + doc.add(newStringField("id", "1", Field.Store.YES)); + // 1 extra token, but wizard and oz are close; + doc.add(newTextField("field", "wizard oz the the the the the the", + Field.Store.NO)); + doc.add(new FloatDocValuesField("final-score", 2.0f)); + w.addDocument(doc); + + final IndexReader r = w.getReader(); + w.close(); + + // Do ordinary BooleanQuery: + final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder(); + bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD); + bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD); + final IndexSearcher searcher = getSearcher(r); + // first run the standard query + final TopDocs hits = searcher.search(bqBuilder.build(), 10); + assertEquals(2, hits.totalHits); + assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id")); + assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id")); + + List features = makeFeatures(new int[] {0, 1, 2}); + final List allFeatures = makeFeatures(new int[] {0, 1, 2, 3, 4, 5, + 6, 7, 8, 9}); + final List norms = new ArrayList<>(); + for (int k=0; k < features.size(); ++k){ + norms.add(IdentityNormalizer.INSTANCE); + } + + // when features are NOT requested in the response, only the modelFeature weights should be created + final LTRScoringModel ltrScoringModel1 = TestLinearModel.createLinearModel("test", + features, norms, "test", allFeatures, + makeFeatureWeights(features)); + LTRScoringQuery.ModelWeight modelWeight = performQuery(hits, searcher, + hits.scoreDocs[0].doc, new LTRScoringQuery(ltrScoringModel1, false)); // features not requested in response + LTRScoringQuery.FeatureInfo[] featuresInfo = modelWeight.getFeaturesInfo(); + + assertEquals(features.size(), modelWeight.getModelFeatureValuesNormalized().length); + int validFeatures = 0; + for (int i=0; i < featuresInfo.length; ++i){ + if (featuresInfo[i] != null && featuresInfo[i].isUsed()){ + validFeatures += 1; + } + } + assertEquals(validFeatures, features.size()); + + // when features are requested in the response, weights should be created for all features + final LTRScoringModel ltrScoringModel2 = TestLinearModel.createLinearModel("test", + features, norms, "test", allFeatures, + makeFeatureWeights(features)); + modelWeight = performQuery(hits, searcher, + hits.scoreDocs[0].doc, new LTRScoringQuery(ltrScoringModel2, true)); // features requested in response + featuresInfo = modelWeight.getFeaturesInfo(); + + assertEquals(features.size(), modelWeight.getModelFeatureValuesNormalized().length); + assertEquals(allFeatures.size(), modelWeight.getExtractedFeatureWeights().length); + + validFeatures = 0; + for (int i=0; i < featuresInfo.length; ++i){ + if (featuresInfo[i] != null && featuresInfo[i].isUsed()){ + validFeatures += 1; + } + } + assertEquals(validFeatures, allFeatures.size()); + + assertU(delI("0"));assertU(delI("1")); + r.close(); + dir.close(); + } + + + @Test + public void testSelectiveWeightsRequestFeaturesFromDifferentStore() throws Exception { + + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*,score"); + query.add("rows", "4"); + query.add("rq", "{!ltr reRankDocs=4 model=externalmodel efi.user_query=w3}"); + query.add("fl", "fv:[fv]"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='3'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='4'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fv=='matchedTitle:1.0;titlePhraseMatch:0.40254828'"); // extract all features in default store + + query.remove("fl"); + query.remove("rq"); + query.add("fl", "*,score"); + query.add("rq", "{!ltr reRankDocs=4 model=externalmodel efi.user_query=w3}"); + query.add("fl", "fv:[fv store=fstore4 efi.myPop=3]"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==0.999"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fv=='popularity:3.0;originalScore:1.0'"); // extract all features from fstore4 + + + query.remove("fl"); + query.remove("rq"); + query.add("fl", "*,score"); + query.add("rq", "{!ltr reRankDocs=4 model=externalmodelstore efi.user_query=w3 efi.myconf=0.8}"); + query.add("fl", "fv:[fv store=fstore4 efi.myPop=3]"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); // score using fstore2 used by externalmodelstore + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==0.7992"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fv=='popularity:3.0;originalScore:1.0'"); // extract all features from fstore4 + } +} + diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestEdisMaxSolrFeature.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestEdisMaxSolrFeature.java new file mode 100644 index 00000000000..cd63b5c17e2 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestEdisMaxSolrFeature.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.model.LinearModel; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestEdisMaxSolrFeature extends TestRerankBase { + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema.xml"); + + assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity", + "1")); + assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description", + "w2 2asd asdd didid", "popularity", "2")); + assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity", + "3")); + assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity", + "4")); + assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", + "5")); + assertU(adoc("id", "6", "title", "w1 w2", "description", "w1 w2", + "popularity", "6")); + assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description", + "w1 w2 w3 w4 w5 w8", "popularity", "7")); + assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description", + "w1 w1 w1 w2 w2", "popularity", "8")); + assertU(commit()); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void testEdisMaxSolrFeature() throws Exception { + loadFeature( + "SomeEdisMax", + SolrFeature.class.getCanonicalName(), + "{\"q\":\"{!edismax qf='title description' pf='description' mm=100% boost='pow(popularity, 0.1)' v='w1' tie=0.1}\"}"); + + loadModel("EdisMax-model", LinearModel.class.getCanonicalName(), + new String[] {"SomeEdisMax"}, "{\"weights\":{\"SomeEdisMax\":1.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score"); + query.add("rows", "4"); + + query.add("rq", "{!ltr model=EdisMax-model reRankDocs=4}"); + query.set("debugQuery", "on"); + restTestHarness.query("/query" + query.toQueryString()); + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestExternalFeatures.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestExternalFeatures.java new file mode 100644 index 00000000000..8c00758b03e --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestExternalFeatures.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestExternalFeatures extends TestRerankBase { + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema.xml"); + + assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity", + "1")); + assertU(adoc("id", "2", "title", "w2", "description", "w2", "popularity", + "2")); + assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity", + "3")); + assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity", + "4")); + assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", + "5")); + assertU(commit()); + + loadFeatures("external_features.json"); + loadModels("external_model.json"); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void testEfiInTransformerShouldNotChangeOrderOfRerankedResults() throws Exception { + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*,score"); + query.add("rows", "3"); + + // Regular scores + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==1.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='2'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==1.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='3'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==1.0"); + + query.add("fl", "[fv]"); + query.add("rq", "{!ltr reRankDocs=3 model=externalmodel efi.user_query=w3}"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='3'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==0.999"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==0.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='2'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==0.0"); + + // Adding an efi in the transformer should not affect the rq ranking with a + // different value for efi of the same parameter + query.remove("fl"); + query.add("fl", "id,[fv efi.user_query=w2]"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='3'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='2'"); + } + + @Test + public void testFeaturesUseStopwordQueryReturnEmptyFeatureVector() throws Exception { + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*,score,fv:[fv]"); + query.add("rows", "1"); + // Stopword only query passed in + query.add("rq", "{!ltr reRankDocs=3 model=externalmodel efi.user_query='a'}"); + + // Features are query title matches, which remove stopwords, leaving blank query, so no matches + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fv==''"); + } + + @Test + public void testEfiFeatureExtraction() throws Exception { + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("rows", "1"); + + // Features we're extracting depend on external feature info not passed in + query.add("fl", "[fv]"); + assertJQ("/query" + query.toQueryString(), "/error/msg=='Exception from createWeight for SolrFeature [name=matchedTitle, params={q={!terms f=title}${user_query}}] SolrFeatureWeight requires efi parameter that was not passed in request.'"); + + // Adding efi in features section should make it work + query.remove("fl"); + query.add("fl", "score,fvalias:[fv store=fstore2 efi.myconf=2.3]"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fvalias=='confidence:2.3;originalScore:1.0'"); + + // Adding efi in transformer + rq should still use the transformer's params for feature extraction + query.remove("fl"); + query.add("fl", "score,fvalias:[fv store=fstore2 efi.myconf=2.3]"); + query.add("rq", "{!ltr reRankDocs=3 model=externalmodel efi.user_query=w3}"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fvalias=='confidence:2.3;originalScore:1.0'"); + } + + @Test + public void featureExtraction_valueFeatureImplicitlyNotRequired_shouldNotScoreFeature() throws Exception { + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("rows", "1"); + + // Efi is explicitly not required, so we do not score the feature + query.remove("fl"); + query.add("fl", "fvalias:[fv store=fstore2]"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fvalias=='originalScore:0.0'"); + } + + @Test + public void featureExtraction_valueFeatureExplicitlyNotRequired_shouldNotScoreFeature() throws Exception { + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("rows", "1"); + + // Efi is explicitly not required, so we do not score the feature + query.remove("fl"); + query.add("fl", "fvalias:[fv store=fstore3]"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fvalias=='originalScore:0.0'"); + } + + @Test + public void featureExtraction_valueFeatureRequired_shouldThrowException() throws Exception { + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("rows", "1"); + + // Using nondefault store should still result in error with no efi when it is required (myPop) + query.remove("fl"); + query.add("fl", "fvalias:[fv store=fstore4]"); + assertJQ("/query" + query.toQueryString(), "/error/msg=='Exception from createWeight for ValueFeature [name=popularity, params={value=${myPop}, required=true}] ValueFeatureWeight requires efi parameter that was not passed in request.'"); + } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestExternalValueFeatures.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestExternalValueFeatures.java new file mode 100644 index 00000000000..bc073cbf2c6 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestExternalValueFeatures.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestExternalValueFeatures extends TestRerankBase { + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema.xml"); + + assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity", + "1")); + assertU(adoc("id", "2", "title", "w2", "description", "w2", "popularity", + "2")); + assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity", + "3")); + assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity", + "4")); + assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", + "5")); + assertU(commit()); + + loadFeatures("external_features_for_sparse_processing.json"); + loadModels("multipleadditivetreesmodel_external_binary_features.json"); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void efiFeatureProcessing_oneEfiMissing_shouldNotCalculateMissingFeature() throws Exception { + SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*,score,features:[fv]"); + query.add("rows", "3"); + query.add("fl", "[fv]"); + query.add("rq", "{!ltr reRankDocs=3 model=external_model_binary_feature efi.user_device_tablet=1}"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/features=='user_device_tablet:1.0'"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/score==65.0"); + } + + @Test + public void efiFeatureProcessing_allEfisMissing_shouldReturnZeroScore() throws Exception { + SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*,score,features:[fv]"); + query.add("rows", "3"); + + query.add("fl", "[fv]"); + query + .add("rq", "{!ltr reRankDocs=3 model=external_model_binary_feature}"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/features==''"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/score==0.0"); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java new file mode 100644 index 00000000000..7658f62262f --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureExtractionFromMultipleSegments.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import java.util.List; +import java.util.Map; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.noggit.ObjectBuilder; + + +public class TestFeatureExtractionFromMultipleSegments extends TestRerankBase { + static final String AB = "abcdefghijklmnopqrstuvwxyz"; + + static String randomString( int len ){ + StringBuilder sb = new StringBuilder( len ); + for( int i = 0; i < len; i++ ) { + sb.append( AB.charAt( random().nextInt(AB.length()) ) ); + } + return sb.toString(); + } + + @BeforeClass + public static void before() throws Exception { + // solrconfig-multiseg.xml contains the merge policy to restrict merging + setuptest("solrconfig-multiseg.xml", "schema.xml"); + // index 400 documents + for(int i = 0; i<400;i=i+20) { + assertU(adoc("id", new Integer(i).toString(), "popularity", "201", "description", "apple is a company " + randomString(i%6+3), "normHits", "0.1")); + assertU(adoc("id", new Integer(i+1).toString(), "popularity", "201", "description", "d " + randomString(i%6+3), "normHits", "0.11")); + + assertU(adoc("id", new Integer(i+2).toString(), "popularity", "201", "description", "apple is a company too " + randomString(i%6+3), "normHits", "0.1")); + assertU(adoc("id", new Integer(i+3).toString(), "popularity", "201", "description", "new york city is big apple " + randomString(i%6+3), "normHits", "0.11")); + + assertU(adoc("id", new Integer(i+6).toString(), "popularity", "301", "description", "function name " + randomString(i%6+3), "normHits", "0.1")); + assertU(adoc("id", new Integer(i+7).toString(), "popularity", "301", "description", "function " + randomString(i%6+3), "normHits", "0.1")); + + assertU(adoc("id", new Integer(i+8).toString(), "popularity", "301", "description", "This is a sample function for testing " + randomString(i%6+3), "normHits", "0.1")); + assertU(adoc("id", new Integer(i+9).toString(), "popularity", "301", "description", "Function to check out stock prices "+randomString(i%6+3), "normHits", "0.1")); + assertU(adoc("id", new Integer(i+10).toString(),"popularity", "301", "description", "Some descriptions "+randomString(i%6+3), "normHits", "0.1")); + + assertU(adoc("id", new Integer(i+11).toString(), "popularity", "201", "description", "apple apple is a company " + randomString(i%6+3), "normHits", "0.1")); + assertU(adoc("id", new Integer(i+12).toString(), "popularity", "201", "description", "Big Apple is New York.", "normHits", "0.01")); + assertU(adoc("id", new Integer(i+13).toString(), "popularity", "201", "description", "New some York is Big. "+ randomString(i%6+3), "normHits", "0.1")); + + assertU(adoc("id", new Integer(i+14).toString(), "popularity", "201", "description", "apple apple is a company " + randomString(i%6+3), "normHits", "0.1")); + assertU(adoc("id", new Integer(i+15).toString(), "popularity", "201", "description", "Big Apple is New York.", "normHits", "0.01")); + assertU(adoc("id", new Integer(i+16).toString(), "popularity", "401", "description", "barack h", "normHits", "0.0")); + assertU(adoc("id", new Integer(i+17).toString(), "popularity", "201", "description", "red delicious apple " + randomString(i%6+3), "normHits", "0.1")); + assertU(adoc("id", new Integer(i+18).toString(), "popularity", "201", "description", "nyc " + randomString(i%6+3), "normHits", "0.11")); + } + + assertU(commit()); + + loadFeatures("comp_features.json"); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void testFeatureExtractionFromMultipleSegments() throws Exception { + + final SolrQuery query = new SolrQuery(); + query.setQuery("{!edismax qf='description^1' boost='sum(product(pow(normHits, 0.7), 1600), .1)' v='apple'}"); + // request 100 rows, if any rows are fetched from the second or subsequent segments the tests should succeed if LTRRescorer::extractFeaturesInfo() advances the doc iterator properly + int numRows = 100; + query.add("rows", (new Integer(numRows)).toString()); + query.add("wt", "json"); + query.add("fq", "popularity:201"); + query.add("fl", "*, score,id,normHits,description,fv:[features store='feature-store-6' format='dense' efi.user_text='apple']"); + String res = restTestHarness.query("/query" + query.toQueryString()); + + Map resultJson = (Map) ObjectBuilder.fromJSON(res); + + List> docs = (List>)((Map)resultJson.get("response")).get("docs"); + int passCount = 0; + for (final Map doc : docs) { + String features = (String)doc.get("fv"); + assert(features.length() > 0); + ++passCount; + } + assert(passCount == numRows); + } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureLogging.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureLogging.java new file mode 100644 index 00000000000..14e29031d55 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureLogging.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.model.LinearModel; +import org.apache.solr.ltr.store.FeatureStore; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestFeatureLogging extends TestRerankBase { + + @BeforeClass + public static void setup() throws Exception { + setuptest(); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void testGeneratedFeatures() throws Exception { + loadFeature("c1", ValueFeature.class.getCanonicalName(), "test1", + "{\"value\":1.0}"); + loadFeature("c2", ValueFeature.class.getCanonicalName(), "test1", + "{\"value\":2.0}"); + loadFeature("c3", ValueFeature.class.getCanonicalName(), "test1", + "{\"value\":3.0}"); + loadFeature("pop", FieldValueFeature.class.getCanonicalName(), "test1", + "{\"field\":\"popularity\"}"); + loadFeature("nomatch", SolrFeature.class.getCanonicalName(), "test1", + "{\"q\":\"{!terms f=title}foobarbat\"}"); + loadFeature("yesmatch", SolrFeature.class.getCanonicalName(), "test1", + "{\"q\":\"{!terms f=popularity}2\"}"); + + loadModel("sum1", LinearModel.class.getCanonicalName(), new String[] { + "c1", "c2", "c3"}, "test1", + "{\"weights\":{\"c1\":1.0,\"c2\":1.0,\"c3\":1.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:bloomberg"); + query.add("fl", "title,description,id,popularity,[fv]"); + query.add("rows", "3"); + query.add("debugQuery", "on"); + query.add("rq", "{!ltr reRankDocs=3 model=sum1}"); + + restTestHarness.query("/query" + query.toQueryString()); + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/=={'title':'bloomberg bloomberg ', 'description':'bloomberg','id':'7', 'popularity':2, '[fv]':'c1:1.0;c2:2.0;c3:3.0;pop:2.0;yesmatch:1.0'}"); + + query.remove("fl"); + query.add("fl", "[fv]"); + query.add("rows", "3"); + query.add("rq", "{!ltr reRankDocs=3 model=sum1}"); + + restTestHarness.query("/query" + query.toQueryString()); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/=={'[fv]':'c1:1.0;c2:2.0;c3:3.0;pop:2.0;yesmatch:1.0'}"); + query.remove("rq"); + + // set logging at false but still asking for feature, and it should work anyway + query.add("rq", "{!ltr reRankDocs=3 model=sum1}"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/=={'[fv]':'c1:1.0;c2:2.0;c3:3.0;pop:2.0;yesmatch:1.0'}"); + + + } + + @Test + public void testDefaultStoreFeatureExtraction() throws Exception { + loadFeature("defaultf1", ValueFeature.class.getCanonicalName(), + FeatureStore.DEFAULT_FEATURE_STORE_NAME, + "{\"value\":1.0}"); + loadFeature("store8f1", ValueFeature.class.getCanonicalName(), + "store8", + "{\"value\":2.0}"); + loadFeature("store9f1", ValueFeature.class.getCanonicalName(), + "store9", + "{\"value\":3.0}"); + loadModel("store9m1", LinearModel.class.getCanonicalName(), + new String[] {"store9f1"}, + "store9", + "{\"weights\":{\"store9f1\":1.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("id:7"); + query.add("rows", "1"); + + // No store specified, use default store for extraction + query.add("fl", "fv:[fv]"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/=={'fv':'defaultf1:1.0'}"); + + // Store specified, use store for extraction + query.remove("fl"); + query.add("fl", "fv:[fv store=store8]"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/=={'fv':'store8f1:2.0'}"); + + // Store specified + model specified, use store for extraction + query.add("rq", "{!ltr reRankDocs=3 model=store9m1}"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/=={'fv':'store8f1:2.0'}"); + + // No store specified + model specified, use model store for extraction + query.remove("fl"); + query.add("fl", "fv:[fv]"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/=={'fv':'store9f1:3.0'}"); + } + + + @Test + public void testGeneratedGroup() throws Exception { + loadFeature("c1", ValueFeature.class.getCanonicalName(), "testgroup", + "{\"value\":1.0}"); + loadFeature("c2", ValueFeature.class.getCanonicalName(), "testgroup", + "{\"value\":2.0}"); + loadFeature("c3", ValueFeature.class.getCanonicalName(), "testgroup", + "{\"value\":3.0}"); + loadFeature("pop", FieldValueFeature.class.getCanonicalName(), "testgroup", + "{\"field\":\"popularity\"}"); + + loadModel("sumgroup", LinearModel.class.getCanonicalName(), new String[] { + "c1", "c2", "c3"}, "testgroup", + "{\"weights\":{\"c1\":1.0,\"c2\":1.0,\"c3\":1.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:bloomberg"); + query.add("fl", "*,[fv]"); + query.add("debugQuery", "on"); + + query.remove("fl"); + query.add("fl", "fv:[fv]"); + query.add("rows", "3"); + query.add("group", "true"); + query.add("group.field", "title"); + + query.add("rq", "{!ltr reRankDocs=3 model=sumgroup}"); + + restTestHarness.query("/query" + query.toQueryString()); + assertJQ( + "/query" + query.toQueryString(), + "/grouped/title/groups/[0]/doclist/docs/[0]/=={'fv':'c1:1.0;c2:2.0;c3:3.0;pop:5.0'}"); + + query.remove("fl"); + query.add("fl", "fv:[fv fvwt=json]"); + restTestHarness.query("/query" + query.toQueryString()); + assertJQ( + "/query" + query.toQueryString(), + "/grouped/title/groups/[0]/doclist/docs/[0]/fv/=={'c1':1.0,'c2':2.0,'c3':3.0,'pop':5.0}"); + query.remove("fl"); + query.add("fl", "fv:[fv fvwt=json]"); + + assertJQ( + "/query" + query.toQueryString(), + "/grouped/title/groups/[0]/doclist/docs/[0]/fv/=={'c1':1.0,'c2':2.0,'c3':3.0,'pop':5.0}"); + } + + @Test + public void testSparseDenseFeatures() throws Exception { + loadFeature("match", SolrFeature.class.getCanonicalName(), "test4", + "{\"q\":\"{!terms f=title}different\"}"); + loadFeature("c4", ValueFeature.class.getCanonicalName(), "test4", + "{\"value\":1.0}"); + + loadModel("sum4", LinearModel.class.getCanonicalName(), new String[] { + "match"}, "test4", + "{\"weights\":{\"match\":1.0}}"); + + //json - no feature format check (default to sparse) + final SolrQuery query = new SolrQuery(); + query.setQuery("title:bloomberg"); + query.add("rows", "10"); + query.add("fl", "*,score,fv:[fv store=test4 fvwt=json]"); + query.add("rq", "{!ltr reRankDocs=10 model=sum4}"); + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/fv/=={'match':1.0,'c4':1.0}"); + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[1]/fv/=={'c4':1.0}"); + + //json - sparse feature format check + query.remove("fl"); + query.add("fl", "*,score,fv:[fv store=test4 format=sparse fvwt=json]"); + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/fv/=={'match':1.0,'c4':1.0}"); + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[1]/fv/=={'c4':1.0}"); + + //json - dense feature format check + query.remove("fl"); + query.add("fl", "*,score,fv:[fv store=test4 format=dense fvwt=json]"); + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/fv/=={'match':1.0,'c4':1.0}"); + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[1]/fv/=={'match':0.0,'c4':1.0}"); + + //csv - no feature format check (default to sparse) + query.remove("fl"); + query.add("fl", "*,score,fv:[fv store=test4 fvwt=csv]"); + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/fv/=='match:1.0;c4:1.0'"); + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[1]/fv/=='c4:1.0'"); + + //csv - sparse feature format check + query.remove("fl"); + query.add("fl", "*,score,fv:[fv store=test4 format=sparse fvwt=csv]"); + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/fv/=='match:1.0;c4:1.0'"); + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[1]/fv/=='c4:1.0'"); + + //csv - dense feature format check + query.remove("fl"); + query.add("fl", "*,score,fv:[fv store=test4 format=dense fvwt=csv]"); + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[0]/fv/=='match:1.0;c4:1.0'"); + assertJQ( + "/query" + query.toQueryString(), + "/response/docs/[1]/fv/=='match:0.0;c4:1.0'"); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureLtrScoringModel.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureLtrScoringModel.java new file mode 100644 index 00000000000..5fcebad884b --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureLtrScoringModel.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.store.rest.ManagedFeatureStore; +import org.apache.solr.ltr.store.rest.TestManagedFeatureStore; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestFeatureLtrScoringModel extends TestRerankBase { + + static ManagedFeatureStore store = null; + + @BeforeClass + public static void setup() throws Exception { + setuptest(); + store = getManagedFeatureStore(); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void getInstanceTest() throws FeatureException + { + store.addFeature(TestManagedFeatureStore.createMap("test", + OriginalScoreFeature.class.getCanonicalName(), null), + "testFstore"); + final Feature feature = store.getFeatureStore("testFstore").get("test"); + assertNotNull(feature); + assertEquals("test", feature.getName()); + assertEquals(OriginalScoreFeature.class.getCanonicalName(), feature + .getClass().getCanonicalName()); + } + + @Test + public void getInvalidInstanceTest() + { + final String nonExistingClassName = "org.apache.solr.ltr.feature.LOLFeature"; + final ClassNotFoundException expectedException = + new ClassNotFoundException(nonExistingClassName); + try { + store.addFeature(TestManagedFeatureStore.createMap("test", + nonExistingClassName, null), + "testFstore2"); + fail("getInvalidInstanceTest failed to throw exception: "+expectedException); + } catch (Exception actualException) { + Throwable rootError = getRootCause(actualException); + assertEquals(expectedException.toString(), rootError.toString()); + } + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureStore.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureStore.java new file mode 100644 index 00000000000..0ed0cdac3a2 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFeatureStore.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import java.util.HashMap; +import java.util.Map; + +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.store.FeatureStore; +import org.apache.solr.ltr.store.rest.ManagedFeatureStore; +import org.apache.solr.ltr.store.rest.TestManagedFeatureStore; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestFeatureStore extends TestRerankBase { + + static ManagedFeatureStore fstore = null; + + @BeforeClass + public static void setup() throws Exception { + setuptest(); + fstore = getManagedFeatureStore(); + } + + @Test + public void testDefaultFeatureStoreName() + { + assertEquals("_DEFAULT_", FeatureStore.DEFAULT_FEATURE_STORE_NAME); + final FeatureStore expectedFeatureStore = fstore.getFeatureStore(FeatureStore.DEFAULT_FEATURE_STORE_NAME); + final FeatureStore actualFeatureStore = fstore.getFeatureStore(null); + assertEquals("getFeatureStore(null) should return the default feature store", expectedFeatureStore, actualFeatureStore); + } + + @Test + public void testFeatureStoreAdd() throws FeatureException + { + final FeatureStore fs = fstore.getFeatureStore("fstore-testFeature"); + for (int i = 0; i < 5; i++) { + final String name = "c" + i; + + fstore.addFeature(TestManagedFeatureStore.createMap(name, + OriginalScoreFeature.class.getCanonicalName(), null), + "fstore-testFeature"); + + final Feature f = fs.get(name); + assertNotNull(f); + + } + assertEquals(5, fs.getFeatures().size()); + + } + + @Test + public void testFeatureStoreGet() throws FeatureException + { + final FeatureStore fs = fstore.getFeatureStore("fstore-testFeature2"); + for (int i = 0; i < 5; i++) { + Map params = new HashMap(); + params.put("value", i); + final String name = "c" + i; + + fstore.addFeature(TestManagedFeatureStore.createMap(name, + ValueFeature.class.getCanonicalName(), params), + "fstore-testFeature2"); + + } + + for (int i = 0; i < 5; i++) { + final Feature f = fs.get("c" + i); + assertEquals("c" + i, f.getName()); + assertTrue(f instanceof ValueFeature); + final ValueFeature vf = (ValueFeature)f; + assertEquals(i, vf.getValue()); + } + } + + @Test + public void testMissingFeatureReturnsNull() { + final FeatureStore fs = fstore.getFeatureStore("fstore-testFeature3"); + for (int i = 0; i < 5; i++) { + Map params = new HashMap(); + params.put("value", i); + final String name = "testc" + (float) i; + fstore.addFeature(TestManagedFeatureStore.createMap(name, + ValueFeature.class.getCanonicalName(), params), + "fstore-testFeature3"); + + } + assertNull(fs.get("missing_feature_name")); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFieldLengthFeature.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFieldLengthFeature.java new file mode 100644 index 00000000000..4a0d4490a73 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFieldLengthFeature.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.model.LinearModel; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestFieldLengthFeature extends TestRerankBase { + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema.xml"); + + assertU(adoc("id", "1", "title", "w1", "description", "w1")); + assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description", + "w2 2asd asdd didid")); + assertU(adoc("id", "3", "title", "w3", "description", "w3")); + assertU(adoc("id", "4", "title", "w4", "description", "w4")); + assertU(adoc("id", "5", "title", "w5", "description", "w5")); + assertU(adoc("id", "6", "title", "w1 w2", "description", "w1 w2")); + assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description", + "w1 w2 w3 w4 w5 w8")); + assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description", + "w1 w1 w1 w2 w2")); + assertU(commit()); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void testIfFieldIsMissingInDocumentLengthIsZero() throws Exception { + // add a document without the field 'description' + assertU(adoc("id", "42", "title", "w10")); + assertU(commit()); + + loadFeature("description-length2", FieldLengthFeature.class.getCanonicalName(), + "{\"field\":\"description\"}"); + + loadModel("description-model2", LinearModel.class.getCanonicalName(), + new String[] {"description-length2"}, "{\"weights\":{\"description-length2\":1.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:w10"); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("rq", "{!ltr model=description-model2 reRankDocs=8}"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==0.0"); + } + + + @Test + public void testIfFieldIsEmptyLengthIsZero() throws Exception { + // add a document without the field 'description' + assertU(adoc("id", "43", "title", "w11", "description", "")); + assertU(commit()); + + loadFeature("description-length3", FieldLengthFeature.class.getCanonicalName(), + "{\"field\":\"description\"}"); + + loadModel("description-model3", LinearModel.class.getCanonicalName(), + new String[] {"description-length3"}, "{\"weights\":{\"description-length3\":1.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:w11"); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("rq", "{!ltr model=description-model3 reRankDocs=8}"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==0.0"); + } + + + @Test + public void testRanking() throws Exception { + loadFeature("title-length", FieldLengthFeature.class.getCanonicalName(), + "{\"field\":\"title\"}"); + + loadModel("title-model", LinearModel.class.getCanonicalName(), + new String[] {"title-length"}, "{\"weights\":{\"title-length\":1.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score"); + query.add("rows", "4"); + + // Normal term match + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='7'"); + // Normal term match + + query.add("rq", "{!ltr model=title-model reRankDocs=4}"); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='7'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='1'"); + + query.setQuery("*:*"); + query.remove("rows"); + query.add("rows", "8"); + query.remove("rq"); + query.add("rq", "{!ltr model=title-model reRankDocs=8}"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='7'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='2'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='6'"); + + loadFeature("description-length", + FieldLengthFeature.class.getCanonicalName(), + "{\"field\":\"description\"}"); + loadModel("description-model", LinearModel.class.getCanonicalName(), + new String[] {"description-length"}, + "{\"weights\":{\"description-length\":1.0}}"); + query.setQuery("title:w1"); + query.remove("rq"); + query.remove("rows"); + query.add("rows", "4"); + query.add("rq", "{!ltr model=description-model reRankDocs=4}"); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='7'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='1'"); + } + + + + + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFieldValueFeature.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFieldValueFeature.java new file mode 100644 index 00000000000..af150c060e4 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFieldValueFeature.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.model.LinearModel; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestFieldValueFeature extends TestRerankBase { + + private static final float FIELD_VALUE_FEATURE_DEFAULT_VAL = 0.0f; + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema.xml"); + + assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity", + "1")); + assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description", + "w2 2asd asdd didid", "popularity", "2")); + assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity", + "3")); + assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity", + "4")); + assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", + "5")); + assertU(adoc("id", "6", "title", "w1 w2", "description", "w1 w2", + "popularity", "6")); + assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description", + "w1 w2 w3 w4 w5 w8", "popularity", "7")); + assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description", + "w1 w1 w1 w2 w2", "popularity", "8")); + + // a document without the popularity field + assertU(adoc("id", "42", "title", "NO popularity", "description", "NO popularity")); + + assertU(commit()); + + loadFeature("popularity", FieldValueFeature.class.getCanonicalName(), + "{\"field\":\"popularity\"}"); + + loadModel("popularity-model", LinearModel.class.getCanonicalName(), + new String[] {"popularity"}, "{\"weights\":{\"popularity\":1.0}}"); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void testRanking() throws Exception { + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score"); + query.add("rows", "4"); + + // Normal term match + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='7'"); + + query.add("rq", "{!ltr model=popularity-model reRankDocs=4}"); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='7'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='1'"); + + query.setQuery("*:*"); + query.remove("rows"); + query.add("rows", "8"); + query.remove("rq"); + query.add("rq", "{!ltr model=popularity-model reRankDocs=8}"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='7'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='5'"); + } + + + @Test + public void testIfADocumentDoesntHaveAFieldDefaultValueIsReturned() throws Exception { + SolrQuery query = new SolrQuery(); + query.setQuery("id:42"); + query.add("fl", "*, score"); + query.add("rows", "4"); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==1"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='42'"); + query = new SolrQuery(); + query.setQuery("id:42"); + query.add("rq", "{!ltr model=popularity-model reRankDocs=4}"); + query.add("fl", "[fv]"); + assertJQ("/query" + query.toQueryString(), "/response/numFound/==1"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/=={'[fv]':'popularity:"+FIELD_VALUE_FEATURE_DEFAULT_VAL+"'}"); + + } + + + @Test + public void testIfADocumentDoesntHaveAFieldASetDefaultValueIsReturned() throws Exception { + + final String fstore = "testIfADocumentDoesntHaveAFieldASetDefaultValueIsReturned"; + + loadFeature("popularity42", FieldValueFeature.class.getCanonicalName(), fstore, + "{\"field\":\"popularity\",\"defaultValue\":\"42.0\"}"); + + SolrQuery query = new SolrQuery(); + query.setQuery("id:42"); + query.add("fl", "*, score"); + query.add("rows", "4"); + + loadModel("popularity-model42", LinearModel.class.getCanonicalName(), + new String[] {"popularity42"}, fstore, "{\"weights\":{\"popularity42\":1.0}}"); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==1"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='42'"); + query = new SolrQuery(); + query.setQuery("id:42"); + query.add("rq", "{!ltr model=popularity-model42 reRankDocs=4}"); + query.add("fl", "[fv]"); + assertJQ("/query" + query.toQueryString(), "/response/numFound/==1"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/=={'[fv]':'popularity42:42.0'}"); + + } + + @Test + public void testThatIfaFieldDoesNotExistDefaultValueIsReturned() throws Exception { + // using a different fstore to avoid a clash with the other tests + final String fstore = "testThatIfaFieldDoesNotExistDefaultValueIsReturned"; + loadFeature("not-existing-field", FieldValueFeature.class.getCanonicalName(), fstore, + "{\"field\":\"cowabunga\"}"); + + loadModel("not-existing-field-model", LinearModel.class.getCanonicalName(), + new String[] {"not-existing-field"}, fstore, "{\"weights\":{\"not-existing-field\":1.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("id:42"); + query.add("rq", "{!ltr model=not-existing-field-model reRankDocs=4}"); + query.add("fl", "[fv]"); + assertJQ("/query" + query.toQueryString(), "/response/numFound/==1"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/=={'[fv]':'not-existing-field:"+FIELD_VALUE_FEATURE_DEFAULT_VAL+"'}"); + + } + + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFilterSolrFeature.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFilterSolrFeature.java new file mode 100644 index 00000000000..14baefaaf31 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestFilterSolrFeature.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.model.LinearModel; +import org.apache.solr.ltr.store.rest.ManagedFeatureStore; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestFilterSolrFeature extends TestRerankBase { + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema.xml"); + + assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity", + "1")); + assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description", + "w2 2asd asdd didid", "popularity", "2")); + assertU(adoc("id", "3", "title", "w1", "description", "w1", "popularity", + "3")); + assertU(adoc("id", "4", "title", "w1", "description", "w1", "popularity", + "4")); + assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", + "5")); + assertU(adoc("id", "6", "title", "w6 w2", "description", "w1 w2", + "popularity", "6")); + assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description", + "w6 w2 w3 w4 w5 w8", "popularity", "88888")); + assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description", + "w1 w1 w1 w2 w2", "popularity", "88888")); + assertU(commit()); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void testUserTermScoreWithFQ() throws Exception { + loadFeature("SomeTermFQ", SolrFeature.class.getCanonicalName(), + "{\"fq\":[\"{!terms f=popularity}88888\"]}"); + loadFeature("SomeEfiFQ", SolrFeature.class.getCanonicalName(), + "{\"fq\":[\"{!terms f=title}${user_query}\"]}"); + loadModel("Term-modelFQ", LinearModel.class.getCanonicalName(), + new String[] {"SomeTermFQ", "SomeEfiFQ"}, + "{\"weights\":{\"SomeTermFQ\":1.6, \"SomeEfiFQ\":2.0}}"); + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*, score"); + query.add("rows", "3"); + query.add("fq", "{!terms f=title}w1"); + query.add("rq", + "{!ltr model=Term-modelFQ reRankDocs=5 efi.user_query='w5'}"); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==5"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==3.6"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==1.6"); + } + + @Test + public void testBadFeature() throws Exception { + // Missing q/fq + final String feature = getFeatureInJson("badFeature", "test", + SolrFeature.class.getCanonicalName(), "{\"df\":\"foo\"]}"); + assertJPut(ManagedFeatureStore.REST_END_POINT, feature, + "/responseHeader/status==500"); + } + + @Test + public void testFeatureNotEqualWhenNormalizerDifferent() throws Exception { + loadFeatures("fq_features.json"); // features that use filter query + loadModels("fq-model.json"); // model that uses filter query features + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*,score"); + query.add("rows", "4"); + + query.add("rq", "{!ltr reRankDocs=4 model=fqmodel efi.user_query=w2}"); + query.add("fl", "fv:[fv]"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='2'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='3'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fv=='matchedTitle:1.0;popularity:3.0'"); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestNoMatchSolrFeature.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestNoMatchSolrFeature.java new file mode 100644 index 00000000000..57126877ac6 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestNoMatchSolrFeature.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import java.util.ArrayList; +import java.util.Map; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.model.LinearModel; +import org.apache.solr.ltr.model.MultipleAdditiveTreesModel; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.noggit.ObjectBuilder; + +public class TestNoMatchSolrFeature extends TestRerankBase { + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema.xml"); + + assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity", + "1")); + assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description", + "w2 2asd asdd didid", "popularity", "2")); + assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity", + "3")); + assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity", + "4")); + assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", + "5")); + assertU(adoc("id", "6", "title", "w1 w2", "description", "w1 w2", + "popularity", "6")); + assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description", + "w1 w2 w3 w4 w5 w8", "popularity", "7")); + assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description", + "w1 w1 w1 w2 w2", "popularity", "8")); + assertU(commit()); + + loadFeature("nomatchfeature", SolrFeature.class.getCanonicalName(), + "{\"q\":\"foobarbat12345\",\"df\":\"title\"}"); + loadFeature("yesmatchfeature", SolrFeature.class.getCanonicalName(), + "{\"q\":\"w1\",\"df\":\"title\"}"); + loadFeature("nomatchfeature2", SolrFeature.class.getCanonicalName(), + "{\"q\":\"foobarbat12345\",\"df\":\"title\"}"); + loadModel( + "nomatchmodel", + LinearModel.class.getCanonicalName(), + new String[] {"nomatchfeature", "yesmatchfeature", "nomatchfeature2"}, + "{\"weights\":{\"nomatchfeature\":1.0,\"yesmatchfeature\":1.1,\"nomatchfeature2\":1.1}}"); + + loadFeature("nomatchfeature3", SolrFeature.class.getCanonicalName(), + "{\"q\":\"foobarbat12345\",\"df\":\"title\"}"); + loadModel("nomatchmodel2", LinearModel.class.getCanonicalName(), + new String[] {"nomatchfeature3"}, + "{\"weights\":{\"nomatchfeature3\":1.0}}"); + + loadFeature("nomatchfeature4", SolrFeature.class.getCanonicalName(), + "noMatchFeaturesStore", "{\"q\":\"foobarbat12345\",\"df\":\"title\"}"); + loadModel("nomatchmodel3", LinearModel.class.getCanonicalName(), + new String[] {"nomatchfeature4"}, "noMatchFeaturesStore", + "{\"weights\":{\"nomatchfeature4\":1.0}}"); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void test2NoMatch1YesMatchFeatureReturnsFvWith1FeatureAndDocScoreScaledByModel() throws Exception { + // Tests model with all no matching features but 1 + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*, score,fv:[fv]"); + query.add("rows", "4"); + query.add("fv", "true"); + query.add("rq", "{!ltr model=nomatchmodel reRankDocs=4}"); + + final SolrQuery yesMatchFeatureQuery = new SolrQuery(); + yesMatchFeatureQuery.setQuery("title:w1"); + yesMatchFeatureQuery.add("fl", "score"); + yesMatchFeatureQuery.add("rows", "4"); + String res = restTestHarness.query("/query" + + yesMatchFeatureQuery.toQueryString()); + + final Map jsonParse = (Map) ObjectBuilder + .fromJSON(res); + final Double doc0Score = (Double) ((Map) ((ArrayList) ((Map) jsonParse + .get("response")).get("docs")).get(0)).get("score"); + + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==" + + (doc0Score * 1.1)); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/fv=='yesmatchfeature:" + doc0Score + "'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='2'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==0.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/fv==''"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='3'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==0.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/fv==''"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='4'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/score==0.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/fv==''"); + } + + @Test + public void test1NoMatchFeatureReturnsFvWith1MatchingFeatureFromStoreAndDocWith0Score() throws Exception { + // Tests model with all no matching features, but 1 feature store feature matching for extraction + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*, score,fv:[fv]"); + query.add("rows", "4"); + query.add("rq", "{!ltr model=nomatchmodel2 reRankDocs=4}"); + + final SolrQuery yesMatchFeatureQuery = new SolrQuery(); + yesMatchFeatureQuery.setQuery("title:w1"); + yesMatchFeatureQuery.add("fl", "score"); + yesMatchFeatureQuery.add("rows", "4"); + String res = restTestHarness.query("/query" + + yesMatchFeatureQuery.toQueryString()); + + final Map jsonParse = (Map) ObjectBuilder + .fromJSON(res); + final Double doc0Score = (Double) ((Map) ((ArrayList) ((Map) jsonParse + .get("response")).get("docs")).get(0)).get("score"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==0.0"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/fv=='yesmatchfeature:" + doc0Score + "'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==0.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/fv==''"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==0.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/fv==''"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/score==0.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/fv==''"); + } + + @Test + public void tesOnlyNoMatchFeaturesInStoreAndModelReturnsZeroScore() throws Exception { + // Tests model with all no matching features + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*, score,fv:[fv]"); + query.add("rows", "4"); + query.add("fv", "true"); + query.add("rq", "{!ltr model=nomatchmodel3 reRankDocs=4}"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==0.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fv==''"); + } + + @Test + public void tesOnlyNoMatchFeaturesInStoreAndModelReturnsNonzeroScore() throws Exception { + // Tests model with all no matching features but expects a non 0 score + // MultipleAdditiveTrees will return scores even for docs without any feature matches + loadModel( + "nomatchmodel4", + MultipleAdditiveTreesModel.class.getCanonicalName(), + new String[] {"nomatchfeature4"}, + "noMatchFeaturesStore", + "{\"trees\":[{\"weight\":\"1f\", \"root\":{\"feature\": \"matchedTitle\",\"threshold\": \"0.5f\",\"left\":{\"value\" : \"-10\"},\"right\":{\"value\" : \"9\"}}}]}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*, score,fv:[fv]"); + query.add("rows", "4"); + query.add("rq", "{!ltr model=nomatchmodel4 reRankDocs=4}"); + + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/score==0.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/fv==''"); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestOriginalScoreFeature.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestOriginalScoreFeature.java new file mode 100644 index 00000000000..e525891f3d7 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestOriginalScoreFeature.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import java.util.ArrayList; +import java.util.Map; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.model.LinearModel; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.noggit.ObjectBuilder; + +public class TestOriginalScoreFeature extends TestRerankBase { + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema.xml"); + + assertU(adoc("id", "1", "title", "w1")); + assertU(adoc("id", "2", "title", "w2")); + assertU(adoc("id", "3", "title", "w3")); + assertU(adoc("id", "4", "title", "w4")); + assertU(adoc("id", "5", "title", "w5")); + assertU(adoc("id", "6", "title", "w1 w2")); + assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5")); + assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2")); + assertU(commit()); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void testOriginalScore() throws Exception { + loadFeature("score", OriginalScoreFeature.class.getCanonicalName(), "{}"); + + loadModel("originalScore", LinearModel.class.getCanonicalName(), + new String[] {"score"}, "{\"weights\":{\"score\":1.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("wt", "json"); + + // Normal term match + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='7'"); + + final String res = restTestHarness.query("/query" + query.toQueryString()); + final Map jsonParse = (Map) ObjectBuilder + .fromJSON(res); + final String doc0Score = ((Double) ((Map) ((ArrayList) ((Map) jsonParse + .get("response")).get("docs")).get(0)).get("score")).toString(); + final String doc1Score = ((Double) ((Map) ((ArrayList) ((Map) jsonParse + .get("response")).get("docs")).get(1)).get("score")).toString(); + final String doc2Score = ((Double) ((Map) ((ArrayList) ((Map) jsonParse + .get("response")).get("docs")).get(2)).get("score")).toString(); + final String doc3Score = ((Double) ((Map) ((ArrayList) ((Map) jsonParse + .get("response")).get("docs")).get(3)).get("score")).toString(); + + query.add("fl", "[fv]"); + query.add("rq", "{!ltr model=originalScore reRankDocs=4}"); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==" + + doc0Score); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==" + + doc1Score); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==" + + doc2Score); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='7'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/score==" + + doc3Score); + } + + @Test + public void testOriginalScoreWithNonScoringFeatures() throws Exception { + loadFeature("origScore", OriginalScoreFeature.class.getCanonicalName(), + "store2", "{}"); + loadFeature("c2", ValueFeature.class.getCanonicalName(), "store2", + "{\"value\":2.0}"); + + loadModel("origScore", LinearModel.class.getCanonicalName(), + new String[] {"origScore"}, "store2", + "{\"weights\":{\"origScore\":1.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score, fv:[fv]"); + query.add("rows", "4"); + query.add("wt", "json"); + query.add("rq", "{!ltr model=origScore reRankDocs=4}"); + + final String res = restTestHarness.query("/query" + query.toQueryString()); + final Map jsonParse = (Map) ObjectBuilder + .fromJSON(res); + final String doc0Score = ((Double) ((Map) ((ArrayList) ((Map) jsonParse + .get("response")).get("docs")).get(0)).get("score")).toString(); + final String doc1Score = ((Double) ((Map) ((ArrayList) ((Map) jsonParse + .get("response")).get("docs")).get(1)).get("score")).toString(); + final String doc2Score = ((Double) ((Map) ((ArrayList) ((Map) jsonParse + .get("response")).get("docs")).get(2)).get("score")).toString(); + final String doc3Score = ((Double) ((Map) ((ArrayList) ((Map) jsonParse + .get("response")).get("docs")).get(3)).get("score")).toString(); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[0]/fv=='origScore:" + doc0Score + ";c2:2.0'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='8'"); + + assertJQ("/query" + query.toQueryString(), + "/response/docs/[1]/fv=='origScore:" + doc1Score + ";c2:2.0'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[2]/fv=='origScore:" + doc2Score + ";c2:2.0'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='7'"); + assertJQ("/query" + query.toQueryString(), + "/response/docs/[3]/fv=='origScore:" + doc3Score + ";c2:2.0'"); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestRankingFeature.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestRankingFeature.java new file mode 100644 index 00000000000..437e10d2558 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestRankingFeature.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.model.LinearModel; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + + +public class TestRankingFeature extends TestRerankBase { + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema.xml"); + + assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity", + "1")); + assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description", + "w2 2asd asdd didid", "popularity", "2")); + assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity", + "3")); + assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity", + "4")); + assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", + "5")); + assertU(adoc("id", "6", "title", "w1 w2", "description", "w1 w2", + "popularity", "6")); + assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description", + "w1 w2 w3 w4 w5 w8", "popularity", "7")); + assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description", + "w1 w1 w1 w2 w2", "popularity", "8")); + assertU(commit()); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void testRankingSolrFeature() throws Exception { + // before(); + loadFeature("powpularityS", SolrFeature.class.getCanonicalName(), + "{\"q\":\"{!func}pow(popularity,2)\"}"); + loadFeature("unpopularityS", SolrFeature.class.getCanonicalName(), + "{\"q\":\"{!func}div(1,popularity)\"}"); + + loadModel("powpularityS-model", LinearModel.class.getCanonicalName(), + new String[] {"powpularityS"}, "{\"weights\":{\"powpularityS\":1.0}}"); + loadModel("unpopularityS-model", LinearModel.class.getCanonicalName(), + new String[] {"unpopularityS"}, "{\"weights\":{\"unpopularityS\":1.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score"); + query.add("rows", "4"); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='7'"); + // Normal term match + + query.add("rq", "{!ltr model=powpularityS-model reRankDocs=4}"); + query.set("debugQuery", "on"); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='8'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==64.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='7'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==49.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==36.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/score==1.0"); + + query.remove("rq"); + query.add("rq", "{!ltr model=unpopularityS-model reRankDocs=4}"); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==1.0"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='6'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='7'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='8'"); + + //bad solr ranking feature + loadFeature("powdesS", SolrFeature.class.getCanonicalName(), + "{\"q\":\"{!func}pow(description,2)\"}"); + loadModel("powdesS-model", LinearModel.class.getCanonicalName(), + new String[] {"powdesS"}, "{\"weights\":{\"powdesS\":1.0}}"); + + query.remove("rq"); + query.add("rq", "{!ltr model=powdesS-model reRankDocs=4}"); + + assertJQ("/query" + query.toQueryString(), + "/error/msg/=='"+FeatureException.class.getCanonicalName()+": " + + "java.lang.UnsupportedOperationException: " + + "Unable to extract feature for powdesS'"); + // aftertest(); + + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestUserTermScoreWithQ.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestUserTermScoreWithQ.java new file mode 100644 index 00000000000..754409a658c --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestUserTermScoreWithQ.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.model.LinearModel; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestUserTermScoreWithQ extends TestRerankBase { + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema.xml"); + + assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity", + "1")); + assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description", + "w2 2asd asdd didid", "popularity", "2")); + assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity", + "3")); + assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity", + "4")); + assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", + "5")); + assertU(adoc("id", "6", "title", "w1 w2", "description", "w1 w2", + "popularity", "6")); + assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description", + "w1 w2 w3 w4 w5 w8", "popularity", "7")); + assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description", + "w1 w1 w1 w2 w2", "popularity", "8")); + assertU(commit()); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void testUserTermScoreWithQ() throws Exception { + // before(); + loadFeature("SomeTermQ", SolrFeature.class.getCanonicalName(), + "{\"q\":\"{!terms f=popularity}88888\"}"); + loadModel("Term-modelQ", LinearModel.class.getCanonicalName(), + new String[] {"SomeTermQ"}, "{\"weights\":{\"SomeTermQ\":1.0}}"); + final SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("rq", "{!ltr model=Term-modelQ reRankDocs=4}"); + query.set("debugQuery", "on"); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==0.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==0.0"); + } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestUserTermScorerQuery.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestUserTermScorerQuery.java new file mode 100644 index 00000000000..c79207c644e --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestUserTermScorerQuery.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.model.LinearModel; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestUserTermScorerQuery extends TestRerankBase { + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema.xml"); + + assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity", + "1")); + assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description", + "w2 2asd asdd didid", "popularity", "2")); + assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity", + "3")); + assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity", + "4")); + assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", + "5")); + assertU(adoc("id", "6", "title", "w1 w2", "description", "w1 w2", + "popularity", "6")); + assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description", + "w1 w2 w3 w4 w5 w8", "popularity", "7")); + assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description", + "w1 w1 w1 w2 w2", "popularity", "8")); + assertU(commit()); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void testUserTermScorerQuery() throws Exception { + // before(); + loadFeature("matchedTitleDFExt", SolrFeature.class.getCanonicalName(), + "{\"q\":\"${user_query}\",\"df\":\"title\"}"); + loadModel("Term-matchedTitleDFExt", LinearModel.class.getCanonicalName(), + new String[] {"matchedTitleDFExt"}, + "{\"weights\":{\"matchedTitleDFExt\":1.1}}"); + final SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("rq", + "{!ltr model=Term-matchedTitleDFExt reRankDocs=4 efi.user_query=w8}"); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='8'"); + } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestUserTermScorereQDF.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestUserTermScorereQDF.java new file mode 100644 index 00000000000..f47a883c37a --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestUserTermScorereQDF.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.model.LinearModel; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestUserTermScorereQDF extends TestRerankBase { + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema.xml"); + + assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity", + "1")); + assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description", + "w2 2asd asdd didid", "popularity", "2")); + assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity", + "3")); + assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity", + "4")); + assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", + "5")); + assertU(adoc("id", "6", "title", "w1 w2", "description", "w1 w2", + "popularity", "6")); + assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description", + "w1 w2 w3 w4 w5 w8", "popularity", "7")); + assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description", + "w1 w1 w1 w2 w2", "popularity", "8")); + assertU(commit()); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void testUserTermScorerQWithDF() throws Exception { + // before(); + loadFeature("matchedTitleDF", SolrFeature.class.getCanonicalName(), + "{\"q\":\"w5\",\"df\":\"title\"}"); + loadModel("Term-matchedTitleDF", LinearModel.class.getCanonicalName(), + new String[] {"matchedTitleDF"}, + "{\"weights\":{\"matchedTitleDF\":1.0}}"); + final SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score"); + query.add("rows", "2"); + query.add("rq", "{!ltr model=Term-matchedTitleDF reRankDocs=4}"); + query.set("debugQuery", "on"); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='7'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==0.0"); + } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestValueFeature.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestValueFeature.java new file mode 100644 index 00000000000..084da4a3695 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestValueFeature.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.feature; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.model.LinearModel; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestValueFeature extends TestRerankBase { + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema.xml"); + + assertU(adoc("id", "1", "title", "w1")); + assertU(adoc("id", "2", "title", "w2")); + assertU(adoc("id", "3", "title", "w3")); + assertU(adoc("id", "4", "title", "w4")); + assertU(adoc("id", "5", "title", "w5")); + assertU(adoc("id", "6", "title", "w1 w2")); + assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5")); + assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2")); + assertU(commit()); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + @Test + public void testValueFeatureWithEmptyValue() throws Exception { + final RuntimeException expectedException = + new RuntimeException("mismatch: '0'!='500' @ responseHeader/status"); + try { + loadFeature("c2", ValueFeature.class.getCanonicalName(), "{\"value\":\"\"}"); + fail("testValueFeatureWithEmptyValue failed to throw exception: "+expectedException); + } catch (RuntimeException actualException) { + assertEquals(expectedException.toString(), actualException.toString()); + } + } + + @Test + public void testValueFeatureWithWhitespaceValue() throws Exception { + final RuntimeException expectedException = + new RuntimeException("mismatch: '0'!='500' @ responseHeader/status"); + try { + loadFeature("c2", ValueFeature.class.getCanonicalName(), + "{\"value\":\" \"}"); + fail("testValueFeatureWithWhitespaceValue failed to throw exception: "+expectedException); + } catch (RuntimeException actualException) { + assertEquals(expectedException.toString(), actualException.toString()); + } + } + + @Test + public void testRerankingWithConstantValueFeatureReplacesDocScore() throws Exception { + loadFeature("c3", ValueFeature.class.getCanonicalName(), "c3", + "{\"value\":2}"); + loadModel("m3", LinearModel.class.getCanonicalName(), new String[] {"c3"}, + "c3", "{\"weights\":{\"c3\":1.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("wt", "json"); + query.add("rq", "{!ltr model=m3 reRankDocs=4}"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==2.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==2.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==2.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/score==2.0"); + } + + @Test + public void testRerankingWithEfiValueFeatureReplacesDocScore() throws Exception { + loadFeature("c6", ValueFeature.class.getCanonicalName(), "c6", + "{\"value\":\"${val6}\"}"); + loadModel("m6", LinearModel.class.getCanonicalName(), new String[] {"c6"}, + "c6", "{\"weights\":{\"c6\":1.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("wt", "json"); + query.add("rq", "{!ltr model=m6 reRankDocs=4 efi.val6='2'}"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==2.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==2.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==2.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/score==2.0"); + } + + + @Test + public void testValueFeatureImplicitlyNotRequiredShouldReturnOkStatusCode() throws Exception { + loadFeature("c5", ValueFeature.class.getCanonicalName(), "c5", + "{\"value\":\"${val6}\"}"); + loadModel("m5", LinearModel.class.getCanonicalName(), new String[] {"c5"}, + "c5", "{\"weights\":{\"c5\":1.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score,fvonly:[fvonly]"); + query.add("rows", "4"); + query.add("wt", "json"); + query.add("rq", "{!ltr model=m5 reRankDocs=4}"); + + assertJQ("/query" + query.toQueryString(), "/responseHeader/status==0"); + } + + @Test + public void testValueFeatureExplictlyNotRequiredShouldReturnOkStatusCode() throws Exception { + loadFeature("c7", ValueFeature.class.getCanonicalName(), "c7", + "{\"value\":\"${val7}\",\"required\":false}"); + loadModel("m7", LinearModel.class.getCanonicalName(), new String[] {"c7"}, + "c7", "{\"weights\":{\"c7\":1.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score,fvonly:[fvonly]"); + query.add("rows", "4"); + query.add("wt", "json"); + query.add("rq", "{!ltr model=m7 reRankDocs=4}"); + + assertJQ("/query" + query.toQueryString(), "/responseHeader/status==0"); + } + + @Test + public void testValueFeatureRequiredShouldReturn400StatusCode() throws Exception { + loadFeature("c8", ValueFeature.class.getCanonicalName(), "c8", + "{\"value\":\"${val8}\",\"required\":true}"); + loadModel("m8", LinearModel.class.getCanonicalName(), new String[] {"c8"}, + "c8", "{\"weights\":{\"c8\":1.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:w1"); + query.add("fl", "*, score,fvonly:[fvonly]"); + query.add("rows", "4"); + query.add("wt", "json"); + query.add("rq", "{!ltr model=m8 reRankDocs=4}"); + + assertJQ("/query" + query.toQueryString(), "/responseHeader/status==400"); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestLinearModel.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestLinearModel.java new file mode 100644 index 00000000000..e8ee22482cb --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestLinearModel.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.model; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.solr.common.SolrException; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.norm.IdentityNormalizer; +import org.apache.solr.ltr.norm.Normalizer; +import org.apache.solr.ltr.store.FeatureStore; +import org.apache.solr.ltr.store.rest.ManagedModelStore; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestLinearModel extends TestRerankBase { + + public static LTRScoringModel createLinearModel(String name, List features, + List norms, + String featureStoreName, List allFeatures, + Map params) throws ModelException { + final LTRScoringModel model = LTRScoringModel.getInstance(solrResourceLoader, + LinearModel.class.getCanonicalName(), + name, + features, norms, featureStoreName, allFeatures, params); + return model; + } + + static ManagedModelStore store = null; + static FeatureStore fstore = null; + + @BeforeClass + public static void setup() throws Exception { + setuptest(); + // loadFeatures("features-store-test-model.json"); + store = getManagedModelStore(); + fstore = getManagedFeatureStore().getFeatureStore("test"); + + } + + @Test + public void getInstanceTest() { + final Map weights = new HashMap<>(); + weights.put("constant1", 1d); + weights.put("constant5", 1d); + + Map params = new HashMap(); + final List features = getFeatures(new String[] { + "constant1", "constant5"}); + final List norms = + new ArrayList( + Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE)); + params.put("weights", weights); + final LTRScoringModel ltrScoringModel = createLinearModel("test1", + features, norms, "test", fstore.getFeatures(), + params); + + store.addModel(ltrScoringModel); + final LTRScoringModel m = store.getModel("test1"); + assertEquals(ltrScoringModel, m); + } + + @Test + public void nullFeatureWeightsTest() { + final ModelException expectedException = + new ModelException("Model test2 doesn't contain any weights"); + try { + final List features = getFeatures(new String[] + {"constant1", "constant5"}); + final List norms = + new ArrayList( + Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE)); + createLinearModel("test2", + features, norms, "test", fstore.getFeatures(), null); + fail("unexpectedly got here instead of catching "+expectedException); + } catch (ModelException actualException) { + assertEquals(expectedException.toString(), actualException.toString()); + } + } + + @Test + public void existingNameTest() { + final SolrException expectedException = + new SolrException(SolrException.ErrorCode.BAD_REQUEST, + ModelException.class.getCanonicalName()+": model 'test3' already exists. Please use a different name"); + try { + final List features = getFeatures(new String[] + {"constant1", "constant5"}); + final List norms = + new ArrayList( + Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE)); + final Map weights = new HashMap<>(); + weights.put("constant1", 1d); + weights.put("constant5", 1d); + + Map params = new HashMap(); + params.put("weights", weights); + final LTRScoringModel ltrScoringModel = createLinearModel("test3", + features, norms, "test", fstore.getFeatures(), + params); + store.addModel(ltrScoringModel); + final LTRScoringModel m = store.getModel("test3"); + assertEquals(ltrScoringModel, m); + store.addModel(ltrScoringModel); + fail("unexpectedly got here instead of catching "+expectedException); + } catch (SolrException actualException) { + assertEquals(expectedException.toString(), actualException.toString()); + } + } + + @Test + public void duplicateFeatureTest() { + final ModelException expectedException = + new ModelException("duplicated feature constant1 in model test4"); + try { + final List features = getFeatures(new String[] + {"constant1", "constant1"}); + final List norms = + new ArrayList( + Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE)); + final Map weights = new HashMap<>(); + weights.put("constant1", 1d); + weights.put("constant5", 1d); + + Map params = new HashMap(); + params.put("weights", weights); + final LTRScoringModel ltrScoringModel = createLinearModel("test4", + features, norms, "test", fstore.getFeatures(), + params); + store.addModel(ltrScoringModel); + fail("unexpectedly got here instead of catching "+expectedException); + } catch (ModelException actualException) { + assertEquals(expectedException.toString(), actualException.toString()); + } + + } + + @Test + public void missingFeatureWeightTest() { + final ModelException expectedException = + new ModelException("Model test5 lacks weight(s) for [constant5]"); + try { + final List features = getFeatures(new String[] + {"constant1", "constant5"}); + final List norms = + new ArrayList( + Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE)); + final Map weights = new HashMap<>(); + weights.put("constant1", 1d); + weights.put("constant5missing", 1d); + + Map params = new HashMap(); + params.put("weights", weights); + createLinearModel("test5", + features, norms, "test", fstore.getFeatures(), + params); + fail("unexpectedly got here instead of catching "+expectedException); + } catch (ModelException actualException) { + assertEquals(expectedException.toString(), actualException.toString()); + } + } + + @Test + public void emptyFeaturesTest() { + final ModelException expectedException = + new ModelException("no features declared for model test6"); + try { + final List features = getFeatures(new String[] {}); + final List norms = + new ArrayList( + Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE)); + final Map weights = new HashMap<>(); + weights.put("constant1", 1d); + weights.put("constant5missing", 1d); + + Map params = new HashMap(); + params.put("weights", weights); + final LTRScoringModel ltrScoringModel = createLinearModel("test6", + features, norms, "test", fstore.getFeatures(), + params); + store.addModel(ltrScoringModel); + fail("unexpectedly got here instead of catching "+expectedException); + } catch (ModelException actualException) { + assertEquals(expectedException.toString(), actualException.toString()); + } + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestMultipleAdditiveTreesModel.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestMultipleAdditiveTreesModel.java new file mode 100644 index 00000000000..3748331a43e --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestMultipleAdditiveTreesModel.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.model; + +//import static org.junit.internal.matchers.StringContains.containsString; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class TestMultipleAdditiveTreesModel extends TestRerankBase { + + + @BeforeClass + public static void before() throws Exception { + setuptest("solrconfig-ltr.xml", "schema.xml"); + + assertU(adoc("id", "1", "title", "w1", "description", "w1", "popularity","1")); + assertU(adoc("id", "2", "title", "w2", "description", "w2", "popularity","2")); + assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity","3")); + assertU(adoc("id", "4", "title", "w4", "description", "w4", "popularity","4")); + assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity","5")); + assertU(commit()); + + loadFeatures("multipleadditivetreesmodel_features.json"); // currently needed to force + // scoring on all docs + loadModels("multipleadditivetreesmodel.json"); + } + + @AfterClass + public static void after() throws Exception { + aftertest(); + } + + + @Test + public void testMultipleAdditiveTreesScoringWithAndWithoutEfiFeatureMatches() throws Exception { + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("rows", "3"); + query.add("fl", "*,score"); + + // Regular scores + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==1.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==1.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==1.0"); + + // No match scores since user_query not passed in to external feature info + // and feature depended on it. + query.add("rq", "{!ltr reRankDocs=3 model=multipleadditivetreesmodel efi.user_query=dsjkafljjk}"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==-120.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==-120.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==-120.0"); + + // Matched user query since it was passed in + query.remove("rq"); + query.add("rq", "{!ltr reRankDocs=3 model=multipleadditivetreesmodel efi.user_query=w3}"); + + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='3'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==-20.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==-120.0"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==-120.0"); + } + + @Ignore + @Test + public void multipleAdditiveTreesTestExplain() throws Exception { + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*,score,[fv]"); + query.add("rows", "3"); + + query.add("rq", + "{!ltr reRankDocs=3 model=multipleadditivetreesmodel efi.user_query=w3}"); + + // test out the explain feature, make sure it returns something + query.setParam("debugQuery", "on"); + String qryResult = JQ("/query" + query.toQueryString()); + + qryResult = qryResult.replaceAll("\n", " "); + // FIXME containsString doesn't exist. + // assertThat(qryResult, containsString("\"debug\":{")); + // qryResult = qryResult.substring(qryResult.indexOf("debug")); + // + // assertThat(qryResult, containsString("\"explain\":{")); + // qryResult = qryResult.substring(qryResult.indexOf("explain")); + // + // assertThat(qryResult, containsString("multipleadditivetreesmodel")); + // assertThat(qryResult, + // containsString(MultipleAdditiveTreesModel.class.getCanonicalName())); + // + // assertThat(qryResult, containsString("-100.0 = tree 0")); + // assertThat(qryResult, containsString("50.0 = tree 0")); + // assertThat(qryResult, containsString("-20.0 = tree 1")); + // assertThat(qryResult, containsString("'matchedTitle':1.0 > 0.5")); + // assertThat(qryResult, containsString("'matchedTitle':0.0 <= 0.5")); + // + // assertThat(qryResult, containsString(" Go Right ")); + // assertThat(qryResult, containsString(" Go Left ")); + // assertThat(qryResult, + // containsString("'this_feature_doesnt_exist' does not exist in FV")); + } + + @Test + public void multipleAdditiveTreesTestNoParams() throws Exception { + final ModelException expectedException = + new ModelException("no trees declared for model multipleadditivetreesmodel_no_params"); + try { + createModelFromFiles("multipleadditivetreesmodel_no_params.json", + "multipleadditivetreesmodel_features.json"); + fail("multipleAdditiveTreesTestNoParams failed to throw exception: "+expectedException); + } catch (Exception actualException) { + Throwable rootError = getRootCause(actualException); + assertEquals(expectedException.toString(), rootError.toString()); + } + + } + + @Test + public void multipleAdditiveTreesTestEmptyParams() throws Exception { + final ModelException expectedException = + new ModelException("no trees declared for model multipleadditivetreesmodel_no_trees"); + try { + createModelFromFiles("multipleadditivetreesmodel_no_trees.json", + "multipleadditivetreesmodel_features.json"); + fail("multipleAdditiveTreesTestEmptyParams failed to throw exception: "+expectedException); + } catch (Exception actualException) { + Throwable rootError = getRootCause(actualException); + assertEquals(expectedException.toString(), rootError.toString()); + } + } + + @Test + public void multipleAdditiveTreesTestNoWeight() throws Exception { + final ModelException expectedException = + new ModelException("MultipleAdditiveTreesModel tree doesn't contain a weight"); + try { + createModelFromFiles("multipleadditivetreesmodel_no_weight.json", + "multipleadditivetreesmodel_features.json"); + fail("multipleAdditiveTreesTestNoWeight failed to throw exception: "+expectedException); + } catch (Exception actualException) { + Throwable rootError = getRootCause(actualException); + assertEquals(expectedException.toString(), rootError.toString()); + } + } + + @Test + public void multipleAdditiveTreesTestTreesParamDoesNotContatinTree() throws Exception { + final ModelException expectedException = + new ModelException("MultipleAdditiveTreesModel tree doesn't contain a tree"); + try { + createModelFromFiles("multipleadditivetreesmodel_no_tree.json", + "multipleadditivetreesmodel_features.json"); + fail("multipleAdditiveTreesTestTreesParamDoesNotContatinTree failed to throw exception: "+expectedException); + } catch (Exception actualException) { + Throwable rootError = getRootCause(actualException); + assertEquals(expectedException.toString(), rootError.toString()); + } + } + + @Test + public void multipleAdditiveTreesTestNoFeaturesSpecified() throws Exception { + final ModelException expectedException = + new ModelException("no features declared for model multipleadditivetreesmodel_no_features"); + try { + createModelFromFiles("multipleadditivetreesmodel_no_features.json", + "multipleadditivetreesmodel_features.json"); + fail("multipleAdditiveTreesTestNoFeaturesSpecified failed to throw exception: "+expectedException); + } catch (ModelException actualException) { + assertEquals(expectedException.toString(), actualException.toString()); + } + } + + @Test + public void multipleAdditiveTreesTestNoRight() throws Exception { + final ModelException expectedException = + new ModelException("MultipleAdditiveTreesModel tree node is missing right"); + try { + createModelFromFiles("multipleadditivetreesmodel_no_right.json", + "multipleadditivetreesmodel_features.json"); + fail("multipleAdditiveTreesTestNoRight failed to throw exception: "+expectedException); + } catch (Exception actualException) { + Throwable rootError = getRootCause(actualException); + assertEquals(expectedException.toString(), rootError.toString()); + } + } + + @Test + public void multipleAdditiveTreesTestNoLeft() throws Exception { + final ModelException expectedException = + new ModelException("MultipleAdditiveTreesModel tree node is missing left"); + try { + createModelFromFiles("multipleadditivetreesmodel_no_left.json", + "multipleadditivetreesmodel_features.json"); + fail("multipleAdditiveTreesTestNoLeft failed to throw exception: "+expectedException); + } catch (Exception actualException) { + Throwable rootError = getRootCause(actualException); + assertEquals(expectedException.toString(), rootError.toString()); + } + } + + @Test + public void multipleAdditiveTreesTestNoThreshold() throws Exception { + final ModelException expectedException = + new ModelException("MultipleAdditiveTreesModel tree node is missing threshold"); + try { + createModelFromFiles("multipleadditivetreesmodel_no_threshold.json", + "multipleadditivetreesmodel_features.json"); + fail("multipleAdditiveTreesTestNoThreshold failed to throw exception: "+expectedException); + } catch (Exception actualException) { + Throwable rootError = getRootCause(actualException); + assertEquals(expectedException.toString(), rootError.toString()); + } + } + + @Test + public void multipleAdditiveTreesTestMissingTreeFeature() throws Exception { + final ModelException expectedException = + new ModelException("MultipleAdditiveTreesModel tree node is leaf with left=-100.0 and right=75.0"); + try { + createModelFromFiles("multipleadditivetreesmodel_no_feature.json", + "multipleadditivetreesmodel_features.json"); + fail("multipleAdditiveTreesTestMissingTreeFeature failed to throw exception: "+expectedException); + } catch (ModelException actualException) { + assertEquals(expectedException.toString(), actualException.toString()); + } + } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/norm/TestMinMaxNormalizer.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/norm/TestMinMaxNormalizer.java new file mode 100644 index 00000000000..055b3bccacd --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/norm/TestMinMaxNormalizer.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.norm; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.util.HashMap; +import java.util.Map; + +import org.apache.solr.core.SolrResourceLoader; +import org.junit.Test; + +public class TestMinMaxNormalizer { + + private final SolrResourceLoader solrResourceLoader = new SolrResourceLoader(); + + private Normalizer implTestMinMax(Map params, + float expectedMin, float expectedMax) { + final Normalizer n = Normalizer.getInstance( + solrResourceLoader, + MinMaxNormalizer.class.getCanonicalName(), + params); + assertTrue(n instanceof MinMaxNormalizer); + final MinMaxNormalizer mmn = (MinMaxNormalizer)n; + assertEquals(mmn.getMin(), expectedMin, 0.0); + assertEquals(mmn.getMax(), expectedMax, 0.0); + return n; + } + + @Test + public void testInvalidMinMaxNoParams() { + implTestMinMax(new HashMap(), + Float.NEGATIVE_INFINITY, + Float.POSITIVE_INFINITY); + } + + @Test + public void testInvalidMinMaxMissingMax() { + final Map params = new HashMap(); + params.put("min", "0.0f"); + implTestMinMax(params, + 0.0f, + Float.POSITIVE_INFINITY); + } + + @Test + public void testInvalidMinMaxMissingMin() { + final Map params = new HashMap(); + params.put("max", "0.0f"); + implTestMinMax(params, + Float.NEGATIVE_INFINITY, + 0.0f); + } + + @Test + public void testMinMaxNormalizerMinLargerThanMax() { + final Map params = new HashMap(); + params.put("min", "10.0f"); + params.put("max", "0.0f"); + implTestMinMax(params, + 10.0f, + 0.0f); + } + + @Test + public void testMinMaxNormalizerMinEqualToMax() { + final Map params = new HashMap(); + params.put("min", "10.0f"); + params.put("max", "10.0f"); + final NormalizerException expectedException = + new NormalizerException("MinMax Normalizer delta must not be zero " + + "| min = 10.0,max = 10.0,delta = 0.0"); + try { + implTestMinMax(params, + 10.0f, + 10.0f); + fail("testMinMaxNormalizerMinEqualToMax failed to throw exception: "+expectedException); + } catch(NormalizerException actualException) { + assertEquals(expectedException.toString(), actualException.toString()); + } + } + + @Test + public void testNormalizer() { + final Map params = new HashMap(); + params.put("min", "5.0f"); + params.put("max", "10.0f"); + final Normalizer n = + implTestMinMax(params, + 5.0f, + 10.0f); + + float value = 8; + assertEquals((value - 5f) / (10f - 5f), n.normalize(value), 0.0001); + value = 100; + assertEquals((value - 5f) / (10f - 5f), n.normalize(value), 0.0001); + value = 150; + assertEquals((value - 5f) / (10f - 5f), n.normalize(value), 0.0001); + value = -1; + assertEquals((value - 5f) / (10f - 5f), n.normalize(value), 0.0001); + value = 5; + assertEquals((value - 5f) / (10f - 5f), n.normalize(value), 0.0001); + } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/norm/TestStandardNormalizer.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/norm/TestStandardNormalizer.java new file mode 100644 index 00000000000..10fa9720ccd --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/norm/TestStandardNormalizer.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.norm; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.util.HashMap; +import java.util.Map; + +import org.apache.solr.core.SolrResourceLoader; +import org.junit.Test; + +public class TestStandardNormalizer { + + private final SolrResourceLoader solrResourceLoader = new SolrResourceLoader(); + + private Normalizer implTestStandard(Map params, + float expectedAvg, float expectedStd) { + final Normalizer n = Normalizer.getInstance( + solrResourceLoader, + StandardNormalizer.class.getCanonicalName(), + params); + assertTrue(n instanceof StandardNormalizer); + final StandardNormalizer sn = (StandardNormalizer)n; + assertEquals(sn.getAvg(), expectedAvg, 0.0); + assertEquals(sn.getStd(), expectedStd, 0.0); + return n; + } + + @Test + public void testNormalizerNoParams() { + implTestStandard(new HashMap(), + 0.0f, + 1.0f); + } + + @Test + public void testInvalidSTD() { + final Map params = new HashMap(); + params.put("std", "0f"); + final NormalizerException expectedException = + new NormalizerException("Standard Normalizer standard deviation must be positive " + + "| avg = 0.0,std = 0.0"); + try { + implTestStandard(params, + 0.0f, + 0.0f); + fail("testInvalidSTD failed to throw exception: "+expectedException); + } catch(NormalizerException actualException) { + assertEquals(expectedException.toString(), actualException.toString()); + } + } + + @Test + public void testInvalidSTD2() { + final Map params = new HashMap(); + params.put("std", "-1f"); + final NormalizerException expectedException = + new NormalizerException("Standard Normalizer standard deviation must be positive " + + "| avg = 0.0,std = -1.0"); + try { + implTestStandard(params, + 0.0f, + -1f); + fail("testInvalidSTD2 failed to throw exception: "+expectedException); + } catch(NormalizerException actualException) { + assertEquals(expectedException.toString(), actualException.toString()); + } + } + + @Test + public void testInvalidSTD3() { + final Map params = new HashMap(); + params.put("avg", "1f"); + params.put("std", "0f"); + final NormalizerException expectedException = + new NormalizerException("Standard Normalizer standard deviation must be positive " + + "| avg = 1.0,std = 0.0"); + try { + implTestStandard(params, + 1f, + 0f); + fail("testInvalidSTD3 failed to throw exception: "+expectedException); + } catch(NormalizerException actualException) { + assertEquals(expectedException.toString(), actualException.toString()); + } + } + + @Test + public void testNormalizer() { + Map params = new HashMap(); + params.put("avg", "0f"); + params.put("std", "1f"); + final Normalizer identity = + implTestStandard(params, + 0f, + 1f); + + float value = 8; + assertEquals(value, identity.normalize(value), 0.0001); + value = 150; + assertEquals(value, identity.normalize(value), 0.0001); + params = new HashMap(); + params.put("avg", "10f"); + params.put("std", "1.5f"); + final Normalizer norm = Normalizer.getInstance( + solrResourceLoader, + StandardNormalizer.class.getCanonicalName(), + params); + + for (final float v : new float[] {10f, 20f, 25f, 30f, 31f, 40f, 42f, 100f, + 10000000f}) { + assertEquals((v - 10f) / (1.5f), norm.normalize(v), 0.0001); + } + } +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/store/rest/TestManagedFeatureStore.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/store/rest/TestManagedFeatureStore.java new file mode 100644 index 00000000000..14373fbe618 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/store/rest/TestManagedFeatureStore.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.store.rest; + +import java.util.HashMap; +import java.util.Map; + +import org.apache.lucene.util.LuceneTestCase; + +public class TestManagedFeatureStore extends LuceneTestCase { + + public static Map createMap(String name, String className, Map params) { + final Map map = new HashMap(); + map.put(ManagedFeatureStore.NAME_KEY, name); + map.put(ManagedFeatureStore.CLASS_KEY, className); + if (params != null) { + map.put(ManagedFeatureStore.PARAMS_KEY, params); + } + return map; + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/store/rest/TestModelManager.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/store/rest/TestModelManager.java new file mode 100644 index 00000000000..855f053b57c --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/store/rest/TestModelManager.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.store.rest; + +import org.apache.solr.common.util.NamedList; +import org.apache.solr.core.SolrResourceLoader; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.feature.FieldValueFeature; +import org.apache.solr.ltr.feature.ValueFeature; +import org.apache.solr.ltr.model.LinearModel; +import org.apache.solr.rest.ManagedResource; +import org.apache.solr.rest.ManagedResourceStorage; +import org.apache.solr.rest.RestManager; +import org.apache.solr.search.LTRQParserPlugin; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TestModelManager extends TestRerankBase { + + @BeforeClass + public static void init() throws Exception { + setuptest(); + } + + @Before + public void restart() throws Exception { + restTestHarness.delete(ManagedFeatureStore.REST_END_POINT + "/*"); + restTestHarness.delete(ManagedModelStore.REST_END_POINT + "/*"); + + } + + @Test + public void test() throws Exception { + final SolrResourceLoader loader = new SolrResourceLoader( + tmpSolrHome.toPath()); + + final RestManager.Registry registry = loader.getManagedResourceRegistry(); + assertNotNull( + "Expected a non-null RestManager.Registry from the SolrResourceLoader!", + registry); + + final String resourceId = "/schema/fstore1"; + registry.registerManagedResource(resourceId, ManagedFeatureStore.class, + new LTRQParserPlugin()); + + final String resourceId2 = "/schema/mstore1"; + registry.registerManagedResource(resourceId2, ManagedModelStore.class, + new LTRQParserPlugin()); + + final NamedList initArgs = new NamedList<>(); + + final RestManager restManager = new RestManager(); + restManager.init(loader, initArgs, + new ManagedResourceStorage.InMemoryStorageIO()); + + final ManagedResource res = restManager.getManagedResource(resourceId); + assertTrue(res instanceof ManagedFeatureStore); + assertEquals(res.getResourceId(), resourceId); + + } + + @Test + public void testRestManagerEndpoints() throws Exception { + // relies on these ManagedResources being activated in the + // schema-rest.xml used by this test + assertJQ("/schema/managed", "/responseHeader/status==0"); + + final String valueFeatureClassName = ValueFeature.class.getCanonicalName(); + + // Add features + String feature = "{\"name\": \"test1\", \"class\": \""+valueFeatureClassName+"\", \"params\": {\"value\": 1} }"; + assertJPut(ManagedFeatureStore.REST_END_POINT, feature, + "/responseHeader/status==0"); + + feature = "{\"name\": \"test2\", \"class\": \""+valueFeatureClassName+"\", \"params\": {\"value\": 1} }"; + assertJPut(ManagedFeatureStore.REST_END_POINT, feature, + "/responseHeader/status==0"); + + feature = "{\"name\": \"test3\", \"class\": \""+valueFeatureClassName+"\", \"params\": {\"value\": 1} }"; + assertJPut(ManagedFeatureStore.REST_END_POINT, feature, + "/responseHeader/status==0"); + + feature = "{\"name\": \"test33\", \"store\": \"TEST\", \"class\": \""+valueFeatureClassName+"\", \"params\": {\"value\": 1} }"; + assertJPut(ManagedFeatureStore.REST_END_POINT, feature, + "/responseHeader/status==0"); + + final String multipleFeatures = "[{\"name\": \"test4\", \"class\": \""+valueFeatureClassName+"\", \"params\": {\"value\": 1} }" + + ",{\"name\": \"test5\", \"class\": \""+valueFeatureClassName+"\", \"params\": {\"value\": 1} } ]"; + assertJPut(ManagedFeatureStore.REST_END_POINT, multipleFeatures, + "/responseHeader/status==0"); + + final String fieldValueFeatureClassName = FieldValueFeature.class.getCanonicalName(); + + // Add bad feature (wrong params)_ + final String badfeature = "{\"name\": \"fvalue\", \"class\": \""+fieldValueFeatureClassName+"\", \"params\": {\"value\": 1} }"; + assertJPut(ManagedFeatureStore.REST_END_POINT, badfeature, + "/error/msg/=='No setter corrresponding to \\'value\\' in "+fieldValueFeatureClassName+"'"); + + final String linearModelClassName = LinearModel.class.getCanonicalName(); + + // Add models + String model = "{ \"name\":\"testmodel1\", \"class\":\""+linearModelClassName+"\", \"features\":[] }"; + // fails since it does not have features + assertJPut(ManagedModelStore.REST_END_POINT, model, + "/responseHeader/status==400"); + // fails since it does not have weights + model = "{ \"name\":\"testmodel2\", \"class\":\""+linearModelClassName+"\", \"features\":[{\"name\":\"test1\"}, {\"name\":\"test2\"}] }"; + assertJPut(ManagedModelStore.REST_END_POINT, model, + "/responseHeader/status==400"); + // success + model = "{ \"name\":\"testmodel3\", \"class\":\""+linearModelClassName+"\", \"features\":[{\"name\":\"test1\"}, {\"name\":\"test2\"}],\"params\":{\"weights\":{\"test1\":1.5,\"test2\":2.0}}}"; + assertJPut(ManagedModelStore.REST_END_POINT, model, + "/responseHeader/status==0"); + // success + final String multipleModels = "[{ \"name\":\"testmodel4\", \"class\":\""+linearModelClassName+"\", \"features\":[{\"name\":\"test1\"}, {\"name\":\"test2\"}],\"params\":{\"weights\":{\"test1\":1.5,\"test2\":2.0}} }\n" + + ",{ \"name\":\"testmodel5\", \"class\":\""+linearModelClassName+"\", \"features\":[{\"name\":\"test1\"}, {\"name\":\"test2\"}],\"params\":{\"weights\":{\"test1\":1.5,\"test2\":2.0}} } ]"; + assertJPut(ManagedModelStore.REST_END_POINT, multipleModels, + "/responseHeader/status==0"); + final String qryResult = JQ(ManagedModelStore.REST_END_POINT); + + assert (qryResult.contains("\"name\":\"testmodel3\"") + && qryResult.contains("\"name\":\"testmodel4\"") && qryResult + .contains("\"name\":\"testmodel5\"")); + /* + * assertJQ(LTRParams.MSTORE_END_POINT, "/models/[0]/name=='testmodel3'"); + * assertJQ(LTRParams.MSTORE_END_POINT, "/models/[1]/name=='testmodel4'"); + * assertJQ(LTRParams.MSTORE_END_POINT, "/models/[2]/name=='testmodel5'"); + */ + assertJQ(ManagedFeatureStore.REST_END_POINT, + "/featureStores==['TEST','_DEFAULT_']"); + assertJQ(ManagedFeatureStore.REST_END_POINT + "/_DEFAULT_", + "/features/[0]/name=='test1'"); + assertJQ(ManagedFeatureStore.REST_END_POINT + "/TEST", + "/features/[0]/name=='test33'"); + } + + @Test + public void testEndpointsFromFile() throws Exception { + loadFeatures("features-linear.json"); + loadModels("linear-model.json"); + + assertJQ(ManagedModelStore.REST_END_POINT, + "/models/[0]/name=='6029760550880411648'"); + assertJQ(ManagedFeatureStore.REST_END_POINT + "/_DEFAULT_", + "/features/[1]/name=='description'"); + } + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/store/rest/TestModelManagerPersistence.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/store/rest/TestModelManagerPersistence.java new file mode 100644 index 00000000000..66c26fd7255 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/store/rest/TestModelManagerPersistence.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.store.rest; + +import java.util.ArrayList; +import java.util.Map; + +import org.apache.commons.io.FileUtils; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.feature.ValueFeature; +import org.apache.solr.ltr.model.LinearModel; +import org.junit.Before; +import org.junit.Test; +import org.noggit.ObjectBuilder; + +public class TestModelManagerPersistence extends TestRerankBase { + + @Before + public void init() throws Exception { + setupPersistenttest(); + } + + // executed first + @Test + public void testFeaturePersistence() throws Exception { + + loadFeature("feature", ValueFeature.class.getCanonicalName(), "test", + "{\"value\":2}"); + assertJQ(ManagedFeatureStore.REST_END_POINT + "/test", + "/features/[0]/name=='feature'"); + restTestHarness.reload(); + assertJQ(ManagedFeatureStore.REST_END_POINT + "/test", + "/features/[0]/name=='feature'"); + loadFeature("feature1", ValueFeature.class.getCanonicalName(), "test1", + "{\"value\":2}"); + loadFeature("feature2", ValueFeature.class.getCanonicalName(), "test", + "{\"value\":2}"); + loadFeature("feature3", ValueFeature.class.getCanonicalName(), "test2", + "{\"value\":2}"); + assertJQ(ManagedFeatureStore.REST_END_POINT + "/test", + "/features/[0]/name=='feature'"); + assertJQ(ManagedFeatureStore.REST_END_POINT + "/test", + "/features/[1]/name=='feature2'"); + assertJQ(ManagedFeatureStore.REST_END_POINT + "/test1", + "/features/[0]/name=='feature1'"); + assertJQ(ManagedFeatureStore.REST_END_POINT + "/test2", + "/features/[0]/name=='feature3'"); + restTestHarness.reload(); + assertJQ(ManagedFeatureStore.REST_END_POINT + "/test", + "/features/[0]/name=='feature'"); + assertJQ(ManagedFeatureStore.REST_END_POINT + "/test", + "/features/[1]/name=='feature2'"); + assertJQ(ManagedFeatureStore.REST_END_POINT + "/test1", + "/features/[0]/name=='feature1'"); + assertJQ(ManagedFeatureStore.REST_END_POINT + "/test2", + "/features/[0]/name=='feature3'"); + loadModel("test-model", LinearModel.class.getCanonicalName(), + new String[] {"feature"}, "test", "{\"weights\":{\"feature\":1.0}}"); + loadModel("test-model2", LinearModel.class.getCanonicalName(), + new String[] {"feature1"}, "test1", "{\"weights\":{\"feature1\":1.0}}"); + final String fstorecontent = FileUtils + .readFileToString(fstorefile, "UTF-8"); + final String mstorecontent = FileUtils + .readFileToString(mstorefile, "UTF-8"); + + //check feature/model stores on deletion + final ArrayList fStore = (ArrayList) ((Map) + ObjectBuilder.fromJSON(fstorecontent)).get("managedList"); + for (int idx = 0;idx < fStore.size(); ++ idx) { + String store = (String) ((Map)fStore.get(idx)).get("store"); + assertTrue(store.equals("test") || store.equals("test2") || store.equals("test1")); + } + + final ArrayList mStore = (ArrayList) ((Map) + ObjectBuilder.fromJSON(mstorecontent)).get("managedList"); + for (int idx = 0;idx < mStore.size(); ++ idx) { + String store = (String) ((Map)mStore.get(idx)).get("store"); + assertTrue(store.equals("test") || store.equals("test1")); + } + + assertJDelete(ManagedFeatureStore.REST_END_POINT + "/test2", + "/responseHeader/status==0"); + assertJDelete(ManagedModelStore.REST_END_POINT + "/test-model2", + "/responseHeader/status==0"); + assertJQ(ManagedFeatureStore.REST_END_POINT + "/test2", + "/features/==[]"); + assertJQ(ManagedModelStore.REST_END_POINT + "/test-model2", + "/models/[0]/name=='test-model'"); + restTestHarness.reload(); + assertJQ(ManagedFeatureStore.REST_END_POINT + "/test2", + "/features/==[]"); + assertJQ(ManagedModelStore.REST_END_POINT + "/test-model2", + "/models/[0]/name=='test-model'"); + + assertJDelete(ManagedModelStore.REST_END_POINT + "/*", + "/responseHeader/status==0"); + assertJDelete(ManagedFeatureStore.REST_END_POINT + "/*", + "/responseHeader/status==0"); + assertJQ(ManagedFeatureStore.REST_END_POINT + "/test1", + "/features/==[]"); + restTestHarness.reload(); + assertJQ(ManagedFeatureStore.REST_END_POINT + "/test1", + "/features/==[]"); + + } + +}