SOLR-14560: Interleaving for Learning To Rank (#1571)

SOLR-14560: Add interleaving support in Learning To Rank
This commit is contained in:
Alessandro Benedetti 2020-11-18 18:15:24 +00:00 committed by GitHub
parent ea4dd0580f
commit af0455ac83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 2329 additions and 214 deletions

View File

@ -167,6 +167,8 @@ New Features
* SOLR-14907: Add v2 API for configSet upload, including single-file insertion. (Houston Putman) * SOLR-14907: Add v2 API for configSet upload, including single-file insertion. (Houston Putman)
* SOLR-14560: Add interleaving support in Learning To Rank. (Alessandro Benedetti, Christine Poerschke)
Improvements Improvements
--------------------- ---------------------
* SOLR-14942: Reduce leader election time on node shutdown by removing election nodes before closing cores. * SOLR-14942: Reduce leader election time on node shutdown by removing election nodes before closing cores.

View File

@ -31,6 +31,7 @@ import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight; import org.apache.lucene.search.Weight;
import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery;
import org.apache.solr.search.SolrIndexSearcher; import org.apache.solr.search.SolrIndexSearcher;
@ -42,12 +43,17 @@ import org.apache.solr.search.SolrIndexSearcher;
* */ * */
public class LTRRescorer extends Rescorer { public class LTRRescorer extends Rescorer {
LTRScoringQuery scoringQuery; final private LTRScoringQuery scoringQuery;
public LTRRescorer() {
this.scoringQuery = null;
}
public LTRRescorer(LTRScoringQuery scoringQuery) { public LTRRescorer(LTRScoringQuery scoringQuery) {
this.scoringQuery = scoringQuery; this.scoringQuery = scoringQuery;
} }
private void heapAdjust(ScoreDoc[] hits, int size, int root) { protected static void heapAdjust(ScoreDoc[] hits, int size, int root) {
final ScoreDoc doc = hits[root]; final ScoreDoc doc = hits[root];
final float score = doc.score; final float score = doc.score;
int i = root; int i = root;
@ -82,7 +88,7 @@ public class LTRRescorer extends Rescorer {
} }
} }
private void heapify(ScoreDoc[] hits, int size) { protected static void heapify(ScoreDoc[] hits, int size) {
for (int i = (size >> 1) - 1; i >= 0; i--) { for (int i = (size >> 1) - 1; i >= 0; i--) {
heapAdjust(hits, size, i); heapAdjust(hits, size, i);
} }
@ -104,23 +110,27 @@ public class LTRRescorer extends Rescorer {
if ((topN == 0) || (firstPassTopDocs.scoreDocs.length == 0)) { if ((topN == 0) || (firstPassTopDocs.scoreDocs.length == 0)) {
return firstPassTopDocs; return firstPassTopDocs;
} }
final ScoreDoc[] hits = firstPassTopDocs.scoreDocs; final ScoreDoc[] firstPassResults = getFirstPassDocsRanked(firstPassTopDocs);
Arrays.sort(hits, new Comparator<ScoreDoc>() {
@Override
public int compare(ScoreDoc a, ScoreDoc b) {
return a.doc - b.doc;
}
});
assert firstPassTopDocs.totalHits.relation == TotalHits.Relation.EQUAL_TO;
topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value)); topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value));
final ScoreDoc[] reranked = rerank(searcher, topN, firstPassResults);
return new TopDocs(firstPassTopDocs.totalHits, reranked);
}
private ScoreDoc[] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPassResults) throws IOException {
final ScoreDoc[] reranked = new ScoreDoc[topN]; final ScoreDoc[] reranked = new ScoreDoc[topN];
final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves(); final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) searcher final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) searcher
.createWeight(searcher.rewrite(scoringQuery), ScoreMode.COMPLETE, 1); .createWeight(searcher.rewrite(scoringQuery), ScoreMode.COMPLETE, 1);
scoreFeatures(searcher, firstPassTopDocs,topN, modelWeight, hits, leaves, reranked); scoreFeatures(searcher,topN, modelWeight, firstPassResults, leaves, reranked);
// Must sort all documents that we reranked, and then select the top // Must sort all documents that we reranked, and then select the top
sortByScore(reranked);
return reranked;
}
protected static void sortByScore(ScoreDoc[] reranked) {
Arrays.sort(reranked, new Comparator<ScoreDoc>() { Arrays.sort(reranked, new Comparator<ScoreDoc>() {
@Override @Override
public int compare(ScoreDoc a, ScoreDoc b) { public int compare(ScoreDoc a, ScoreDoc b) {
@ -136,13 +146,24 @@ public class LTRRescorer extends Rescorer {
} }
} }
}); });
return new TopDocs(firstPassTopDocs.totalHits, reranked);
} }
public void scoreFeatures(IndexSearcher indexSearcher, TopDocs firstPassTopDocs, protected static ScoreDoc[] getFirstPassDocsRanked(TopDocs firstPassTopDocs) {
int topN, LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List<LeafReaderContext> leaves, final ScoreDoc[] hits = firstPassTopDocs.scoreDocs;
ScoreDoc[] reranked) throws IOException { Arrays.sort(hits, new Comparator<ScoreDoc>() {
@Override
public int compare(ScoreDoc a, ScoreDoc b) {
return a.doc - b.doc;
}
});
assert firstPassTopDocs.totalHits.relation == TotalHits.Relation.EQUAL_TO;
return hits;
}
public void scoreFeatures(IndexSearcher indexSearcher,
int topN, LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List<LeafReaderContext> leaves,
ScoreDoc[] reranked) throws IOException {
int readerUpto = -1; int readerUpto = -1;
int endDoc = 0; int endDoc = 0;
@ -150,7 +171,6 @@ public class LTRRescorer extends Rescorer {
LTRScoringQuery.ModelWeight.ModelScorer scorer = null; LTRScoringQuery.ModelWeight.ModelScorer scorer = null;
int hitUpto = 0; int hitUpto = 0;
final FeatureLogger featureLogger = scoringQuery.getFeatureLogger();
while (hitUpto < hits.length) { while (hitUpto < hits.length) {
final ScoreDoc hit = hits[hitUpto]; final ScoreDoc hit = hits[hitUpto];
@ -166,64 +186,77 @@ public class LTRRescorer extends Rescorer {
docBase = readerContext.docBase; docBase = readerContext.docBase;
scorer = modelWeight.scorer(readerContext); scorer = modelWeight.scorer(readerContext);
} }
// Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to scoreSingleHit(indexSearcher, topN, modelWeight, docBase, hitUpto, hit, docID, scoringQuery, scorer, reranked);
// call score hitUpto++;
// even if no feature scorers match, since a model might use that info to }
// return a }
// non-zero score. Same applies for the case of advancing a LTRScoringQuery.ModelWeight.ModelScorer
// past the target
// doc since the model algorithm still needs to compute a potentially
// non-zero score from blank features.
assert (scorer != null);
final int targetDoc = docID - docBase;
scorer.docID();
scorer.iterator().advance(targetDoc);
scorer.getDocInfo().setOriginalDocScore(hit.score); protected static void scoreSingleHit(IndexSearcher indexSearcher, int topN, LTRScoringQuery.ModelWeight modelWeight, int docBase, int hitUpto, ScoreDoc hit, int docID, LTRScoringQuery rerankingQuery, LTRScoringQuery.ModelWeight.ModelScorer scorer, ScoreDoc[] reranked) throws IOException {
hit.score = scorer.score(); final FeatureLogger featureLogger = rerankingQuery.getFeatureLogger();
if (hitUpto < topN) { // Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to
reranked[hitUpto] = hit; // call score
// if the heap is not full, maybe I want to log the features for this // even if no feature scorers match, since a model might use that info to
// document // return a
// non-zero score. Same applies for the case of advancing a LTRScoringQuery.ModelWeight.ModelScorer
// past the target
// doc since the model algorithm still needs to compute a potentially
// non-zero score from blank features.
assert (scorer != null);
final int targetDoc = docID - docBase;
scorer.docID();
scorer.iterator().advance(targetDoc);
scorer.getDocInfo().setOriginalDocScore(hit.score);
hit.score = scorer.score();
if (hitUpto < topN) {
reranked[hitUpto] = hit;
// if the heap is not full, maybe I want to log the features for this
// document
if (featureLogger != null && indexSearcher instanceof SolrIndexSearcher) {
featureLogger.log(hit.doc, rerankingQuery, (SolrIndexSearcher) indexSearcher,
modelWeight.getFeaturesInfo());
}
} else if (hitUpto == topN) {
// collected topN document, I create the heap
heapify(reranked, topN);
}
if (hitUpto >= topN) {
// once that heap is ready, if the score of this document is lower that
// the minimum
// i don't want to log the feature. Otherwise I replace it with the
// minimum and fix the
// heap.
if (hit.score > reranked[0].score) {
reranked[0] = hit;
heapAdjust(reranked, topN, 0);
if (featureLogger != null && indexSearcher instanceof SolrIndexSearcher) { if (featureLogger != null && indexSearcher instanceof SolrIndexSearcher) {
featureLogger.log(hit.doc, scoringQuery, (SolrIndexSearcher)indexSearcher, featureLogger.log(hit.doc, rerankingQuery, (SolrIndexSearcher) indexSearcher,
modelWeight.getFeaturesInfo()); modelWeight.getFeaturesInfo());
} }
} else if (hitUpto == topN) {
// collected topN document, I create the heap
heapify(reranked, topN);
} }
if (hitUpto >= topN) {
// once that heap is ready, if the score of this document is lower that
// the minimum
// i don't want to log the feature. Otherwise I replace it with the
// minimum and fix the
// heap.
if (hit.score > reranked[0].score) {
reranked[0] = hit;
heapAdjust(reranked, topN, 0);
if (featureLogger != null && indexSearcher instanceof SolrIndexSearcher) {
featureLogger.log(hit.doc, scoringQuery, (SolrIndexSearcher)indexSearcher,
modelWeight.getFeaturesInfo());
}
}
}
hitUpto++;
} }
} }
@Override @Override
public Explanation explain(IndexSearcher searcher, public Explanation explain(IndexSearcher searcher,
Explanation firstPassExplanation, int docID) throws IOException { Explanation firstPassExplanation, int docID) throws IOException {
return getExplanation(searcher, docID, scoringQuery);
}
protected static Explanation getExplanation(IndexSearcher searcher, int docID, LTRScoringQuery rerankingQuery) throws IOException {
final List<LeafReaderContext> leafContexts = searcher.getTopReaderContext() final List<LeafReaderContext> leafContexts = searcher.getTopReaderContext()
.leaves(); .leaves();
final int n = ReaderUtil.subIndex(docID, leafContexts); final int n = ReaderUtil.subIndex(docID, leafContexts);
final LeafReaderContext context = leafContexts.get(n); final LeafReaderContext context = leafContexts.get(n);
final int deBasedDoc = docID - context.docBase; final int deBasedDoc = docID - context.docBase;
final Weight modelWeight = searcher.createWeight(searcher.rewrite(scoringQuery), final Weight rankingWeight;
ScoreMode.COMPLETE, 1); if (rerankingQuery instanceof OriginalRankingLTRScoringQuery) {
return modelWeight.explain(context, deBasedDoc); rankingWeight = rerankingQuery.getOriginalQuery().createWeight(searcher, ScoreMode.COMPLETE, 1);
} else {
rankingWeight = searcher.createWeight(searcher.rewrite(rerankingQuery),
ScoreMode.COMPLETE, 1);
}
return rankingWeight.explain(context, deBasedDoc);
} }
public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo(LTRScoringQuery.ModelWeight modelWeight, public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo(LTRScoringQuery.ModelWeight modelWeight,

View File

@ -102,6 +102,10 @@ public class LTRScoringQuery extends Query implements Accountable {
return ltrScoringModel; return ltrScoringModel;
} }
public String getScoringModelName() {
return ltrScoringModel.getName();
}
public void setFeatureLogger(FeatureLogger fl) { public void setFeatureLogger(FeatureLogger fl) {
this.fl = fl; this.fl = fl;
} }

View File

@ -26,8 +26,8 @@ public class SolrQueryRequestContextUtils {
/** key of the feature logger in the request context **/ /** key of the feature logger in the request context **/
private static final String FEATURE_LOGGER = LTR_PREFIX + "feature_logger"; private static final String FEATURE_LOGGER = LTR_PREFIX + "feature_logger";
/** key of the scoring query in the request context **/ /** key of the scoring queries in the request context **/
private static final String SCORING_QUERY = LTR_PREFIX + "scoring_query"; private static final String SCORING_QUERIES = LTR_PREFIX + "scoring_queries";
/** key of the isExtractingFeatures flag in the request context **/ /** key of the isExtractingFeatures flag in the request context **/
private static final String IS_EXTRACTING_FEATURES = LTR_PREFIX + "isExtractingFeatures"; private static final String IS_EXTRACTING_FEATURES = LTR_PREFIX + "isExtractingFeatures";
@ -47,12 +47,12 @@ public class SolrQueryRequestContextUtils {
/** scoring query accessors **/ /** scoring query accessors **/
public static void setScoringQuery(SolrQueryRequest req, LTRScoringQuery scoringQuery) { public static void setScoringQueries(SolrQueryRequest req, LTRScoringQuery[] scoringQueries) {
req.getContext().put(SCORING_QUERY, scoringQuery); req.getContext().put(SCORING_QUERIES, scoringQueries);
} }
public static LTRScoringQuery getScoringQuery(SolrQueryRequest req) { public static LTRScoringQuery[] getScoringQueries(SolrQueryRequest req) {
return (LTRScoringQuery) req.getContext().get(SCORING_QUERY); return (LTRScoringQuery[]) req.getContext().get(SCORING_QUERIES);
} }
/** isExtractingFeatures flag accessors **/ /** isExtractingFeatures flag accessors **/

View File

@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.interleaving;
import org.apache.lucene.search.ScoreDoc;
import org.apache.solr.ltr.interleaving.algorithms.TeamDraftInterleaving;
/**
* Interleaving considers two ranking models: modelA and modelB.
* For a given query, each model returns its ranked list of documents La = (a1,a2,...) and Lb = (b1, b2, ...).
* An Interleaving algorithm creates a unique ranked list I = (i1, i2, ...).
* This list is created by interleaving elements from the two lists la and lb as described by the implementation algorithm.
* Each element Ij is labelled TeamA if it is selected from La and TeamB if it is selected from Lb.
*/
public interface Interleaving {
String TEAM_DRAFT = "TeamDraft";
InterleavingResult interleave(ScoreDoc[] rerankedA, ScoreDoc[] rerankedB);
static Interleaving getImplementation(String algorithm) {
switch(algorithm) {
case TEAM_DRAFT:
default:
return new TeamDraftInterleaving();
}
}
}

View File

@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.interleaving;
import java.util.ArrayList;
import java.util.Set;
import org.apache.lucene.search.ScoreDoc;
public class InterleavingResult {
final private ScoreDoc[] interleavedResults;
final private ArrayList<Set<Integer>> interleavingPicks;
public InterleavingResult(ScoreDoc[] interleavedResults, ArrayList<Set<Integer>> interleavingPicks) {
this.interleavedResults = interleavedResults;
this.interleavingPicks = interleavingPicks;
}
public ScoreDoc[] getInterleavedResults() {
return interleavedResults;
}
public ArrayList<Set<Integer>> getInterleavingPicks() {
return interleavingPicks;
}
}

View File

@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.interleaving;
import java.io.IOException;
import java.util.Arrays;
import org.apache.lucene.search.Query;
import org.apache.solr.ltr.search.LTRQuery;
import org.apache.solr.search.RankQuery;
/**
* A learning to rank Query with Interleaving, will incapsulate two models, and delegate to it the rescoring
* of the documents.
**/
public class LTRInterleavingQuery extends LTRQuery {
private final LTRInterleavingScoringQuery[] rerankingQueries;
private final Interleaving interlavingAlgorithm;
public LTRInterleavingQuery(Interleaving interleavingAlgorithm, LTRInterleavingScoringQuery[] rerankingQueries, int rerankDocs) {
super(null, rerankDocs, new LTRInterleavingRescorer(interleavingAlgorithm, rerankingQueries));
this.rerankingQueries = rerankingQueries;
this.interlavingAlgorithm = interleavingAlgorithm;
}
@Override
public int hashCode() {
return 31 * classHash() + (mainQuery.hashCode() + rerankingQueries.hashCode() + reRankDocs);
}
@Override
public boolean equals(Object o) {
return sameClassAs(o) && equalsTo(getClass().cast(o));
}
private boolean equalsTo(LTRInterleavingQuery other) {
return (mainQuery.equals(other.mainQuery)
&& rerankingQueries.equals(other.rerankingQueries) && (reRankDocs == other.reRankDocs));
}
@Override
public RankQuery wrap(Query _mainQuery) {
super.wrap(_mainQuery);
for(LTRInterleavingScoringQuery rerankingQuery: rerankingQueries){
rerankingQuery.setOriginalQuery(_mainQuery);
}
return this;
}
@Override
public String toString(String field) {
return "{!ltr mainQuery='" + mainQuery.toString() + "' rerankingQueries='"
+ Arrays.toString(rerankingQueries) + "' reRankDocs=" + reRankDocs + "}";
}
@Override
protected Query rewrite(Query rewrittenMainQuery) throws IOException {
return new LTRInterleavingQuery(interlavingAlgorithm, rerankingQueries, reRankDocs).wrap(rewrittenMainQuery);
}
}

View File

@ -0,0 +1,162 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.interleaving;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs;
import org.apache.solr.ltr.LTRRescorer;
import org.apache.solr.ltr.LTRScoringQuery;
/**
* Implements the rescoring logic. The top documents returned by solr with their
* original scores, will be processed by a {@link LTRScoringQuery} that will assign a
* new score to each document. The top documents will be resorted based on the
* new score.
* */
public class LTRInterleavingRescorer extends LTRRescorer {
final private LTRInterleavingScoringQuery[] rerankingQueries;
private Integer originalRankingIndex = null;
final private Interleaving interleavingAlgorithm;
public LTRInterleavingRescorer( Interleaving interleavingAlgorithm, LTRInterleavingScoringQuery[] rerankingQueries) {
this.rerankingQueries = rerankingQueries;
this.interleavingAlgorithm = interleavingAlgorithm;
for(int i=0;i<this.rerankingQueries.length;i++){
if(this.rerankingQueries[i] instanceof OriginalRankingLTRScoringQuery){
this.originalRankingIndex = i;
}
}
}
/**
* rescores the documents:
*
* @param searcher
* current IndexSearcher
* @param firstPassTopDocs
* documents to rerank;
* @param topN
* documents to return;
*/
@Override
public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs,
int topN) throws IOException {
if ((topN == 0) || (firstPassTopDocs.scoreDocs.length == 0)) {
return firstPassTopDocs;
}
ScoreDoc[] firstPassResults = null;
if(originalRankingIndex != null) {
firstPassResults = new ScoreDoc[firstPassTopDocs.scoreDocs.length];
System.arraycopy(firstPassTopDocs.scoreDocs, 0, firstPassResults, 0, firstPassTopDocs.scoreDocs.length);
}
topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value));
ScoreDoc[][] reRankedPerModel = rerank(searcher,topN,getFirstPassDocsRanked(firstPassTopDocs));
if (originalRankingIndex != null) {
reRankedPerModel[originalRankingIndex] = firstPassResults;
}
InterleavingResult interleaved = interleavingAlgorithm.interleave(reRankedPerModel[0], reRankedPerModel[1]);
ScoreDoc[] interleavedResults = interleaved.getInterleavedResults();
ArrayList<Set<Integer>> interleavingPicks = interleaved.getInterleavingPicks();
rerankingQueries[0].setPickedInterleavingDocIds(interleavingPicks.get(0));
rerankingQueries[1].setPickedInterleavingDocIds(interleavingPicks.get(1));
return new TopDocs(firstPassTopDocs.totalHits, interleavedResults);
}
private ScoreDoc[][] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPassResults) throws IOException {
ScoreDoc[][] reRankedPerModel = new ScoreDoc[rerankingQueries.length][topN];
final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
LTRScoringQuery.ModelWeight[] modelWeights = new LTRScoringQuery.ModelWeight[rerankingQueries.length];
for (int i = 0; i < rerankingQueries.length; i++) {
if (originalRankingIndex == null || originalRankingIndex != i) {
modelWeights[i] = (LTRScoringQuery.ModelWeight) searcher
.createWeight(searcher.rewrite(rerankingQueries[i]), ScoreMode.COMPLETE, 1);
}
}
scoreFeatures(searcher, topN, modelWeights, firstPassResults, leaves, reRankedPerModel);
for (int i = 0; i < rerankingQueries.length; i++) {
if (originalRankingIndex == null || originalRankingIndex != i) {
sortByScore(reRankedPerModel[i]);
}
}
return reRankedPerModel;
}
public void scoreFeatures(IndexSearcher indexSearcher,
int topN, LTRScoringQuery.ModelWeight[] modelWeights, ScoreDoc[] hits, List<LeafReaderContext> leaves,
ScoreDoc[][] rerankedPerModel) throws IOException {
int readerUpto = -1;
int endDoc = 0;
int docBase = 0;
int hitUpto = 0;
LTRScoringQuery.ModelWeight.ModelScorer[] scorers = new LTRScoringQuery.ModelWeight.ModelScorer[rerankingQueries.length];
while (hitUpto < hits.length) {
final ScoreDoc hit = hits[hitUpto];
final int docID = hit.doc;
LeafReaderContext readerContext = null;
while (docID >= endDoc) {
readerUpto++;
readerContext = leaves.get(readerUpto);
endDoc = readerContext.docBase + readerContext.reader().maxDoc();
}
// We advanced to another segment
if (readerContext != null) {
docBase = readerContext.docBase;
for (int i = 0; i < modelWeights.length; i++) {
if (modelWeights[i] != null) {
scorers[i] = modelWeights[i].scorer(readerContext);
}
}
}
for (int i = 0; i < rerankingQueries.length; i++) {
if (modelWeights[i] != null) {
scoreSingleHit(indexSearcher, topN, modelWeights[i], docBase, hitUpto, new ScoreDoc(hit.doc, hit.score, hit.shardIndex), docID, rerankingQueries[i], scorers[i], rerankedPerModel[i]);
}
}
hitUpto++;
}
}
@Override
public Explanation explain(IndexSearcher searcher,
Explanation firstPassExplanation, int docID) throws IOException {
LTRScoringQuery pickedRerankModel = rerankingQueries[0];
if (rerankingQueries[1].getPickedInterleavingDocIds().contains(docID)) {
pickedRerankModel = rerankingQueries[1];
}
return getExplanation(searcher, docID, pickedRerankModel);
}
}

View File

@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.interleaving;
import java.util.Map;
import java.util.Set;
import org.apache.solr.ltr.LTRScoringQuery;
import org.apache.solr.ltr.LTRThreadModule;
import org.apache.solr.ltr.model.LTRScoringModel;
public class LTRInterleavingScoringQuery extends LTRScoringQuery {
// Model was picked for this Docs
private Set<Integer> pickedInterleavingDocIds;
public LTRInterleavingScoringQuery(LTRScoringModel ltrScoringModel) {
super(ltrScoringModel);
}
public LTRInterleavingScoringQuery(LTRScoringModel ltrScoringModel, boolean extractAllFeatures) {
super(ltrScoringModel, extractAllFeatures);
}
public LTRInterleavingScoringQuery(LTRScoringModel ltrScoringModel,
Map<String, String[]> externalFeatureInfo,
boolean extractAllFeatures, LTRThreadModule ltrThreadMgr) {
super(ltrScoringModel, externalFeatureInfo, extractAllFeatures, ltrThreadMgr);
}
public Set<Integer> getPickedInterleavingDocIds() {
return pickedInterleavingDocIds;
}
public void setPickedInterleavingDocIds(Set<Integer> pickedInterleavingDocIds) {
this.pickedInterleavingDocIds = pickedInterleavingDocIds;
}
}

View File

@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.interleaving;
public final class OriginalRankingLTRScoringQuery extends LTRInterleavingScoringQuery {
private final String originalRankingModelName;
public OriginalRankingLTRScoringQuery(String originalRankingModelName) {
super(null /* LTRScoringModel */);
this.originalRankingModelName = originalRankingModelName;
}
@Override
public String getScoringModelName() {
return this.originalRankingModelName;
}
}

View File

@ -0,0 +1,127 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.interleaving.algorithms;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Random;
import java.util.Set;
import org.apache.lucene.search.ScoreDoc;
import org.apache.solr.ltr.interleaving.Interleaving;
import org.apache.solr.ltr.interleaving.InterleavingResult;
/**
* Interleaving was introduced the first time by Joachims in [1, 2].
* Team Draft Interleaving is among the most successful and used interleaving approaches[3].
* Team Draft Interleaving implements a method similar to the way in which captains select their players in team-matches.
* Team Draft Interleaving produces a fair distribution of ranking models elements in the final interleaved list.
* "Team draft interleaving" has also proved to overcome an issue of the "Balanced interleaving" approach, in determining the winning model[4].
* <p>
* [1] T. Joachims. Optimizing search engines using clickthrough data. KDD (2002)
* [2] T.Joachims.Evaluatingretrievalperformanceusingclickthroughdata.InJ.Franke, G. Nakhaeizadeh, and I. Renz, editors,
* Text Mining, pages 7996. Physica/Springer (2003)
* [3] F. Radlinski, M. Kurup, and T. Joachims. How does clickthrough data reflect re-
* trieval quality? In CIKM, pages 4352. ACM Press (2008)
* [4] O. Chapelle, T. Joachims, F. Radlinski, and Y. Yue.
* Large-scale validation and analysis of interleaved search evaluation. ACM TOIS, 30(1):141, Feb. (2012)
*/
public class TeamDraftInterleaving implements Interleaving {
public static Random RANDOM;
static {
// We try to make things reproducible in the context of our tests by initializing the random instance
// based on the current seed
String seed = System.getProperty("tests.seed");
if (seed == null) {
RANDOM = new Random();
} else {
RANDOM = new Random(seed.hashCode());
}
}
/**
* Team Draft Interleaving considers two ranking models: modelA and modelB.
* For a given query, each model returns its ranked list of documents La = (a1,a2,...) and Lb = (b1, b2, ...).
* The algorithm creates a unique ranked list I = (i1, i2, ...).
* This list is created by interleaving elements from the two lists la and lb as described by Chapelle et al.[1].
* Each element Ij is labelled TeamA if it is selected from La and TeamB if it is selected from Lb.
* <p>
* [1] O. Chapelle, T. Joachims, F. Radlinski, and Y. Yue.
* Large-scale validation and analysis of interleaved search evaluation. ACM TOIS, 30(1):141, Feb. (2012)
* <p>
* Assumptions:
* - rerankedA and rerankedB has the same length.
* They contains the same search results, ranked differently by two ranking models
* - each reranked list can not contain the same search result more than once.
* - results are all from the same shard
*
* @param rerankedA a ranked list of search results produced by a ranking model A
* @param rerankedB a ranked list of search results produced by a ranking model B
* @return the interleaved ranking list
*/
public InterleavingResult interleave(ScoreDoc[] rerankedA, ScoreDoc[] rerankedB) {
LinkedList<ScoreDoc> interleavedResults = new LinkedList<>();
HashSet<Integer> alreadyAdded = new HashSet<>();
ScoreDoc[] interleavedResultArray = new ScoreDoc[rerankedA.length];
ArrayList<Set<Integer>> interleavingPicks = new ArrayList<>(2);
Set<Integer> teamA = new HashSet<>();
Set<Integer> teamB = new HashSet<>();
int topN = rerankedA.length;
int indexA = 0, indexB = 0;
while (interleavedResults.size() < topN && indexA < rerankedA.length && indexB < rerankedB.length) {
if(teamA.size()<teamB.size() || (teamA.size()==teamB.size() && !RANDOM.nextBoolean())){
indexA = updateIndex(alreadyAdded, indexA, rerankedA);
interleavedResults.add(rerankedA[indexA]);
alreadyAdded.add(rerankedA[indexA].doc);
teamA.add(rerankedA[indexA].doc);
indexA++;
} else{
indexB = updateIndex(alreadyAdded,indexB,rerankedB);
interleavedResults.add(rerankedB[indexB]);
alreadyAdded.add(rerankedB[indexB].doc);
teamB.add(rerankedB[indexB].doc);
indexB++;
}
}
interleavingPicks.add(teamA);
interleavingPicks.add(teamB);
interleavedResultArray = interleavedResults.toArray(interleavedResultArray);
return new InterleavingResult(interleavedResultArray,interleavingPicks);
}
private int updateIndex(HashSet<Integer> alreadyAdded, int index, ScoreDoc[] reranked) {
boolean foundElementToAdd = false;
while (index < reranked.length && !foundElementToAdd) {
ScoreDoc elementToCheck = reranked[index];
if (alreadyAdded.contains(elementToCheck.doc)) {
index++;
} else {
foundElementToAdd = true;
}
}
return index;
}
public static void setRANDOM(Random RANDOM) {
TeamDraftInterleaving.RANDOM = RANDOM;
}
}

View File

@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Contains Various Interleaving Algorithms
*/
package org.apache.solr.ltr.interleaving.algorithms;

View File

@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Contains Various Interleaving auxiliary classes
*/
package org.apache.solr.ltr.interleaving;

View File

@ -36,6 +36,8 @@ import org.apache.solr.ltr.LTRScoringQuery;
import org.apache.solr.ltr.LTRThreadModule; import org.apache.solr.ltr.LTRThreadModule;
import org.apache.solr.ltr.SolrQueryRequestContextUtils; import org.apache.solr.ltr.SolrQueryRequestContextUtils;
import org.apache.solr.ltr.feature.Feature; import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.interleaving.LTRInterleavingScoringQuery;
import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery;
import org.apache.solr.ltr.model.LTRScoringModel; import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.norm.Normalizer; import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.ltr.search.LTRQParserPlugin; import org.apache.solr.ltr.search.LTRQParserPlugin;
@ -126,14 +128,15 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory {
SolrQueryRequestContextUtils.setIsExtractingFeatures(req); SolrQueryRequestContextUtils.setIsExtractingFeatures(req);
// Communicate which feature store we are requesting features for // Communicate which feature store we are requesting features for
SolrQueryRequestContextUtils.setFvStoreName(req, localparams.get(FV_STORE, defaultStore)); final String fvStoreName = localparams.get(FV_STORE);
SolrQueryRequestContextUtils.setFvStoreName(req, (fvStoreName == null ? defaultStore : fvStoreName));
// Create and supply the feature logger to be used // Create and supply the feature logger to be used
SolrQueryRequestContextUtils.setFeatureLogger(req, SolrQueryRequestContextUtils.setFeatureLogger(req,
createFeatureLogger( createFeatureLogger(
localparams.get(FV_FORMAT))); localparams.get(FV_FORMAT)));
return new FeatureTransformer(name, localparams, req); return new FeatureTransformer(name, localparams, req, (fvStoreName != null) /* hasExplicitFeatureStore */);
} }
/** /**
@ -163,11 +166,18 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory {
final private String name; final private String name;
final private SolrParams localparams; final private SolrParams localparams;
final private SolrQueryRequest req; final private SolrQueryRequest req;
final private boolean hasExplicitFeatureStore;
private List<LeafReaderContext> leafContexts; private List<LeafReaderContext> leafContexts;
private SolrIndexSearcher searcher; private SolrIndexSearcher searcher;
private LTRScoringQuery scoringQuery; /**
private LTRScoringQuery.ModelWeight modelWeight; * rerankingQueries, modelWeights have:
* length=1 - [Classic LTR] When reranking with a single model
* length=2 - [Interleaving] When reranking with interleaving (two ranking models are involved)
*/
private LTRScoringQuery[] rerankingQueriesFromContext;
private LTRScoringQuery[] rerankingQueries;
private LTRScoringQuery.ModelWeight[] modelWeights;
private FeatureLogger featureLogger; private FeatureLogger featureLogger;
private boolean docsWereNotReranked; private boolean docsWereNotReranked;
@ -177,10 +187,11 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory {
* feature vectors * feature vectors
*/ */
public FeatureTransformer(String name, SolrParams localparams, public FeatureTransformer(String name, SolrParams localparams,
SolrQueryRequest req) { SolrQueryRequest req, boolean hasExplicitFeatureStore) {
this.name = name; this.name = name;
this.localparams = localparams; this.localparams = localparams;
this.req = req; this.req = req;
this.hasExplicitFeatureStore = hasExplicitFeatureStore;
} }
@Override @Override
@ -209,51 +220,102 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory {
threadManager.setExecutor(context.getRequest().getCore().getCoreContainer().getUpdateShardHandler().getUpdateExecutor()); threadManager.setExecutor(context.getRequest().getCore().getCoreContainer().getUpdateShardHandler().getUpdateExecutor());
} }
// Setup LTRScoringQuery rerankingQueriesFromContext = SolrQueryRequestContextUtils.getScoringQueries(req);
scoringQuery = SolrQueryRequestContextUtils.getScoringQuery(req); docsWereNotReranked = (rerankingQueriesFromContext == null || rerankingQueriesFromContext.length == 0);
docsWereNotReranked = (scoringQuery == null); String transformerFeatureStore = SolrQueryRequestContextUtils.getFvStoreName(req);
String featureStoreName = SolrQueryRequestContextUtils.getFvStoreName(req); Map<String, String[]> transformerExternalFeatureInfo = LTRQParserPlugin.extractEFIParams(localparams);
if (docsWereNotReranked || (featureStoreName != null && (!featureStoreName.equals(scoringQuery.getScoringModel().getFeatureStoreName())))) {
// if store is set in the transformer we should overwrite the logger
final ManagedFeatureStore fr = ManagedFeatureStore.getManagedFeatureStore(req.getCore()); final LoggingModel loggingModel = createLoggingModel(transformerFeatureStore);
setupRerankingQueriesForLogging(transformerFeatureStore, transformerExternalFeatureInfo, loggingModel);
setupRerankingWeightsForLogging(context);
}
final FeatureStore store = fr.getFeatureStore(featureStoreName); /**
featureStoreName = store.getName(); // if featureStoreName was null before this gets actual name * The loggingModel is an empty model that is just used to extract the features
* and log them
* @param transformerFeatureStore the explicit transformer feature store
*/
private LoggingModel createLoggingModel(String transformerFeatureStore) {
final ManagedFeatureStore fr = ManagedFeatureStore.getManagedFeatureStore(req.getCore());
try { final FeatureStore store = fr.getFeatureStore(transformerFeatureStore);
final LoggingModel lm = new LoggingModel(loggingModelName, transformerFeatureStore = store.getName(); // if transformerFeatureStore was null before this gets actual name
featureStoreName, store.getFeatures());
scoringQuery = new LTRScoringQuery(lm, return new LoggingModel(loggingModelName,
LTRQParserPlugin.extractEFIParams(localparams), transformerFeatureStore, store.getFeatures());
true, }
threadManager); // request feature weights to be created for all features
}catch (final Exception e) { /**
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, * When preparing the reranking queries for logging features various scenarios apply:
"retrieving the feature store "+featureStoreName, e); *
* No Reranking
* There is the need of a logger model from the default feature store or the explicit feature store passed
* to extract the feature vector
*
* Re Ranking
* 1) If no explicit feature store is passed, the models for each reranking query can be safely re-used
* the feature vector can be fetched from the feature vector cache.
* 2) If an explicit feature store is passed, and no reranking query uses a model with that feature store,
* There is the need of a logger model to extract the feature vector
* 3) If an explicit feature store is passed, and there is a reranking query that uses a model with that feature store,
* the model can be re-used and there is no need for a logging model
*
* @param transformerFeatureStore explicit feature store for the transformer
* @param transformerExternalFeatureInfo explicit efi for the transformer
*/
private void setupRerankingQueriesForLogging(String transformerFeatureStore, Map<String, String[]> transformerExternalFeatureInfo, LoggingModel loggingModel) {
if (docsWereNotReranked) { //no reranking query
LTRScoringQuery loggingQuery = new LTRScoringQuery(loggingModel,
transformerExternalFeatureInfo,
true /* extractAllFeatures */,
threadManager);
rerankingQueries = new LTRScoringQuery[]{loggingQuery};
} else {
rerankingQueries = new LTRScoringQuery[rerankingQueriesFromContext.length];
System.arraycopy(rerankingQueriesFromContext, 0, rerankingQueries, 0, rerankingQueriesFromContext.length);
if (transformerFeatureStore != null) {// explicit feature store for the transformer
LTRScoringModel matchingRerankingModel = loggingModel;
for (LTRScoringQuery rerankingQuery : rerankingQueries) {
if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) &&
transformerFeatureStore.equals(rerankingQuery.getScoringModel().getFeatureStoreName())) {
matchingRerankingModel = rerankingQuery.getScoringModel();
}
}
for (int i = 0; i < rerankingQueries.length; i++) {
rerankingQueries[i] = new LTRScoringQuery(
matchingRerankingModel,
(!transformerExternalFeatureInfo.isEmpty() ? transformerExternalFeatureInfo : rerankingQueries[i].getExternalFeatureInfo()),
true /* extractAllFeatures */,
threadManager);
}
} }
} }
}
if (scoringQuery.getOriginalQuery() == null) { private void setupRerankingWeightsForLogging(ResultContext context) {
scoringQuery.setOriginalQuery(context.getQuery()); modelWeights = new LTRScoringQuery.ModelWeight[rerankingQueries.length];
} for (int i = 0; i < rerankingQueries.length; i++) {
if (scoringQuery.getFeatureLogger() == null){ if (rerankingQueries[i].getOriginalQuery() == null) {
scoringQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) ); rerankingQueries[i].setOriginalQuery(context.getQuery());
} }
scoringQuery.setRequest(req); rerankingQueries[i].setRequest(req);
if (!(rerankingQueries[i] instanceof OriginalRankingLTRScoringQuery) || hasExplicitFeatureStore) {
featureLogger = scoringQuery.getFeatureLogger(); if (rerankingQueries[i].getFeatureLogger() == null) {
rerankingQueries[i].setFeatureLogger(SolrQueryRequestContextUtils.getFeatureLogger(req));
try { }
modelWeight = scoringQuery.createWeight(searcher, ScoreMode.COMPLETE, 1f); featureLogger = rerankingQueries[i].getFeatureLogger();
} catch (final IOException e) { try {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e.getMessage(), e); modelWeights[i] = rerankingQueries[i].createWeight(searcher, ScoreMode.COMPLETE, 1f);
} } catch (final IOException e) {
if (modelWeight == null) { throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e.getMessage(), e);
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, }
"error logging the features, model weight is null"); if (modelWeights[i] == null) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
"error logging the features, model weight is null");
}
}
} }
} }
@ -271,17 +333,26 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory {
private void implTransform(SolrDocument doc, int docid, Float score) private void implTransform(SolrDocument doc, int docid, Float score)
throws IOException { throws IOException {
Object fv = featureLogger.getFeatureVector(docid, scoringQuery, searcher); LTRScoringQuery rerankingQuery = rerankingQueries[0];
if (fv == null) { // FV for this document was not in the cache LTRScoringQuery.ModelWeight rerankingModelWeight = modelWeights[0];
fv = featureLogger.makeFeatureVector( for (int i = 1; i < rerankingQueries.length; i++) {
LTRRescorer.extractFeaturesInfo( if (((LTRInterleavingScoringQuery)rerankingQueriesFromContext[i]).getPickedInterleavingDocIds().contains(docid)) {
modelWeight, rerankingQuery = rerankingQueries[i];
docid, rerankingModelWeight = modelWeights[i];
(docsWereNotReranked ? score : null), }
leafContexts)); }
if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) || hasExplicitFeatureStore) {
Object featureVector = featureLogger.getFeatureVector(docid, rerankingQuery, searcher);
if (featureVector == null) { // FV for this document was not in the cache
featureVector = featureLogger.makeFeatureVector(
LTRRescorer.extractFeaturesInfo(
rerankingModelWeight,
docid,
(docsWereNotReranked ? score : null),
leafContexts));
}
doc.addField(name, featureVector);
} }
doc.addField(name, fv);
} }
} }

View File

@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.response.transform;
import java.io.IOException;
import org.apache.solr.common.SolrDocument;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.ltr.interleaving.LTRInterleavingScoringQuery;
import org.apache.solr.ltr.LTRScoringQuery;
import org.apache.solr.ltr.SolrQueryRequestContextUtils;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.response.ResultContext;
import org.apache.solr.response.transform.DocTransformer;
import org.apache.solr.response.transform.TransformerFactory;
import org.apache.solr.util.SolrPluginUtils;
public class LTRInterleavingTransformerFactory extends TransformerFactory {
@Override
@SuppressWarnings({"unchecked"})
public void init(@SuppressWarnings("rawtypes") NamedList args) {
super.init(args);
SolrPluginUtils.invokeSetters(this, args);
}
@Override
public DocTransformer create(String name, SolrParams localparams,
SolrQueryRequest req) {
return new InterleavingTransformer(name, req);
}
class InterleavingTransformer extends DocTransformer {
final private String name;
final private SolrQueryRequest req;
private LTRInterleavingScoringQuery[] rerankingQueries;
/**
* @param name
* Name of the field to be added in a document representing the
* model picked by the interleaving process
*/
public InterleavingTransformer(String name,
SolrQueryRequest req) {
this.name = name;
this.req = req;
}
@Override
public String getName() {
return name;
}
@Override
public void setContext(ResultContext context) {
super.setContext(context);
if (context == null) {
return;
}
if (context.getRequest() == null) {
return;
}
rerankingQueries = (LTRInterleavingScoringQuery[])SolrQueryRequestContextUtils.getScoringQueries(req);
for (int i = 0; i < rerankingQueries.length; i++) {
LTRScoringQuery scoringQuery = rerankingQueries[i];
if (scoringQuery.getOriginalQuery() == null) {
scoringQuery.setOriginalQuery(context.getQuery());
}
if (scoringQuery.getFeatureLogger() == null) {
scoringQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) );
}
scoringQuery.setRequest(req);
}
}
@Override
public void transform(SolrDocument doc, int docid, float score)
throws IOException {
implTransform(doc, docid);
}
@Override
public void transform(SolrDocument doc, int docid)
throws IOException {
implTransform(doc, docid);
}
private void implTransform(SolrDocument doc, int docid) {
LTRScoringQuery rerankingQuery = rerankingQueries[0];
if (rerankingQueries.length > 1 && rerankingQueries[1].getPickedInterleavingDocIds().contains(docid)) {
rerankingQuery = rerankingQueries[1];
}
doc.addField(name, rerankingQuery.getScoringModelName());
}
}
}

View File

@ -23,26 +23,26 @@ import java.util.Map;
import org.apache.lucene.util.ResourceLoader; import org.apache.lucene.util.ResourceLoader;
import org.apache.lucene.util.ResourceLoaderAware; import org.apache.lucene.util.ResourceLoaderAware;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.solr.common.SolrException; import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.SolrParams; import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList; import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.SolrResourceLoader; import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.ltr.LTRRescorer; import org.apache.solr.ltr.interleaving.LTRInterleavingScoringQuery;
import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery;
import org.apache.solr.ltr.LTRScoringQuery; import org.apache.solr.ltr.LTRScoringQuery;
import org.apache.solr.ltr.LTRThreadModule; import org.apache.solr.ltr.LTRThreadModule;
import org.apache.solr.ltr.SolrQueryRequestContextUtils; import org.apache.solr.ltr.SolrQueryRequestContextUtils;
import org.apache.solr.ltr.interleaving.Interleaving;
import org.apache.solr.ltr.interleaving.LTRInterleavingQuery;
import org.apache.solr.ltr.model.LTRScoringModel; import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.store.rest.ManagedFeatureStore; import org.apache.solr.ltr.store.rest.ManagedFeatureStore;
import org.apache.solr.ltr.store.rest.ManagedModelStore; import org.apache.solr.ltr.store.rest.ManagedModelStore;
import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.rest.ManagedResource; import org.apache.solr.rest.ManagedResource;
import org.apache.solr.rest.ManagedResourceObserver; import org.apache.solr.rest.ManagedResourceObserver;
import org.apache.solr.search.AbstractReRankQuery;
import org.apache.solr.search.QParser; import org.apache.solr.search.QParser;
import org.apache.solr.search.QParserPlugin; import org.apache.solr.search.QParserPlugin;
import org.apache.solr.search.RankQuery;
import org.apache.solr.search.SyntaxError; import org.apache.solr.search.SyntaxError;
import org.apache.solr.util.SolrPluginUtils; import org.apache.solr.util.SolrPluginUtils;
@ -55,7 +55,7 @@ import org.apache.solr.util.SolrPluginUtils;
*/ */
public class LTRQParserPlugin extends QParserPlugin implements ResourceLoaderAware, ManagedResourceObserver { public class LTRQParserPlugin extends QParserPlugin implements ResourceLoaderAware, ManagedResourceObserver {
public static final String NAME = "ltr"; public static final String NAME = "ltr";
private static Query defaultQuery = new MatchAllDocsQuery(); private static final String ORIGINAL_RANKING = "_OriginalRanking_";
// params for setting custom external info that features can use, like query // params for setting custom external info that features can use, like query
// intent // intent
@ -145,94 +145,73 @@ public class LTRQParserPlugin extends QParserPlugin implements ResourceLoaderAwa
@Override @Override
public Query parse() throws SyntaxError { public Query parse() throws SyntaxError {
// ReRanking Model
final String modelName = localParams.get(LTRQParserPlugin.MODEL);
if ((modelName == null) || modelName.isEmpty()) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
"Must provide model in the request");
}
final LTRScoringModel ltrScoringModel = mr.getModel(modelName);
if (ltrScoringModel == null) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
"cannot find " + LTRQParserPlugin.MODEL + " " + modelName);
}
final String modelFeatureStoreName = ltrScoringModel.getFeatureStoreName();
final boolean extractFeatures = SolrQueryRequestContextUtils.isExtractingFeatures(req);
final String fvStoreName = SolrQueryRequestContextUtils.getFvStoreName(req);
// Check if features are requested and if the model feature store and feature-transform feature store are the same
final boolean featuresRequestedFromSameStore = (modelFeatureStoreName.equals(fvStoreName) || fvStoreName == null) ? extractFeatures:false;
if (threadManager != null) { if (threadManager != null) {
threadManager.setExecutor(req.getCore().getCoreContainer().getUpdateShardHandler().getUpdateExecutor()); threadManager.setExecutor(req.getCore().getCoreContainer().getUpdateShardHandler().getUpdateExecutor());
} }
final LTRScoringQuery scoringQuery = new LTRScoringQuery(ltrScoringModel, // ReRanking Model
extractEFIParams(localParams), final String[] modelNames = localParams.getParams(LTRQParserPlugin.MODEL);
featuresRequestedFromSameStore, threadManager); if ((modelNames == null) || (modelNames.length!=1 && modelNames.length!=2)) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
// Enable the feature vector caching if we are extracting features, and the features "Must provide one or two models in the request");
// we requested are the same ones we are reranking with }
if (featuresRequestedFromSameStore) { final boolean isInterleaving = (modelNames.length > 1);
scoringQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) ); final boolean extractFeatures = SolrQueryRequestContextUtils.isExtractingFeatures(req);
final String tranformerFeatureStoreName = SolrQueryRequestContextUtils.getFvStoreName(req);
final Map<String,String[]> externalFeatureInfo = extractEFIParams(localParams);
LTRScoringQuery rerankingQuery = null;
LTRInterleavingScoringQuery[] rerankingQueries = new LTRInterleavingScoringQuery[modelNames.length];
for (int i = 0; i < modelNames.length; i++) {
if (modelNames[i].isEmpty()) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
"the " + LTRQParserPlugin.MODEL + " "+ i +" is empty");
}
if (!ORIGINAL_RANKING.equals(modelNames[i])) {
final LTRScoringModel ltrScoringModel = mr.getModel(modelNames[i]);
if (ltrScoringModel == null) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
"cannot find " + LTRQParserPlugin.MODEL + " " + modelNames[i]);
}
final String modelFeatureStoreName = ltrScoringModel.getFeatureStoreName();
// Check if features are requested and if the model feature store and feature-transform feature store are the same
final boolean featuresRequestedFromSameStore = (modelFeatureStoreName.equals(tranformerFeatureStoreName) || tranformerFeatureStoreName == null) ? extractFeatures : false;
if (isInterleaving) {
rerankingQuery = rerankingQueries[i] = new LTRInterleavingScoringQuery(ltrScoringModel,
externalFeatureInfo,
featuresRequestedFromSameStore, threadManager);
} else {
rerankingQuery = new LTRScoringQuery(ltrScoringModel,
externalFeatureInfo,
featuresRequestedFromSameStore, threadManager);
rerankingQueries[i] = null;
}
// Enable the feature vector caching if we are extracting features, and the features
// we requested are the same ones we are reranking with
if (featuresRequestedFromSameStore) {
rerankingQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) );
}
}else{
rerankingQuery = rerankingQueries[i] = new OriginalRankingLTRScoringQuery(ORIGINAL_RANKING);
}
// External features
rerankingQuery.setRequest(req);
} }
SolrQueryRequestContextUtils.setScoringQuery(req, scoringQuery);
int reRankDocs = localParams.getInt(RERANK_DOCS, DEFAULT_RERANK_DOCS); int reRankDocs = localParams.getInt(RERANK_DOCS, DEFAULT_RERANK_DOCS);
if (reRankDocs <= 0) { if (reRankDocs <= 0) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
"Must rerank at least 1 document"); "Must rerank at least 1 document");
}
if (!isInterleaving) {
SolrQueryRequestContextUtils.setScoringQueries(req, new LTRScoringQuery[] { rerankingQuery });
return new LTRQuery(rerankingQuery, reRankDocs);
} else {
SolrQueryRequestContextUtils.setScoringQueries(req, rerankingQueries);
return new LTRInterleavingQuery(Interleaving.getImplementation(Interleaving.TEAM_DRAFT),rerankingQueries, reRankDocs);
} }
// External features
scoringQuery.setRequest(req);
return new LTRQuery(scoringQuery, reRankDocs);
}
}
/**
* A learning to rank Query, will incapsulate a learning to rank model, and delegate to it the rescoring
* of the documents.
**/
public class LTRQuery extends AbstractReRankQuery {
private final LTRScoringQuery scoringQuery;
public LTRQuery(LTRScoringQuery scoringQuery, int reRankDocs) {
super(defaultQuery, reRankDocs, new LTRRescorer(scoringQuery));
this.scoringQuery = scoringQuery;
}
@Override
public int hashCode() {
return 31 * classHash() + (mainQuery.hashCode() + scoringQuery.hashCode() + reRankDocs);
}
@Override
public boolean equals(Object o) {
return sameClassAs(o) && equalsTo(getClass().cast(o));
}
private boolean equalsTo(LTRQuery other) {
return (mainQuery.equals(other.mainQuery)
&& scoringQuery.equals(other.scoringQuery) && (reRankDocs == other.reRankDocs));
}
@Override
public RankQuery wrap(Query _mainQuery) {
super.wrap(_mainQuery);
scoringQuery.setOriginalQuery(_mainQuery);
return this;
}
@Override
public String toString(String field) {
return "{!ltr mainQuery='" + mainQuery.toString() + "' scoringQuery='"
+ scoringQuery.toString() + "' reRankDocs=" + reRankDocs + "}";
}
@Override
protected Query rewrite(Query rewrittenMainQuery) throws IOException {
return new LTRQuery(scoringQuery, reRankDocs).wrap(rewrittenMainQuery);
} }
} }

View File

@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.search;
import java.io.IOException;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.solr.ltr.LTRRescorer;
import org.apache.solr.ltr.LTRScoringQuery;
import org.apache.solr.search.AbstractReRankQuery;
import org.apache.solr.search.RankQuery;
/**
* A learning to rank Query, will incapsulate a learning to rank model, and delegate to it the rescoring
* of the documents.
**/
public class LTRQuery extends AbstractReRankQuery {
private static final Query defaultQuery = new MatchAllDocsQuery();
private final LTRScoringQuery scoringQuery;
public LTRQuery(LTRScoringQuery scoringQuery, int reRankDocs) {
this(scoringQuery, reRankDocs, new LTRRescorer(scoringQuery));
}
protected LTRQuery(LTRScoringQuery scoringQuery, int reRankDocs, LTRRescorer rescorer) {
super(defaultQuery, reRankDocs, rescorer);
this.scoringQuery = scoringQuery;
}
@Override
public int hashCode() {
return 31 * classHash() + (mainQuery.hashCode() + scoringQuery.hashCode() + reRankDocs);
}
@Override
public boolean equals(Object o) {
return sameClassAs(o) && equalsTo(getClass().cast(o));
}
private boolean equalsTo(LTRQuery other) {
return (mainQuery.equals(other.mainQuery)
&& scoringQuery.equals(other.scoringQuery) && (reRankDocs == other.reRankDocs));
}
@Override
public RankQuery wrap(Query _mainQuery) {
super.wrap(_mainQuery);
if (scoringQuery != null) {
scoringQuery.setOriginalQuery(_mainQuery);
}
return this;
}
@Override
public String toString(String field) {
return "{!ltr mainQuery='" + mainQuery.toString() + "' scoringQuery='"
+ scoringQuery.toString() + "' reRankDocs=" + reRankDocs + "}";
}
@Override
protected Query rewrite(Query rewrittenMainQuery) throws IOException {
return new LTRQuery(scoringQuery, reRankDocs).wrap(rewrittenMainQuery);
}
}

View File

@ -39,7 +39,7 @@ A Learning to Rank model is plugged into the ranking through the {@link org.apac
a {@link org.apache.solr.search.QParserPlugin}. The plugin will a {@link org.apache.solr.search.QParserPlugin}. The plugin will
read from the request the model (instance of {@link org.apache.solr.ltr.model.LTRScoringModel}) read from the request the model (instance of {@link org.apache.solr.ltr.model.LTRScoringModel})
used to perform the request plus other used to perform the request plus other
parameters. The plugin will generate a {@link org.apache.solr.ltr.search.LTRQParserPlugin.LTRQuery LTRQuery}: parameters. The plugin will generate a {@link org.apache.solr.ltr.search.LTRQuery LTRQuery}:
a particular {@link org.apache.solr.search.RankQuery} a particular {@link org.apache.solr.search.RankQuery}
that will encapsulate the given model and use it to that will encapsulate the given model and use it to
rescore and rerank the document (by using an {@link org.apache.solr.ltr.LTRRescorer}). rescore and rerank the document (by using an {@link org.apache.solr.ltr.LTRRescorer}).

View File

@ -46,6 +46,15 @@
<str name="fvCacheName">QUERY_DOC_FV</str> <str name="fvCacheName">QUERY_DOC_FV</str>
</transformer> </transformer>
<!-- add a transformer that will encode the model the interleaving process chose the search result from.
For each document the transformer will add an extra field in the response with the model picked.
The name of the field will be the the name of the transformer
enclosed between brackets (in this case [interleaving]).
In order to get the model chosen for the search result
you will have to specify that you want the field (e.g., fl="*,[interleaving]) -->
<transformer name="interleaving" class="org.apache.solr.ltr.response.transform.LTRInterleavingTransformerFactory">
</transformer>
<updateHandler class="solr.DirectUpdateHandler2"> <updateHandler class="solr.DirectUpdateHandler2">
<autoCommit> <autoCommit>
<maxTime>15000</maxTime> <maxTime>15000</maxTime>

View File

@ -16,7 +16,11 @@
*/ */
package org.apache.solr.ltr; package org.apache.solr.ltr;
import java.util.Random;
import org.apache.solr.client.solrj.SolrQuery; import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.ltr.feature.SolrFeature;
import org.apache.solr.ltr.interleaving.algorithms.TeamDraftInterleaving;
import org.apache.solr.ltr.model.LinearModel; import org.apache.solr.ltr.model.LinearModel;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
@ -149,4 +153,160 @@ public class TestLTRQParserExplain extends TestRerankBase {
" 65.0 = tree 1 | \\'user_device_tablet\\':1.0 > 0.500001, Go Right | val: 65.0\n'}"); " 65.0 = tree 1 | \\'user_device_tablet\\':1.0 > 0.500001, Go Right | val: 65.0\n'}");
} }
@Test
public void interleavingModels_shouldReturnExplainForTheModelPicked() throws Exception {
TeamDraftInterleaving.setRANDOM(new Random(10));//Random Boolean Choices Generation from Seed: [1,0]
loadFeature("featureA1", SolrFeature.class.getName(),
"{\"fq\":[\"{!terms f=popularity}1\"]}");
loadFeature("featureA2", SolrFeature.class.getName(),
"{\"fq\":[\"{!terms f=description}bloomberg\"]}");
loadFeature("featureAB", SolrFeature.class.getName(),
"{\"fq\":[\"{!terms f=popularity}2\"]}");
loadFeature("featureB1", SolrFeature.class.getName(),
"{\"fq\":[\"{!terms f=popularity}5\"]}");
loadFeature("featureB2", SolrFeature.class.getName(),
"{\"fq\":[\"{!terms f=title}different\"]}");
loadModel("modelA", LinearModel.class.getName(),
new String[]{"featureA1", "featureA2", "featureAB"},
"{\"weights\":{\"featureA1\":3.0, \"featureA2\":9.0, \"featureAB\":27.0}}");
loadModel("modelB", LinearModel.class.getName(),
new String[]{"featureB1", "featureB2", "featureAB"},
"{\"weights\":{\"featureB1\":2.0, \"featureB2\":4.0, \"featureAB\":8.0}}");
final SolrQuery query = new SolrQuery();
query.setQuery("title:bloomberg");
query.setParam("debugQuery", "on");
query.add("rows", "10");
query.add("rq", "{!ltr reRankDocs=10 model=modelA model=modelB}");
query.add("fl", "*,score");
/*
Doc6 = "featureA1=1.0 featureA2=1.0 featureB2=1.0", ScoreA(12), ScoreB(4)
Doc7 = "featureA2=1.0 featureAB=1.0", ScoreA(36), ScoreB(8)
Doc8 = "featureA2=1.0", ScoreA(9), ScoreB(0)
Doc9 = "featureA2=1.0 featureB1=1.0", ScoreA(9), ScoreB(2)
ModelARerankedList = [7,6,8,9]
ModelBRerankedList = [7,6,9,8]
Random Boolean Choices Generation from Seed: [1,0]
*/
int[] expectedInterleaved = new int[]{7, 6, 8, 9};
String[] expectedExplains = new String[]{
"\n8.0 = LinearModel(name=modelB," +
"featureWeights=[featureB1=2.0,featureB2=4.0,featureAB=8.0]) " +
"model applied to features, sum of:\n " +
"0.0 = prod of:\n 2.0 = weight on feature\n 0.0 = SolrFeature [name=featureB1, params={fq=[{!terms f=popularity}5]}]\n " +
"0.0 = prod of:\n 4.0 = weight on feature\n 0.0 = SolrFeature [name=featureB2, params={fq=[{!terms f=title}different]}]\n " +
"8.0 = prod of:\n 8.0 = weight on feature\n 1.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n",
"\n12.0 = LinearModel(name=modelA," +
"featureWeights=[featureA1=3.0,featureA2=9.0,featureAB=27.0]) " +
"model applied to features, sum of:\n " +
"3.0 = prod of:\n 3.0 = weight on feature\n 1.0 = SolrFeature [name=featureA1, params={fq=[{!terms f=popularity}1]}]\n " +
"9.0 = prod of:\n 9.0 = weight on feature\n 1.0 = SolrFeature [name=featureA2, params={fq=[{!terms f=description}bloomberg]}]\n " +
"0.0 = prod of:\n 27.0 = weight on feature\n 0.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n",
"\n9.0 = LinearModel(name=modelA," +
"featureWeights=[featureA1=3.0,featureA2=9.0,featureAB=27.0]) " +
"model applied to features, sum of:\n " +
"0.0 = prod of:\n 3.0 = weight on feature\n 0.0 = SolrFeature [name=featureA1, params={fq=[{!terms f=popularity}1]}]\n " +
"9.0 = prod of:\n 9.0 = weight on feature\n 1.0 = SolrFeature [name=featureA2, params={fq=[{!terms f=description}bloomberg]}]\n " +
"0.0 = prod of:\n 27.0 = weight on feature\n 0.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n",
"\n2.0 = LinearModel(name=modelB," +
"featureWeights=[featureB1=2.0,featureB2=4.0,featureAB=8.0]) " +
"model applied to features, sum of:\n " +
"2.0 = prod of:\n 2.0 = weight on feature\n 1.0 = SolrFeature [name=featureB1, params={fq=[{!terms f=popularity}5]}]\n " +
"0.0 = prod of:\n 4.0 = weight on feature\n 0.0 = SolrFeature [name=featureB2, params={fq=[{!terms f=title}different]}]\n " +
"0.0 = prod of:\n 8.0 = weight on feature\n 0.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n"};
String[] tests = new String[16];
tests[0] = "/response/numFound/==4";
for (int i = 1; i <= 4; i++) {
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
tests[i + 4] = "/debug/explain/" + expectedInterleaved[(i - 1)] + "=='" + expectedExplains[(i - 1)]+"'}";
}
assertJQ("/query" + query.toQueryString(), tests);
}
@Test
public void interleavingModelsWithOriginalRanking_shouldReturnExplainForTheModelPicked() throws Exception {
TeamDraftInterleaving.setRANDOM(new Random(10));//Random Boolean Choices Generation from Seed: [1,0]
loadFeature("featureA1", SolrFeature.class.getName(),
"{\"fq\":[\"{!terms f=popularity}1\"]}");
loadFeature("featureA2", SolrFeature.class.getName(),
"{\"fq\":[\"{!terms f=description}bloomberg\"]}");
loadFeature("featureAB", SolrFeature.class.getName(),
"{\"fq\":[\"{!terms f=popularity}2\"]}");
loadModel("modelA", LinearModel.class.getName(),
new String[]{"featureA1", "featureA2", "featureAB"},
"{\"weights\":{\"featureA1\":3.0, \"featureA2\":9.0, \"featureAB\":27.0}}");
final SolrQuery query = new SolrQuery();
query.setQuery("title:bloomberg");
query.setParam("debugQuery", "on");
query.add("rows", "10");
query.add("rq", "{!ltr reRankDocs=10 model=modelA model=_OriginalRanking_}");
query.add("fl", "*,score");
/*
Doc6 = "featureA1=1.0 featureA2=1.0 featureB2=1.0", ScoreA(12)
Doc7 = "featureA2=1.0 featureAB=1.0", ScoreA(36)
Doc8 = "featureA2=1.0", ScoreA(9)
Doc9 = "featureA2=1.0 featureB1=1.0", ScoreA(9)
ModelARerankedList = [7,6,8,9]
OriginalRanking = [9,8,7,6]
Random Boolean Choices Generation from Seed: [1,0]
*/
int[] expectedInterleaved = new int[]{9, 7, 6, 8};
String[] expectedExplains = new String[]{
"\n0.07662583 = weight(title:bloomberg in 3) [SchemaSimilarity], result of:\n " +
"0.07662583 = score(freq=4.0), computed as boost * idf * tf from:\n " +
"0.105360515 = idf, computed as log(1 + (N - n + 0.5) / (n + 0.5)) from:\n 4 = n, number of documents containing term\n 4 = N, total number of documents with field\n " +
"0.72727275 = tf, computed as freq / (freq + k1 * (1 - b + b * dl / avgdl)) from:\n 4.0 = freq, occurrences of term within document\n " +
"1.2 = k1, term saturation parameter\n " +
"0.75 = b, length normalization parameter\n " +
"4.0 = dl, length of field\n " +
"3.0 = avgdl, average length of field\n",
"\n36.0 = LinearModel(name=modelA," +
"featureWeights=[featureA1=3.0,featureA2=9.0,featureAB=27.0]) " +
"model applied to features, sum of:\n " +
"0.0 = prod of:\n 3.0 = weight on feature\n 0.0 = SolrFeature [name=featureA1, params={fq=[{!terms f=popularity}1]}]\n " +
"9.0 = prod of:\n 9.0 = weight on feature\n 1.0 = SolrFeature [name=featureA2, params={fq=[{!terms f=description}bloomberg]}]\n " +
"27.0 = prod of:\n 27.0 = weight on feature\n 1.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n",
"\n12.0 = LinearModel(name=modelA," +
"featureWeights=[featureA1=3.0,featureA2=9.0,featureAB=27.0]) " +
"model applied to features, sum of:\n " +
"3.0 = prod of:\n 3.0 = weight on feature\n 1.0 = SolrFeature [name=featureA1, params={fq=[{!terms f=popularity}1]}]\n " +
"9.0 = prod of:\n 9.0 = weight on feature\n 1.0 = SolrFeature [name=featureA2, params={fq=[{!terms f=description}bloomberg]}]\n " +
"0.0 = prod of:\n 27.0 = weight on feature\n 0.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n",
"\n0.07525751 = weight(title:bloomberg in 2) [SchemaSimilarity], result of:\n " +
"0.07525751 = score(freq=3.0), computed as boost * idf * tf from:\n " +
"0.105360515 = idf, computed as log(1 + (N - n + 0.5) / (n + 0.5)) from:\n 4 = n, number of documents containing term\n 4 = N, total number of documents with field\n " +
"0.71428573 = tf, computed as freq / (freq + k1 * (1 - b + b * dl / avgdl)) from:\n 3.0 = freq, occurrences of term within document\n " +
"1.2 = k1, term saturation parameter\n " +
"0.75 = b, length normalization parameter\n " +
"3.0 = dl, length of field\n " +
"3.0 = avgdl, average length of field\n"};
String[] tests = new String[16];
tests[0] = "/response/numFound/==4";
for (int i = 1; i <= 4; i++) {
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
tests[i + 4] = "/debug/explain/" + expectedInterleaved[(i - 1)] + "=='" + expectedExplains[(i - 1)]+"'}";
}
assertJQ("/query" + query.toQueryString(), tests);
}
} }

View File

@ -48,7 +48,21 @@ public class TestLTRQParserPlugin extends TestRerankBase {
query.add("rq", "{!ltr reRankDocs=100}"); query.add("rq", "{!ltr reRankDocs=100}");
final String res = restTestHarness.query("/query" + query.toQueryString()); final String res = restTestHarness.query("/query" + query.toQueryString());
assert (res.contains("Must provide model in the request")); assert (res.contains("Must provide one or two models in the request"));
}
@Test
public void interleavingLtrTooManyModelsTest() throws Exception {
final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.add("fl", "*, score");
query.add("rows", "4");
query.add("fv", "true");
query.add("rq", "{!ltr model=modelA model=modelB model=C reRankDocs=100}");
final String res = restTestHarness.query("/query" + query.toQueryString());
assert (res.contains("Must provide one or two models in the request"));
} }
@Test @Test
@ -65,6 +79,34 @@ public class TestLTRQParserPlugin extends TestRerankBase {
assert (res.contains("cannot find model")); assert (res.contains("cannot find model"));
} }
@Test
public void ltrModelIsEmptyTest() throws Exception {
final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.add("fl", "*, score");
query.add("rows", "4");
query.add("fv", "true");
query.add("rq", "{!ltr model=\"\" reRankDocs=100}");
final String res = restTestHarness.query("/query" + query.toQueryString());
assert (res.contains("the model 0 is empty"));
}
@Test
public void interleavingLtrModelIsEmptyTest() throws Exception {
final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}";
final SolrQuery query = new SolrQuery();
query.setQuery(solrQuery);
query.add("fl", "*, score");
query.add("rows", "4");
query.add("fv", "true");
query.add("rq", "{!ltr model=6029760550880411648 model=\"\" reRankDocs=100}");
final String res = restTestHarness.query("/query" + query.toQueryString());
assert (res.contains("the model 1 is empty"));
}
@Test @Test
public void ltrBadRerankDocsTest() throws Exception { public void ltrBadRerankDocsTest() throws Exception {
final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}"; final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}";

View File

@ -17,8 +17,11 @@
package org.apache.solr.ltr; package org.apache.solr.ltr;
import java.util.Random;
import org.apache.solr.client.solrj.SolrQuery; import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.ltr.feature.SolrFeature; import org.apache.solr.ltr.feature.SolrFeature;
import org.apache.solr.ltr.interleaving.algorithms.TeamDraftInterleaving;
import org.apache.solr.ltr.model.LinearModel; import org.apache.solr.ltr.model.LinearModel;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
@ -97,4 +100,104 @@ public class TestLTRWithSort extends TestRerankBase {
} }
@Test
public void interleavingTwoModelsWithSort_shouldInterleave() throws Exception {
TeamDraftInterleaving.setRANDOM(new Random(10));//Random Boolean Choices Generation from Seed: [1,0]
loadFeature("featureA", SolrFeature.class.getName(),
"{\"q\":\"{!func}pow(popularity,2)\"}");
loadFeature("featureB", SolrFeature.class.getName(),
"{\"q\":\"{!func}pow(popularity,-2)\"}");
loadModel("modelA", LinearModel.class.getName(),
new String[] {"featureA"}, "{\"weights\":{\"featureA\":1.0}}");
loadModel("modelB", LinearModel.class.getName(),
new String[] {"featureB"}, "{\"weights\":{\"featureB\":1.0}}");
final SolrQuery query = new SolrQuery();
query.setQuery("title:a1");
query.add("rows", "10");
query.add("rq", "{!ltr reRankDocs=4 model=modelA model=modelB}");
query.add("fl", "*,score");
query.add("sort", "description desc");
/*
Doc1 = "popularity=1", ScoreA(1) ScoreB(1)
Doc5 = "popularity=5", ScoreA(25) ScoreB(0.04)
Doc7 = "popularity=7", ScoreA(49) ScoreB(0.02)
Doc8 = "popularity=8", ScoreA(64) ScoreB(0.01)
ModelARerankedList = [8,7,5,1]
ModelBRerankedList = [1,5,7,8]
OriginalRanking = [1,5,8,7]
Random Boolean Choices Generation from Seed: [1,0]
*/
int[] expectedInterleaved = new int[]{1, 8, 7, 5};
String[] tests = new String[5];
tests[0] = "/response/numFound/==8";
for (int i = 1; i <= 4; i++) {
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
}
assertJQ("/query" + query.toQueryString(), tests);
}
@Test
public void interleavingModelsWithOriginalRankingSort_shouldInterleave() throws Exception {
loadFeature("powpularityS", SolrFeature.class.getName(),
"{\"q\":\"{!func}pow(popularity,2)\"}");
loadModel("powpularityS-model", LinearModel.class.getName(),
new String[] {"powpularityS"}, "{\"weights\":{\"powpularityS\":1.0}}");
for (boolean originalRankingLast : new boolean[] { true, false }) {
TeamDraftInterleaving.setRANDOM(new Random(10));//Random Boolean Choices Generation from Seed: [1,0]
final SolrQuery query = new SolrQuery();
query.setQuery("title:a1");
query.add("rows", "10");
if (originalRankingLast) {
query.add("rq", "{!ltr reRankDocs=4 model=powpularityS-model model=_OriginalRanking_}");
} else {
query.add("rq", "{!ltr reRankDocs=4 model=_OriginalRanking_ model=powpularityS-model}");
}
query.add("fl", "*,score");
query.add("sort", "description desc");
/*
Doc1 = "popularity=1", ScorePowpularityS(1)
Doc5 = "popularity=5", ScorePowpularityS(25)
Doc7 = "popularity=7", ScorePowpularityS(49)
Doc8 = "popularity=8", ScorePowpularityS(64)
PowpularitySRerankedList = [8,7,5,1]
OriginalRanking = [1,5,8,7]
Random Boolean Choices Generation from Seed: [1,0]
*/
final int[] expectedInterleaved;
if (originalRankingLast) {
expectedInterleaved = new int[]{1, 8, 7, 5};
} else {
expectedInterleaved = new int[]{8, 1, 5, 7};
}
String[] tests = new String[5];
tests[0] = "/response/numFound/==8";
for (int i = 1; i <= 4; i++) {
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
}
assertJQ("/query" + query.toQueryString(), tests);
}
}
} }

View File

@ -0,0 +1,170 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.interleaving.algorithms;
import java.util.ArrayList;
import java.util.Random;
import java.util.Set;
import org.apache.lucene.search.ScoreDoc;
import org.apache.solr.ltr.interleaving.InterleavingResult;
import org.junit.Before;
import org.junit.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertTrue;
public class TeamDraftInterleavingTest {
private TeamDraftInterleaving toTest;
private ScoreDoc[] rerankedA,rerankedB;
private ScoreDoc a1,a2,a3,a4,a5;
private ScoreDoc b1,b2,b3,b4,b5;
@Before
public void setup() {
toTest = new TeamDraftInterleaving();
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
}
protected void initDifferentOrderRerankLists() {
a1 = new ScoreDoc(1,10,1);
a2 = new ScoreDoc(5,7,1);
a3 = new ScoreDoc(4,6,1);
a4 = new ScoreDoc(2,5,1);
a5 = new ScoreDoc(3,4,1);
rerankedA = new ScoreDoc[]{a1,a2,a3,a4,a5};
b1 = new ScoreDoc(1,10,1);
b2 = new ScoreDoc(4,7,1);
b3 = new ScoreDoc(5,6,1);
b4 = new ScoreDoc(3,5,1);
b5 = new ScoreDoc(2,4,1);
rerankedB = new ScoreDoc[]{b1,b2,b3,b4,b5};
}
/**
* Random Boolean Choices Generation from Seed: [0,1,1]
*/
@Test
public void interleaving_twoDifferentLists_shouldInterleaveTeamDraft() {
initDifferentOrderRerankLists();
InterleavingResult interleaved = toTest.interleave(rerankedA, rerankedB);
ScoreDoc[] interleavedResults = interleaved.getInterleavedResults();
assertThat(interleavedResults.length,is(5));
assertThat(interleavedResults[0],is(a1));
assertThat(interleavedResults[1],is(b2));
assertThat(interleavedResults[2],is(b3));
assertThat(interleavedResults[3],is(a4));
assertThat(interleavedResults[4],is(b4));
}
/**
* Random Boolean Choices Generation from Seed: [0,1,1]
*/
@Test
public void interleaving_twoDifferentLists_shouldBuildCorrectInterleavingPicks() {
initDifferentOrderRerankLists();
InterleavingResult interleaved = toTest.interleave(rerankedA, rerankedB);
ArrayList<Set<Integer>> interleavingPicks = interleaved.getInterleavingPicks();
Set<Integer> modelAPicks = interleavingPicks.get(0);
Set<Integer> modelBPicks = interleavingPicks.get(1);
assertThat(modelAPicks.size(),is(2));
assertThat(modelBPicks.size(),is(3));
assertTrue(modelAPicks.contains(a1.doc));
assertTrue(modelAPicks.contains(a4.doc));
assertTrue(modelBPicks.contains(b2.doc));
assertTrue(modelBPicks.contains(b3.doc));
assertTrue(modelBPicks.contains(b4.doc));
}
protected void initIdenticalOrderRerankLists() {
a1 = new ScoreDoc(1,10,1);
a2 = new ScoreDoc(5,7,1);
a3 = new ScoreDoc(4,6,1);
a4 = new ScoreDoc(2,5,1);
a5 = new ScoreDoc(3,4,1);
rerankedA = new ScoreDoc[]{a1,a2,a3,a4,a5};
b1 = new ScoreDoc(1,10,1);
b2 = new ScoreDoc(5,7,1);
b3 = new ScoreDoc(4,6,1);
b4 = new ScoreDoc(2,5,1);
b5 = new ScoreDoc(3,4,1);
rerankedB = new ScoreDoc[]{b1,b2,b3,b4,b5};
}
/**
* Random Boolean Choices Generation from Seed: [0,1,1]
*/
@Test
public void interleaving_identicalRerankLists_shouldInterleaveTeamDraft() {
initIdenticalOrderRerankLists();
InterleavingResult interleaved = toTest.interleave(rerankedA, rerankedB);
ScoreDoc[] interleavedResults = interleaved.getInterleavedResults();
assertThat(interleavedResults.length,is(5));
assertThat(interleavedResults[0],is(a1));
assertThat(interleavedResults[1],is(b2));
assertThat(interleavedResults[2],is(b3));
assertThat(interleavedResults[3],is(a4));
assertThat(interleavedResults[4],is(b5));
}
/**
* Random Boolean Choices Generation from Seed: [0,1,1]
*/
@Test
public void interleaving_identicalRerankLists_shouldBuildCorrectInterleavingPicks() {
initIdenticalOrderRerankLists();
InterleavingResult interleaved = toTest.interleave(rerankedA, rerankedB);
ArrayList<Set<Integer>> interleavingPicks = interleaved.getInterleavingPicks();
Set<Integer> modelAPicks = interleavingPicks.get(0);
Set<Integer> modelBPicks = interleavingPicks.get(1);
assertThat(modelAPicks.size(),is(2));
assertThat(modelBPicks.size(),is(3));
assertTrue(modelAPicks.contains(a1.doc));
assertTrue(modelAPicks.contains(a4.doc));
assertTrue(modelBPicks.contains(b2.doc));
assertTrue(modelBPicks.contains(b3.doc));
assertTrue(modelBPicks.contains(b5.doc));
}
}

View File

@ -0,0 +1,400 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.response.transform;
import java.util.Random;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.ltr.TestRerankBase;
import org.apache.solr.ltr.feature.SolrFeature;
import org.apache.solr.ltr.interleaving.algorithms.TeamDraftInterleaving;
import org.apache.solr.ltr.model.LinearModel;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
public class TestFeatureLoggerTransformer extends TestRerankBase {
@Before
public void before() throws Exception {
setuptest(false);
assertU(adoc("id", "1", "title", "w1", "description", "w5", "popularity",
"1"));
assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description",
"w2 2asd asdd didid", "popularity", "2"));
assertU(adoc("id", "3", "title", "w1", "description", "w5", "popularity",
"3"));
assertU(adoc("id", "4", "title", "w1", "description", "w1", "popularity",
"6"));
assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity",
"5"));
assertU(adoc("id", "6", "title", "w6 w2", "description", "w1 w2",
"popularity", "6"));
assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description",
"w6 w2 w3 w4 w5 w8", "popularity", "88888"));
assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description",
"w1 w1 w1 w2 w2 w5", "popularity", "88888"));
assertU(commit());
}
@After
public void after() throws Exception {
aftertest();
}
protected void loadFeaturesAndModels() throws Exception {
loadFeature("featureA1", SolrFeature.class.getName(),
"{\"fq\":[\"{!terms f=popularity}88888\"]}");
loadFeature("featureA2", SolrFeature.class.getName(),
"{\"fq\":[\"{!terms f=title}${user_query}\"]}");
loadFeature("featureAB", SolrFeature.class.getName(),
"{\"fq\":[\"{!terms f=title}${user_query}\"]}");
loadFeature("featureB1", SolrFeature.class.getName(),
"{\"fq\":[\"{!terms f=popularity}6\"]}");
loadFeature("featureB2", SolrFeature.class.getName(),
"{\"fq\":[\"{!terms f=description}${user_query}\"]}");
loadFeature("featureC1", SolrFeature.class.getName(),"featureStore2",
"{\"fq\":[\"{!terms f=popularity}6\"]}");
loadFeature("featureC2", SolrFeature.class.getName(),"featureStore2",
"{\"fq\":[\"{!terms f=description}${user_query}\"]}");
loadFeature("featureC3", SolrFeature.class.getName(),"featureStore2",
"{\"fq\":[\"{!terms f=description}${user_query}\"]}");
loadModel("modelA", LinearModel.class.getName(),
new String[]{"featureA1", "featureA2", "featureAB"},
"{\"weights\":{\"featureA1\":3.0, \"featureA2\":9.0, \"featureAB\":27.0}}");
loadModel("modelB", LinearModel.class.getName(),
new String[]{"featureB1", "featureB2", "featureAB"},
"{\"weights\":{\"featureB1\":2.0, \"featureB2\":4.0, \"featureAB\":8.0}}");
loadModel("modelC", LinearModel.class.getName(),
new String[]{"featureC1", "featureC2"},"featureStore2",
"{\"weights\":{\"featureC1\":5.0, \"featureC2\":25.0}}");
}
@Test
public void interleaving_featureTransformer_shouldWorkInSparseFormat() throws Exception {
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
loadFeaturesAndModels();
final SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.add("fl", "*, score,features:[fv format=sparse]");
query.add("rows", "10");
query.add("debugQuery", "true");
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
query.add("rq",
"{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}");
/*
Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4)
Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4)
Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2)
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12)
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4)
ModelARerankedList = [7,8,1,3,4]
ModelBRerankedList = [7,1,3,8,4]
Random Boolean Choices Generation from Seed: [0,1,1]
*/
String[] expectedFeatureVectors = new String[]{"featureA1\\=1.0\\,featureA2\\=1.0\\,featureAB\\=1.0\\,featureB2\\=1.0", "featureB2\\=1.0", "featureB2\\=1.0", "featureA1\\=1.0\\,featureB2\\=1.0", "featureB1\\=1.0"};
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
String[] tests = new String[11];
tests[0] = "/response/numFound/==5";
for (int i = 1; i <= 5; i++) {
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
tests[i + 5] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
}
assertJQ("/query" + query.toQueryString(), tests);
}
@Test
public void interleaving_featureTransformer_shouldWorkInDenseFormat() throws Exception {
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
loadFeaturesAndModels();
final SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.add("fl", "*, score,features:[fv format=dense]");
query.add("rows", "10");
query.add("debugQuery", "true");
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
query.add("rq",
"{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}");
/*
Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4)
Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4)
Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2)
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12)
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4)
ModelARerankedList = [7,8,1,3,4]
ModelBRerankedList = [7,1,3,8,4]
Random Boolean Choices Generation from Seed: [0,1,1]
*/
String[] expectedFeatureVectors = new String[]{"featureA1\\=1.0\\,featureA2\\=1.0\\,featureAB\\=1.0\\,featureB1\\=0.0\\,featureB2\\=1.0",
"featureA1\\=0.0\\,featureA2\\=0.0\\,featureAB\\=0.0\\,featureB1\\=0.0\\,featureB2\\=1.0",
"featureA1\\=0.0\\,featureA2\\=0.0\\,featureAB\\=0.0\\,featureB1\\=0.0\\,featureB2\\=1.0",
"featureA1\\=1.0\\,featureA2\\=0.0\\,featureAB\\=0.0\\,featureB1\\=0.0\\,featureB2\\=1.0",
"featureA1\\=0.0\\,featureA2\\=0.0\\,featureAB\\=0.0\\,featureB1\\=1.0\\,featureB2\\=0.0"};
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
String[] tests = new String[11];
tests[0] = "/response/numFound/==5";
for (int i = 1; i <= 5; i++) {
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
tests[i + 5] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
}
assertJQ("/query" + query.toQueryString(), tests);
}
@Test
public void interleaving_explicitNewFeatureStore_shouldExtractAllFeaturesFromNewStore() throws Exception {
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
loadFeaturesAndModels();
final SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.add("fl", "*, score,features:[fv store=featureStore2 efi.user_query='w5' format=dense]");
query.add("rows", "10");
query.add("debugQuery", "true");
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
query.add("rq",
"{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}");
/*
Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4)
Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4)
Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2)
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12)
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4)
ModelARerankedList = [7,8,1,3,4]
ModelBRerankedList = [7,1,3,8,4]
Random Boolean Choices Generation from Seed: [0,1,1]
*/
String[] expectedFeatureVectors = new String[]{
"featureC1\\=0.0\\,featureC2\\=1.0\\,featureC3\\=1.0",
"featureC1\\=0.0\\,featureC2\\=1.0\\,featureC3\\=1.0",
"featureC1\\=0.0\\,featureC2\\=1.0\\,featureC3\\=1.0",
"featureC1\\=0.0\\,featureC2\\=1.0\\,featureC3\\=1.0",
"featureC1\\=1.0\\,featureC2\\=0.0\\,featureC3\\=0.0"};
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
String[] tests = new String[16];
tests[0] = "/response/numFound/==5";
for (int i = 1; i <= 5; i++) {
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
}
assertJQ("/query" + query.toQueryString(), tests);
}
@Test
public void interleaving_withOriginalRankingAndExplicitFeatureStore_shouldReturnNewCalculatedFeatureVector() throws Exception {
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
loadFeaturesAndModels();
final SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.add("fl", "*, score,features:[fv store=featureStore2 efi.user_query='w5' format=sparse]");
query.add("rows", "10");
query.add("debugQuery", "true");
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
query.add("rq",
"{!ltr model=modelA model=_OriginalRanking_ reRankDocs=10 efi.user_query='w5'}");
/*
Doc1 = "featureB2=1.0", ScoreA(0)
Doc3 = "featureB2=1.0", ScoreA(0)
Doc4 = "featureB1=1.0", ScoreA(0)
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39)
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3)
ModelARerankedList = [7,8,1,3,4]
_OriginalRanking_ = [1,3,4,7,8]
Random Boolean Choices Generation from Seed: [0,1,1]
*/
String[] expectedFeatureVectors = new String[]{
"featureC2\\=1.0\\,featureC3\\=1.0",
"featureC2\\=1.0\\,featureC3\\=1.0",
"featureC2\\=1.0\\,featureC3\\=1.0",
"featureC2\\=1.0\\,featureC3\\=1.0",
"featureC1\\=1.0"};
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
String[] tests = new String[16];
tests[0] = "/response/numFound/==5";
for (int i = 1; i <= 5; i++) {
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
if (expectedFeatureVectors[(i - 1)] != null) {
tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
}
}
assertJQ("/query" + query.toQueryString(), tests);
int[] nullFeatureVectorIndexes = new int[]{1, 2, 4};
for (int index : nullFeatureVectorIndexes) {
TeamDraftInterleaving.setRANDOM(new Random(10101010));
String[] nullFeatureVectorTests = new String[1];
try {
nullFeatureVectorTests[0] = "/response/docs/[" + index + "]/features==";
assertJQ("/query" + query.toQueryString(), nullFeatureVectorTests);
} catch (Exception e) {
assertEquals("Path not found: /response/docs/[" + index + "]/features", e.getMessage());
continue;
}
}
}
@Test
public void interleaving_modelsFromDifferentFeatureStores_shouldLogFeaturesCorrectly() throws Exception {
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
loadFeaturesAndModels();
final SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.add("fl", "*, score,features:[fv format=sparse]");
query.add("rows", "10");
query.add("debugQuery", "true");
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
query.add("rq",
"{!ltr model=modelA model=modelC reRankDocs=10 efi.user_query='w5'}");
/*
Doc1 = "featureC2=1.0", ScoreA(0), ScoreC(25)
Doc3 = "featureC2=1.0", ScoreA(0), ScoreC(25)
Doc4 = "featureC1=1.0", ScoreA(0), ScoreC(5)
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0", ScoreA(39), ScoreC(0)
Doc8 = "featureA1=1.0,featureC2=1.0", ScoreA(3), ScoreC(25)
ModelARerankedList = [7,8,1,3,4]
ModelCRerankedList = [1,3,8,4,7]
Random Boolean Choices Generation from Seed: [0,1,1]
*/
String[] expectedFeatureVectors = new String[]{
"featureA1\\=1.0\\,featureA2\\=1.0\\,featureAB\\=1.0\\,featureB2\\=1.0",
"featureC2\\=1.0\\,featureC3\\=1.0",
"featureC2\\=1.0\\,featureC3\\=1.0",
"featureA1\\=1.0\\,featureB2\\=1.0",
"featureC1\\=1.0"};
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
String[] tests = new String[16];
tests[0] = "/response/numFound/==5";
for (int i = 1; i <= 5; i++) {
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
}
assertJQ("/query" + query.toQueryString(), tests);
}
@Test
public void interleaving_featureLoggerFromNewFeatureStoreWithDifferentEfi_shouldReturnNewCalculatedFeatureVector() throws Exception {
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
loadFeaturesAndModels();
final SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.add("fl", "*, score,features:[fv store=featureStore2 efi.user_query='w8' format=sparse]");
query.add("rows", "10");
query.add("debugQuery", "true");
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
query.add("rq",
"{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}");
/*
Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4)
Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4)
Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2)
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12)
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4)
ModelARerankedList = [7,8,1,3,4]
ModelBRerankedList = [7,1,3,8,4]
Random Boolean Choices Generation from Seed: [0,1,1]
*/
String[] expectedFeatureVectors = new String[]{
"featureC2\\=1.0\\,featureC3\\=1.0",
"",
"",
"",
""};
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
String[] tests = new String[16];
tests[0] = "/response/numFound/==5";
for (int i = 1; i <= 5; i++) {
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
}
assertJQ("/query" + query.toQueryString(), tests);
}
@Test
public void interleaving_explicitFeatureStoreReusableFromModel_shouldLogFeaturesCorrectly() throws Exception {
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
loadFeaturesAndModels();
final SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.add("fl", "*, score,features:[fv store=featureStore2 format=sparse]");
query.add("rows", "10");
query.add("debugQuery", "true");
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
query.add("rq",
"{!ltr model=modelA model=modelC reRankDocs=10 efi.user_query='w5'}");
/*
Doc1 = "featureC2=1.0", ScoreA(0), ScoreC(25)
Doc3 = "featureC2=1.0", ScoreA(0), ScoreC(25)
Doc4 = "featureC1=1.0", ScoreA(0), ScoreC(5)
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0", ScoreA(39), ScoreC(0)
Doc8 = "featureA1=1.0,featureC2=1.0", ScoreA(3), ScoreC(25)
ModelARerankedList = [7,8,1,3,4]
ModelCRerankedList = [1,3,8,4,7]
Random Boolean Choices Generation from Seed: [0,1,1]
*/
String[] expectedFeatureVectors = new String[]{
"featureC2\\=1.0\\,featureC3\\=1.0",
"featureC2\\=1.0\\,featureC3\\=1.0",
"featureC2\\=1.0\\,featureC3\\=1.0",
"featureC2\\=1.0\\,featureC3\\=1.0",
"featureC1\\=1.0"};
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
String[] tests = new String[16];
tests[0] = "/response/numFound/==5";
for (int i = 1; i <= 5; i++) {
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
}
assertJQ("/query" + query.toQueryString(), tests);
}
}

View File

@ -0,0 +1,277 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.response.transform;
import java.util.Random;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.ltr.TestRerankBase;
import org.apache.solr.ltr.feature.SolrFeature;
import org.apache.solr.ltr.interleaving.algorithms.TeamDraftInterleaving;
import org.apache.solr.ltr.model.LinearModel;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
public class TestInterleavingTransformer extends TestRerankBase {
@Before
public void before() throws Exception {
setuptest(false);
assertU(adoc("id", "1", "title", "w1", "description", "w5", "popularity",
"1"));
assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description",
"w2 2asd asdd didid", "popularity", "2"));
assertU(adoc("id", "3", "title", "w1", "description", "w5", "popularity",
"3"));
assertU(adoc("id", "4", "title", "w1", "description", "w1", "popularity",
"6"));
assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity",
"5"));
assertU(adoc("id", "6", "title", "w6 w2", "description", "w1 w2",
"popularity", "6"));
assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description",
"w6 w2 w3 w4 w5 w8", "popularity", "88888"));
assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description",
"w1 w1 w1 w2 w2 w5", "popularity", "88888"));
assertU(commit());
}
@After
public void after() throws Exception {
aftertest();
}
protected void loadFeaturesAndModelsForInterleaving() throws Exception {
loadFeature("featureA1", SolrFeature.class.getName(),
"{\"fq\":[\"{!terms f=popularity}88888\"]}");
loadFeature("featureA2", SolrFeature.class.getName(),
"{\"fq\":[\"{!terms f=title}${user_query}\"]}");
loadFeature("featureAB", SolrFeature.class.getName(),
"{\"fq\":[\"{!terms f=title}${user_query}\"]}");
loadFeature("featureB1", SolrFeature.class.getName(),
"{\"fq\":[\"{!terms f=popularity}6\"]}");
loadFeature("featureB2", SolrFeature.class.getName(),
"{\"fq\":[\"{!terms f=description}${user_query}\"]}");
loadFeature("featureC1", SolrFeature.class.getName(),"featureStore2",
"{\"fq\":[\"{!terms f=popularity}6\"]}");
loadFeature("featureC2", SolrFeature.class.getName(),"featureStore2",
"{\"fq\":[\"{!terms f=popularity}1\"]}");
loadModel("modelA", LinearModel.class.getName(),
new String[]{"featureA1", "featureA2", "featureAB"},
"{\"weights\":{\"featureA1\":3.0, \"featureA2\":9.0, \"featureAB\":27.0}}");
loadModel("modelB", LinearModel.class.getName(),
new String[]{"featureB1", "featureB2", "featureAB"},
"{\"weights\":{\"featureB1\":2.0, \"featureB2\":4.0, \"featureAB\":8.0}}");
loadModel("modelC", LinearModel.class.getName(),
new String[]{"featureC1", "featureC2"},"featureStore2",
"{\"weights\":{\"featureC1\":5.0, \"featureC2\":25.0}}");
}
@Test
public void interleavingTransformer_shouldReturnInterleavingPickInTheResults() throws Exception {
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
loadFeaturesAndModelsForInterleaving();
final SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.add("fl", "*, score,interleavingPick:[interleaving]");
query.add("rows", "10");
query.add("debugQuery", "true");
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
query.add("rq",
"{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}");
/*
Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4)
Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4)
Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2)
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12)
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4)
ModelARerankedList = [7,8,1,3,4]
ModelBRerankedList = [7,1,3,8,4]
Random Boolean Choices Generation from Seed: [0,1,1]
*/
String[] expectedInterleavingPicks = new String[]{"modelA", "modelB", "modelB", "modelA", "modelB"};
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
String[] tests = new String[11];
tests[0] = "/response/numFound/==5";
for (int i = 1; i <= 5; i++) {
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
tests[i + 5] = "/response/docs/[" + (i - 1) + "]/interleavingPick==" + expectedInterleavingPicks[(i - 1)];
}
assertJQ("/query" + query.toQueryString(), tests);
}
@Test
public void interleavingTransformerWithOriginalRanking_shouldReturnInterleavingPickInTheResults() throws Exception {
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
loadFeaturesAndModelsForInterleaving();
final SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.add("fl", "*, score,interleavingPick:[interleaving]");
query.add("rows", "10");
query.add("debugQuery", "true");
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
query.add("rq",
"{!ltr model=modelA model=_OriginalRanking_ reRankDocs=10 efi.user_query='w5'}");
/*
Doc1 = "featureB2=1.0", ScoreA(0)
Doc3 = "featureB2=1.0", ScoreA(0)
Doc4 = "featureB1=1.0", ScoreA(0)
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39)
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3)
ModelARerankedList = [7,8,1,3,4]
OriginalRanking = [1,3,4,7,8]
Random Boolean Choices Generation from Seed: [0,1,1]
*/
String[] expectedInterleavingPicks = new String[]{"modelA", "_OriginalRanking_", "_OriginalRanking_", "modelA", "_OriginalRanking_"};
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
String[] tests = new String[11];
tests[0] = "/response/numFound/==5";
for (int i = 1; i <= 5; i++) {
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
tests[i + 5] = "/response/docs/[" + (i - 1) + "]/interleavingPick==" + expectedInterleavingPicks[(i - 1)];
}
assertJQ("/query" + query.toQueryString(), tests);
}
@Test
public void interleavingTransformer_shouldBeCompatibleWithFeatureTransformer() throws Exception {
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
loadFeaturesAndModelsForInterleaving();
final SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.add("fl", "*, score,interleavingPick:[interleaving],features:[fv format=sparse]");
query.add("rows", "10");
query.add("debugQuery", "true");
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
query.add("rq",
"{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}");
/*
Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4)
Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4)
Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2)
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12)
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4)
ModelARerankedList = [7,8,1,3,4]
ModelBRerankedList = [7,1,3,8,4]
Random Boolean Choices Generation from Seed: [0,1,1]
*/
String[] expectedFeatureVectors = new String[]{
"featureA1\\=1.0\\,featureA2\\=1.0\\,featureAB\\=1.0\\,featureB2\\=1.0",
"featureB2\\=1.0",
"featureB2\\=1.0",
"featureA1\\=1.0\\,featureB2\\=1.0",
"featureB1\\=1.0"};
String[] expectedInterleavingPicks = new String[]{"modelA", "modelB", "modelB", "modelA", "modelB"};
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
String[] tests = new String[16];
tests[0] = "/response/numFound/==5";
for (int i = 1; i <= 5; i++) {
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
tests[i + 5] = "/response/docs/[" + (i - 1) + "]/interleavingPick==" + expectedInterleavingPicks[(i - 1)];
tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
}
assertJQ("/query" + query.toQueryString(), tests);
}
@Test
public void interleavingTransformerWithOriginalRanking_shouldBeCompatibleWithFeatureTransformer() throws Exception {
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
loadFeaturesAndModelsForInterleaving();
final SolrQuery query = new SolrQuery();
query.setQuery("*:*");
query.add("fl", "*, score,interleavingPick:[interleaving],features:[fv format=sparse]");
query.add("rows", "10");
query.add("debugQuery", "true");
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
query.add("rq",
"{!ltr model=modelA model=_OriginalRanking_ reRankDocs=10 efi.user_query='w5'}");
/*
Doc1 = "featureB2=1.0", ScoreA(0)
Doc3 = "featureB2=1.0", ScoreA(0)
Doc4 = "featureB1=1.0", ScoreA(0)
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39)
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3)
ModelARerankedList = [7,8,1,3,4]
_OriginalRanking_ = [1,3,4,7,8]
Random Boolean Choices Generation from Seed: [0,1,1]
*/
String[] expectedFeatureVectors = new String[]{
"featureA1\\=1.0\\,featureA2\\=1.0\\,featureAB\\=1.0\\,featureB2\\=1.0",
null,
null,
"featureA1\\=1.0\\,featureB2\\=1.0",
null};
String[] expectedInterleavingPicks = new String[]{"modelA", "_OriginalRanking_", "_OriginalRanking_", "modelA", "_OriginalRanking_"};
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
String[] tests = new String[16];
tests[0] = "/response/numFound/==5";
for (int i = 1; i <= 5; i++) {
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
tests[i + 5] = "/response/docs/[" + (i - 1) + "]/interleavingPick==" + expectedInterleavingPicks[(i - 1)];
if (expectedFeatureVectors[(i - 1)] != null) {
tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
}
}
assertJQ("/query" + query.toQueryString(), tests);
int[] nullFeatureVectorIndexes = new int[]{1, 2, 4};
for (int index : nullFeatureVectorIndexes) {
TeamDraftInterleaving.setRANDOM(new Random(10101010));
String[] nullFeatureVectorTests = new String[1];
try {
nullFeatureVectorTests[0] = "/response/docs/[" + index + "]/features==";
assertJQ("/query" + query.toQueryString(), nullFeatureVectorTests);
} catch (Exception e) {
assertEquals("Path not found: /response/docs/[" + index + "]/features", e.getMessage());
continue;
}
}
}
}

View File

@ -38,6 +38,13 @@ A ranking model computes the scores used to rerank documents. Irrespective of an
* features that represent the document being scored * features that represent the document being scored
* features that represent the query for which the document is being scored * features that represent the query for which the document is being scored
==== Interleaving
Interleaving is an approach to Online Search Quality evaluation that allows to compare two models interleaving their results in the final ranked list returned to the user.
* currently only the Team Draft Interleaving algorithm is supported (and its implementation assumes all results are from the same shard)
==== Feature ==== Feature
A feature is a value, a number, that represents some quantity or quality of the document being scored or of the query for which documents are being scored. For example documents often have a 'recency' quality and 'number of past purchases' might be a quantity that is passed to Solr as part of the search query. A feature is a value, a number, that represents some quantity or quality of the document being scored or of the query for which documents are being scored. For example documents often have a 'recency' quality and 'number of past purchases' might be a quantity that is passed to Solr as part of the search query.
@ -247,6 +254,81 @@ The output XML will include feature values as a comma-separated list, resembling
}} }}
---- ----
=== Running a Rerank Query Interleaving Two Models
To rerank the results of a query, interleaving two models (myModelA, myModelB) add the `rq` parameter to your search, passing two models in input, for example:
[source,text]
http://localhost:8983/solr/techproducts/query?q=test&rq={!ltr model=myModelA model=myModelB reRankDocs=100}&fl=id,score
To obtain the model that interleaving picked for a search result, computed during reranking, add `[interleaving]` to the `fl` parameter, for example:
[source,text]
http://localhost:8983/solr/techproducts/query?q=test&rq={!ltr model=myModelA model=myModelB reRankDocs=100}&fl=id,score,[interleaving]
The output XML will include the model picked for each search result, resembling the output shown here:
[source,json]
----
{
"responseHeader":{
"status":0,
"QTime":0,
"params":{
"q":"test",
"fl":"id,score,[interleaving]",
"rq":"{!ltr model=myModelA model=myModelB reRankDocs=100}"}},
"response":{"numFound":2,"start":0,"maxScore":1.0005897,"docs":[
{
"id":"GB18030TEST",
"score":1.0005897,
"[interleaving]":"myModelB"},
{
"id":"UTF8TEST",
"score":0.79656565,
"[interleaving]":"myModelA"}]
}}
----
=== Running a Rerank Query Interleaving a model with the original ranking
When approaching Search Quality Evaluation with interleaving it may be useful to compare a model with the original ranking.
To rerank the results of a query, interleaving a model with the original ranking, add the `rq` parameter to your search, passing the special inbuilt `_OriginalRanking_` model identifier as one model and your comparison model as the other model, for example:
[source,text]
http://localhost:8983/solr/techproducts/query?q=test&rq={!ltr model=_OriginalRanking_ model=myModel reRankDocs=100}&fl=id,score
The addition of the `rq` parameter will not change the output XML of the search.
To obtain the model that interleaving picked for a search result, computed during reranking, add `[interleaving]` to the `fl` parameter, for example:
[source,text]
http://localhost:8983/solr/techproducts/query?q=test&rq={!ltr model=_OriginalRanking_ model=myModel reRankDocs=100}&fl=id,score,[interleaving]
The output XML will include the model picked for each search result, resembling the output shown here:
[source,json]
----
{
"responseHeader":{
"status":0,
"QTime":0,
"params":{
"q":"test",
"fl":"id,score,[features]",
"rq":"{!ltr model=_OriginalRanking_ model=myModel reRankDocs=100}"}},
"response":{"numFound":2,"start":0,"maxScore":1.0005897,"docs":[
{
"id":"GB18030TEST",
"score":1.0005897,
"[interleaving]":"_OriginalRanking_"},
{
"id":"UTF8TEST",
"score":0.79656565,
"[interleaving]":"myModel"}]
}}
----
=== External Feature Information === External Feature Information
The {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/feature/ValueFeature.html[ValueFeature] and {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/feature/SolrFeature.html[SolrFeature] classes support the use of external feature information, `efi` for short. The {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/feature/ValueFeature.html[ValueFeature] and {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/feature/SolrFeature.html[SolrFeature] classes support the use of external feature information, `efi` for short.
@ -418,6 +500,13 @@ Learning-To-Rank is a contrib module and therefore its plugins must be configure
</transformer> </transformer>
---- ----
* Declaration of the `[interleaving]` transformer.
+
[source,xml]
----
<transformer name="interleaving" class="org.apache.solr.ltr.response.transform.LTRInterleavingTransformerFactory"/>
----
=== Advanced Options === Advanced Options
==== LTRThreadModule ==== LTRThreadModule
@ -446,11 +535,12 @@ How does Solr Learning-To-Rank work under the hood?::
Please refer to the `ltr` {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/package-summary.html[javadocs] for an implementation overview. Please refer to the `ltr` {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/package-summary.html[javadocs] for an implementation overview.
How could I write additional models and/or features?:: How could I write additional models and/or features?::
Contributions for further models, features and normalizers are welcome. Related links: Contributions for further models, features, normalizers and interleaving algorithms are welcome. Related links:
+ +
* {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/model/LTRScoringModel.html[LTRScoringModel javadocs] * {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/model/LTRScoringModel.html[LTRScoringModel javadocs]
* {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/feature/Feature.html[Feature javadocs] * {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/feature/Feature.html[Feature javadocs]
* {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/norm/Normalizer.html[Normalizer javadocs] * {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/norm/Normalizer.html[Normalizer javadocs]
* {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/interleaving/Interleaving.html[Interleaving javadocs]
* https://cwiki.apache.org/confluence/display/solr/HowToContribute * https://cwiki.apache.org/confluence/display/solr/HowToContribute
* https://cwiki.apache.org/confluence/display/LUCENE/HowToContribute * https://cwiki.apache.org/confluence/display/LUCENE/HowToContribute
@ -779,3 +869,7 @@ The feature store and the model store are both <<managed-resources.adoc#managed-
* "Learning to Rank in Solr" presentation at Lucene/Solr Revolution 2015 in Austin: * "Learning to Rank in Solr" presentation at Lucene/Solr Revolution 2015 in Austin:
** Slides: http://www.slideshare.net/lucidworks/learning-to-rank-in-solr-presented-by-michael-nilsson-diego-ceccarelli-bloomberg-lp ** Slides: http://www.slideshare.net/lucidworks/learning-to-rank-in-solr-presented-by-michael-nilsson-diego-ceccarelli-bloomberg-lp
** Video: https://www.youtube.com/watch?v=M7BKwJoh96s ** Video: https://www.youtube.com/watch?v=M7BKwJoh96s
* The importance of Online Testing in Learning To Rank:
** Blog: https://sease.io/2020/04/the-importance-of-online-testing-in-learning-to-rank-part-1.html
** Blog: https://sease.io/2020/05/online-testing-for-learning-to-rank-interleaving.html