mirror of https://github.com/apache/lucene.git
SOLR-14560: Interleaving for Learning To Rank (#1571)
SOLR-14560: Add interleaving support in Learning To Rank
This commit is contained in:
parent
ea4dd0580f
commit
af0455ac83
|
@ -167,6 +167,8 @@ New Features
|
||||||
|
|
||||||
* SOLR-14907: Add v2 API for configSet upload, including single-file insertion. (Houston Putman)
|
* SOLR-14907: Add v2 API for configSet upload, including single-file insertion. (Houston Putman)
|
||||||
|
|
||||||
|
* SOLR-14560: Add interleaving support in Learning To Rank. (Alessandro Benedetti, Christine Poerschke)
|
||||||
|
|
||||||
Improvements
|
Improvements
|
||||||
---------------------
|
---------------------
|
||||||
* SOLR-14942: Reduce leader election time on node shutdown by removing election nodes before closing cores.
|
* SOLR-14942: Reduce leader election time on node shutdown by removing election nodes before closing cores.
|
||||||
|
|
|
@ -31,6 +31,7 @@ import org.apache.lucene.search.ScoreMode;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.search.TotalHits;
|
import org.apache.lucene.search.TotalHits;
|
||||||
import org.apache.lucene.search.Weight;
|
import org.apache.lucene.search.Weight;
|
||||||
|
import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery;
|
||||||
import org.apache.solr.search.SolrIndexSearcher;
|
import org.apache.solr.search.SolrIndexSearcher;
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,12 +43,17 @@ import org.apache.solr.search.SolrIndexSearcher;
|
||||||
* */
|
* */
|
||||||
public class LTRRescorer extends Rescorer {
|
public class LTRRescorer extends Rescorer {
|
||||||
|
|
||||||
LTRScoringQuery scoringQuery;
|
final private LTRScoringQuery scoringQuery;
|
||||||
|
|
||||||
|
public LTRRescorer() {
|
||||||
|
this.scoringQuery = null;
|
||||||
|
}
|
||||||
|
|
||||||
public LTRRescorer(LTRScoringQuery scoringQuery) {
|
public LTRRescorer(LTRScoringQuery scoringQuery) {
|
||||||
this.scoringQuery = scoringQuery;
|
this.scoringQuery = scoringQuery;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void heapAdjust(ScoreDoc[] hits, int size, int root) {
|
protected static void heapAdjust(ScoreDoc[] hits, int size, int root) {
|
||||||
final ScoreDoc doc = hits[root];
|
final ScoreDoc doc = hits[root];
|
||||||
final float score = doc.score;
|
final float score = doc.score;
|
||||||
int i = root;
|
int i = root;
|
||||||
|
@ -82,7 +88,7 @@ public class LTRRescorer extends Rescorer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void heapify(ScoreDoc[] hits, int size) {
|
protected static void heapify(ScoreDoc[] hits, int size) {
|
||||||
for (int i = (size >> 1) - 1; i >= 0; i--) {
|
for (int i = (size >> 1) - 1; i >= 0; i--) {
|
||||||
heapAdjust(hits, size, i);
|
heapAdjust(hits, size, i);
|
||||||
}
|
}
|
||||||
|
@ -104,23 +110,27 @@ public class LTRRescorer extends Rescorer {
|
||||||
if ((topN == 0) || (firstPassTopDocs.scoreDocs.length == 0)) {
|
if ((topN == 0) || (firstPassTopDocs.scoreDocs.length == 0)) {
|
||||||
return firstPassTopDocs;
|
return firstPassTopDocs;
|
||||||
}
|
}
|
||||||
final ScoreDoc[] hits = firstPassTopDocs.scoreDocs;
|
final ScoreDoc[] firstPassResults = getFirstPassDocsRanked(firstPassTopDocs);
|
||||||
Arrays.sort(hits, new Comparator<ScoreDoc>() {
|
|
||||||
@Override
|
|
||||||
public int compare(ScoreDoc a, ScoreDoc b) {
|
|
||||||
return a.doc - b.doc;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
assert firstPassTopDocs.totalHits.relation == TotalHits.Relation.EQUAL_TO;
|
|
||||||
topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value));
|
topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value));
|
||||||
|
|
||||||
|
final ScoreDoc[] reranked = rerank(searcher, topN, firstPassResults);
|
||||||
|
|
||||||
|
return new TopDocs(firstPassTopDocs.totalHits, reranked);
|
||||||
|
}
|
||||||
|
|
||||||
|
private ScoreDoc[] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPassResults) throws IOException {
|
||||||
final ScoreDoc[] reranked = new ScoreDoc[topN];
|
final ScoreDoc[] reranked = new ScoreDoc[topN];
|
||||||
final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
|
final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
|
||||||
final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) searcher
|
final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) searcher
|
||||||
.createWeight(searcher.rewrite(scoringQuery), ScoreMode.COMPLETE, 1);
|
.createWeight(searcher.rewrite(scoringQuery), ScoreMode.COMPLETE, 1);
|
||||||
|
|
||||||
scoreFeatures(searcher, firstPassTopDocs,topN, modelWeight, hits, leaves, reranked);
|
scoreFeatures(searcher,topN, modelWeight, firstPassResults, leaves, reranked);
|
||||||
// Must sort all documents that we reranked, and then select the top
|
// Must sort all documents that we reranked, and then select the top
|
||||||
|
sortByScore(reranked);
|
||||||
|
return reranked;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected static void sortByScore(ScoreDoc[] reranked) {
|
||||||
Arrays.sort(reranked, new Comparator<ScoreDoc>() {
|
Arrays.sort(reranked, new Comparator<ScoreDoc>() {
|
||||||
@Override
|
@Override
|
||||||
public int compare(ScoreDoc a, ScoreDoc b) {
|
public int compare(ScoreDoc a, ScoreDoc b) {
|
||||||
|
@ -136,13 +146,24 @@ public class LTRRescorer extends Rescorer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
return new TopDocs(firstPassTopDocs.totalHits, reranked);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void scoreFeatures(IndexSearcher indexSearcher, TopDocs firstPassTopDocs,
|
protected static ScoreDoc[] getFirstPassDocsRanked(TopDocs firstPassTopDocs) {
|
||||||
int topN, LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List<LeafReaderContext> leaves,
|
final ScoreDoc[] hits = firstPassTopDocs.scoreDocs;
|
||||||
ScoreDoc[] reranked) throws IOException {
|
Arrays.sort(hits, new Comparator<ScoreDoc>() {
|
||||||
|
@Override
|
||||||
|
public int compare(ScoreDoc a, ScoreDoc b) {
|
||||||
|
return a.doc - b.doc;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
assert firstPassTopDocs.totalHits.relation == TotalHits.Relation.EQUAL_TO;
|
||||||
|
return hits;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void scoreFeatures(IndexSearcher indexSearcher,
|
||||||
|
int topN, LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List<LeafReaderContext> leaves,
|
||||||
|
ScoreDoc[] reranked) throws IOException {
|
||||||
|
|
||||||
int readerUpto = -1;
|
int readerUpto = -1;
|
||||||
int endDoc = 0;
|
int endDoc = 0;
|
||||||
|
@ -150,7 +171,6 @@ public class LTRRescorer extends Rescorer {
|
||||||
|
|
||||||
LTRScoringQuery.ModelWeight.ModelScorer scorer = null;
|
LTRScoringQuery.ModelWeight.ModelScorer scorer = null;
|
||||||
int hitUpto = 0;
|
int hitUpto = 0;
|
||||||
final FeatureLogger featureLogger = scoringQuery.getFeatureLogger();
|
|
||||||
|
|
||||||
while (hitUpto < hits.length) {
|
while (hitUpto < hits.length) {
|
||||||
final ScoreDoc hit = hits[hitUpto];
|
final ScoreDoc hit = hits[hitUpto];
|
||||||
|
@ -166,64 +186,77 @@ public class LTRRescorer extends Rescorer {
|
||||||
docBase = readerContext.docBase;
|
docBase = readerContext.docBase;
|
||||||
scorer = modelWeight.scorer(readerContext);
|
scorer = modelWeight.scorer(readerContext);
|
||||||
}
|
}
|
||||||
// Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to
|
scoreSingleHit(indexSearcher, topN, modelWeight, docBase, hitUpto, hit, docID, scoringQuery, scorer, reranked);
|
||||||
// call score
|
hitUpto++;
|
||||||
// even if no feature scorers match, since a model might use that info to
|
}
|
||||||
// return a
|
}
|
||||||
// non-zero score. Same applies for the case of advancing a LTRScoringQuery.ModelWeight.ModelScorer
|
|
||||||
// past the target
|
|
||||||
// doc since the model algorithm still needs to compute a potentially
|
|
||||||
// non-zero score from blank features.
|
|
||||||
assert (scorer != null);
|
|
||||||
final int targetDoc = docID - docBase;
|
|
||||||
scorer.docID();
|
|
||||||
scorer.iterator().advance(targetDoc);
|
|
||||||
|
|
||||||
scorer.getDocInfo().setOriginalDocScore(hit.score);
|
protected static void scoreSingleHit(IndexSearcher indexSearcher, int topN, LTRScoringQuery.ModelWeight modelWeight, int docBase, int hitUpto, ScoreDoc hit, int docID, LTRScoringQuery rerankingQuery, LTRScoringQuery.ModelWeight.ModelScorer scorer, ScoreDoc[] reranked) throws IOException {
|
||||||
hit.score = scorer.score();
|
final FeatureLogger featureLogger = rerankingQuery.getFeatureLogger();
|
||||||
if (hitUpto < topN) {
|
// Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to
|
||||||
reranked[hitUpto] = hit;
|
// call score
|
||||||
// if the heap is not full, maybe I want to log the features for this
|
// even if no feature scorers match, since a model might use that info to
|
||||||
// document
|
// return a
|
||||||
|
// non-zero score. Same applies for the case of advancing a LTRScoringQuery.ModelWeight.ModelScorer
|
||||||
|
// past the target
|
||||||
|
// doc since the model algorithm still needs to compute a potentially
|
||||||
|
// non-zero score from blank features.
|
||||||
|
assert (scorer != null);
|
||||||
|
final int targetDoc = docID - docBase;
|
||||||
|
scorer.docID();
|
||||||
|
scorer.iterator().advance(targetDoc);
|
||||||
|
|
||||||
|
scorer.getDocInfo().setOriginalDocScore(hit.score);
|
||||||
|
hit.score = scorer.score();
|
||||||
|
if (hitUpto < topN) {
|
||||||
|
reranked[hitUpto] = hit;
|
||||||
|
// if the heap is not full, maybe I want to log the features for this
|
||||||
|
// document
|
||||||
|
if (featureLogger != null && indexSearcher instanceof SolrIndexSearcher) {
|
||||||
|
featureLogger.log(hit.doc, rerankingQuery, (SolrIndexSearcher) indexSearcher,
|
||||||
|
modelWeight.getFeaturesInfo());
|
||||||
|
}
|
||||||
|
} else if (hitUpto == topN) {
|
||||||
|
// collected topN document, I create the heap
|
||||||
|
heapify(reranked, topN);
|
||||||
|
}
|
||||||
|
if (hitUpto >= topN) {
|
||||||
|
// once that heap is ready, if the score of this document is lower that
|
||||||
|
// the minimum
|
||||||
|
// i don't want to log the feature. Otherwise I replace it with the
|
||||||
|
// minimum and fix the
|
||||||
|
// heap.
|
||||||
|
if (hit.score > reranked[0].score) {
|
||||||
|
reranked[0] = hit;
|
||||||
|
heapAdjust(reranked, topN, 0);
|
||||||
if (featureLogger != null && indexSearcher instanceof SolrIndexSearcher) {
|
if (featureLogger != null && indexSearcher instanceof SolrIndexSearcher) {
|
||||||
featureLogger.log(hit.doc, scoringQuery, (SolrIndexSearcher)indexSearcher,
|
featureLogger.log(hit.doc, rerankingQuery, (SolrIndexSearcher) indexSearcher,
|
||||||
modelWeight.getFeaturesInfo());
|
modelWeight.getFeaturesInfo());
|
||||||
}
|
}
|
||||||
} else if (hitUpto == topN) {
|
|
||||||
// collected topN document, I create the heap
|
|
||||||
heapify(reranked, topN);
|
|
||||||
}
|
}
|
||||||
if (hitUpto >= topN) {
|
|
||||||
// once that heap is ready, if the score of this document is lower that
|
|
||||||
// the minimum
|
|
||||||
// i don't want to log the feature. Otherwise I replace it with the
|
|
||||||
// minimum and fix the
|
|
||||||
// heap.
|
|
||||||
if (hit.score > reranked[0].score) {
|
|
||||||
reranked[0] = hit;
|
|
||||||
heapAdjust(reranked, topN, 0);
|
|
||||||
if (featureLogger != null && indexSearcher instanceof SolrIndexSearcher) {
|
|
||||||
featureLogger.log(hit.doc, scoringQuery, (SolrIndexSearcher)indexSearcher,
|
|
||||||
modelWeight.getFeaturesInfo());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
hitUpto++;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Explanation explain(IndexSearcher searcher,
|
public Explanation explain(IndexSearcher searcher,
|
||||||
Explanation firstPassExplanation, int docID) throws IOException {
|
Explanation firstPassExplanation, int docID) throws IOException {
|
||||||
|
return getExplanation(searcher, docID, scoringQuery);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected static Explanation getExplanation(IndexSearcher searcher, int docID, LTRScoringQuery rerankingQuery) throws IOException {
|
||||||
final List<LeafReaderContext> leafContexts = searcher.getTopReaderContext()
|
final List<LeafReaderContext> leafContexts = searcher.getTopReaderContext()
|
||||||
.leaves();
|
.leaves();
|
||||||
final int n = ReaderUtil.subIndex(docID, leafContexts);
|
final int n = ReaderUtil.subIndex(docID, leafContexts);
|
||||||
final LeafReaderContext context = leafContexts.get(n);
|
final LeafReaderContext context = leafContexts.get(n);
|
||||||
final int deBasedDoc = docID - context.docBase;
|
final int deBasedDoc = docID - context.docBase;
|
||||||
final Weight modelWeight = searcher.createWeight(searcher.rewrite(scoringQuery),
|
final Weight rankingWeight;
|
||||||
ScoreMode.COMPLETE, 1);
|
if (rerankingQuery instanceof OriginalRankingLTRScoringQuery) {
|
||||||
return modelWeight.explain(context, deBasedDoc);
|
rankingWeight = rerankingQuery.getOriginalQuery().createWeight(searcher, ScoreMode.COMPLETE, 1);
|
||||||
|
} else {
|
||||||
|
rankingWeight = searcher.createWeight(searcher.rewrite(rerankingQuery),
|
||||||
|
ScoreMode.COMPLETE, 1);
|
||||||
|
}
|
||||||
|
return rankingWeight.explain(context, deBasedDoc);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo(LTRScoringQuery.ModelWeight modelWeight,
|
public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo(LTRScoringQuery.ModelWeight modelWeight,
|
||||||
|
|
|
@ -102,6 +102,10 @@ public class LTRScoringQuery extends Query implements Accountable {
|
||||||
return ltrScoringModel;
|
return ltrScoringModel;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public String getScoringModelName() {
|
||||||
|
return ltrScoringModel.getName();
|
||||||
|
}
|
||||||
|
|
||||||
public void setFeatureLogger(FeatureLogger fl) {
|
public void setFeatureLogger(FeatureLogger fl) {
|
||||||
this.fl = fl;
|
this.fl = fl;
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,8 +26,8 @@ public class SolrQueryRequestContextUtils {
|
||||||
/** key of the feature logger in the request context **/
|
/** key of the feature logger in the request context **/
|
||||||
private static final String FEATURE_LOGGER = LTR_PREFIX + "feature_logger";
|
private static final String FEATURE_LOGGER = LTR_PREFIX + "feature_logger";
|
||||||
|
|
||||||
/** key of the scoring query in the request context **/
|
/** key of the scoring queries in the request context **/
|
||||||
private static final String SCORING_QUERY = LTR_PREFIX + "scoring_query";
|
private static final String SCORING_QUERIES = LTR_PREFIX + "scoring_queries";
|
||||||
|
|
||||||
/** key of the isExtractingFeatures flag in the request context **/
|
/** key of the isExtractingFeatures flag in the request context **/
|
||||||
private static final String IS_EXTRACTING_FEATURES = LTR_PREFIX + "isExtractingFeatures";
|
private static final String IS_EXTRACTING_FEATURES = LTR_PREFIX + "isExtractingFeatures";
|
||||||
|
@ -47,12 +47,12 @@ public class SolrQueryRequestContextUtils {
|
||||||
|
|
||||||
/** scoring query accessors **/
|
/** scoring query accessors **/
|
||||||
|
|
||||||
public static void setScoringQuery(SolrQueryRequest req, LTRScoringQuery scoringQuery) {
|
public static void setScoringQueries(SolrQueryRequest req, LTRScoringQuery[] scoringQueries) {
|
||||||
req.getContext().put(SCORING_QUERY, scoringQuery);
|
req.getContext().put(SCORING_QUERIES, scoringQueries);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static LTRScoringQuery getScoringQuery(SolrQueryRequest req) {
|
public static LTRScoringQuery[] getScoringQueries(SolrQueryRequest req) {
|
||||||
return (LTRScoringQuery) req.getContext().get(SCORING_QUERY);
|
return (LTRScoringQuery[]) req.getContext().get(SCORING_QUERIES);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** isExtractingFeatures flag accessors **/
|
/** isExtractingFeatures flag accessors **/
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.solr.ltr.interleaving;
|
||||||
|
|
||||||
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
|
import org.apache.solr.ltr.interleaving.algorithms.TeamDraftInterleaving;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interleaving considers two ranking models: modelA and modelB.
|
||||||
|
* For a given query, each model returns its ranked list of documents La = (a1,a2,...) and Lb = (b1, b2, ...).
|
||||||
|
* An Interleaving algorithm creates a unique ranked list I = (i1, i2, ...).
|
||||||
|
* This list is created by interleaving elements from the two lists la and lb as described by the implementation algorithm.
|
||||||
|
* Each element Ij is labelled TeamA if it is selected from La and TeamB if it is selected from Lb.
|
||||||
|
*/
|
||||||
|
public interface Interleaving {
|
||||||
|
|
||||||
|
String TEAM_DRAFT = "TeamDraft";
|
||||||
|
|
||||||
|
InterleavingResult interleave(ScoreDoc[] rerankedA, ScoreDoc[] rerankedB);
|
||||||
|
|
||||||
|
static Interleaving getImplementation(String algorithm) {
|
||||||
|
switch(algorithm) {
|
||||||
|
case TEAM_DRAFT:
|
||||||
|
default:
|
||||||
|
return new TeamDraftInterleaving();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,42 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.solr.ltr.interleaving;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
|
|
||||||
|
public class InterleavingResult {
|
||||||
|
final private ScoreDoc[] interleavedResults;
|
||||||
|
final private ArrayList<Set<Integer>> interleavingPicks;
|
||||||
|
|
||||||
|
public InterleavingResult(ScoreDoc[] interleavedResults, ArrayList<Set<Integer>> interleavingPicks) {
|
||||||
|
this.interleavedResults = interleavedResults;
|
||||||
|
this.interleavingPicks = interleavingPicks;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ScoreDoc[] getInterleavedResults() {
|
||||||
|
return interleavedResults;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArrayList<Set<Integer>> getInterleavingPicks() {
|
||||||
|
return interleavingPicks;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,75 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.solr.ltr.interleaving;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
import org.apache.lucene.search.Query;
|
||||||
|
import org.apache.solr.ltr.search.LTRQuery;
|
||||||
|
import org.apache.solr.search.RankQuery;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A learning to rank Query with Interleaving, will incapsulate two models, and delegate to it the rescoring
|
||||||
|
* of the documents.
|
||||||
|
**/
|
||||||
|
public class LTRInterleavingQuery extends LTRQuery {
|
||||||
|
private final LTRInterleavingScoringQuery[] rerankingQueries;
|
||||||
|
private final Interleaving interlavingAlgorithm;
|
||||||
|
|
||||||
|
public LTRInterleavingQuery(Interleaving interleavingAlgorithm, LTRInterleavingScoringQuery[] rerankingQueries, int rerankDocs) {
|
||||||
|
super(null, rerankDocs, new LTRInterleavingRescorer(interleavingAlgorithm, rerankingQueries));
|
||||||
|
this.rerankingQueries = rerankingQueries;
|
||||||
|
this.interlavingAlgorithm = interleavingAlgorithm;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return 31 * classHash() + (mainQuery.hashCode() + rerankingQueries.hashCode() + reRankDocs);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
return sameClassAs(o) && equalsTo(getClass().cast(o));
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean equalsTo(LTRInterleavingQuery other) {
|
||||||
|
return (mainQuery.equals(other.mainQuery)
|
||||||
|
&& rerankingQueries.equals(other.rerankingQueries) && (reRankDocs == other.reRankDocs));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RankQuery wrap(Query _mainQuery) {
|
||||||
|
super.wrap(_mainQuery);
|
||||||
|
for(LTRInterleavingScoringQuery rerankingQuery: rerankingQueries){
|
||||||
|
rerankingQuery.setOriginalQuery(_mainQuery);
|
||||||
|
}
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(String field) {
|
||||||
|
return "{!ltr mainQuery='" + mainQuery.toString() + "' rerankingQueries='"
|
||||||
|
+ Arrays.toString(rerankingQueries) + "' reRankDocs=" + reRankDocs + "}";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Query rewrite(Query rewrittenMainQuery) throws IOException {
|
||||||
|
return new LTRInterleavingQuery(interlavingAlgorithm, rerankingQueries, reRankDocs).wrap(rewrittenMainQuery);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,162 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.solr.ltr.interleaving;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
|
import org.apache.lucene.search.Explanation;
|
||||||
|
import org.apache.lucene.search.IndexSearcher;
|
||||||
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
|
import org.apache.lucene.search.ScoreMode;
|
||||||
|
import org.apache.lucene.search.TopDocs;
|
||||||
|
import org.apache.solr.ltr.LTRRescorer;
|
||||||
|
import org.apache.solr.ltr.LTRScoringQuery;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implements the rescoring logic. The top documents returned by solr with their
|
||||||
|
* original scores, will be processed by a {@link LTRScoringQuery} that will assign a
|
||||||
|
* new score to each document. The top documents will be resorted based on the
|
||||||
|
* new score.
|
||||||
|
* */
|
||||||
|
public class LTRInterleavingRescorer extends LTRRescorer {
|
||||||
|
|
||||||
|
final private LTRInterleavingScoringQuery[] rerankingQueries;
|
||||||
|
private Integer originalRankingIndex = null;
|
||||||
|
final private Interleaving interleavingAlgorithm;
|
||||||
|
|
||||||
|
public LTRInterleavingRescorer( Interleaving interleavingAlgorithm, LTRInterleavingScoringQuery[] rerankingQueries) {
|
||||||
|
this.rerankingQueries = rerankingQueries;
|
||||||
|
this.interleavingAlgorithm = interleavingAlgorithm;
|
||||||
|
for(int i=0;i<this.rerankingQueries.length;i++){
|
||||||
|
if(this.rerankingQueries[i] instanceof OriginalRankingLTRScoringQuery){
|
||||||
|
this.originalRankingIndex = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* rescores the documents:
|
||||||
|
*
|
||||||
|
* @param searcher
|
||||||
|
* current IndexSearcher
|
||||||
|
* @param firstPassTopDocs
|
||||||
|
* documents to rerank;
|
||||||
|
* @param topN
|
||||||
|
* documents to return;
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs,
|
||||||
|
int topN) throws IOException {
|
||||||
|
if ((topN == 0) || (firstPassTopDocs.scoreDocs.length == 0)) {
|
||||||
|
return firstPassTopDocs;
|
||||||
|
}
|
||||||
|
|
||||||
|
ScoreDoc[] firstPassResults = null;
|
||||||
|
if(originalRankingIndex != null) {
|
||||||
|
firstPassResults = new ScoreDoc[firstPassTopDocs.scoreDocs.length];
|
||||||
|
System.arraycopy(firstPassTopDocs.scoreDocs, 0, firstPassResults, 0, firstPassTopDocs.scoreDocs.length);
|
||||||
|
}
|
||||||
|
topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value));
|
||||||
|
|
||||||
|
ScoreDoc[][] reRankedPerModel = rerank(searcher,topN,getFirstPassDocsRanked(firstPassTopDocs));
|
||||||
|
if (originalRankingIndex != null) {
|
||||||
|
reRankedPerModel[originalRankingIndex] = firstPassResults;
|
||||||
|
}
|
||||||
|
InterleavingResult interleaved = interleavingAlgorithm.interleave(reRankedPerModel[0], reRankedPerModel[1]);
|
||||||
|
ScoreDoc[] interleavedResults = interleaved.getInterleavedResults();
|
||||||
|
|
||||||
|
ArrayList<Set<Integer>> interleavingPicks = interleaved.getInterleavingPicks();
|
||||||
|
rerankingQueries[0].setPickedInterleavingDocIds(interleavingPicks.get(0));
|
||||||
|
rerankingQueries[1].setPickedInterleavingDocIds(interleavingPicks.get(1));
|
||||||
|
|
||||||
|
return new TopDocs(firstPassTopDocs.totalHits, interleavedResults);
|
||||||
|
}
|
||||||
|
|
||||||
|
private ScoreDoc[][] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPassResults) throws IOException {
|
||||||
|
ScoreDoc[][] reRankedPerModel = new ScoreDoc[rerankingQueries.length][topN];
|
||||||
|
final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
|
||||||
|
LTRScoringQuery.ModelWeight[] modelWeights = new LTRScoringQuery.ModelWeight[rerankingQueries.length];
|
||||||
|
for (int i = 0; i < rerankingQueries.length; i++) {
|
||||||
|
if (originalRankingIndex == null || originalRankingIndex != i) {
|
||||||
|
modelWeights[i] = (LTRScoringQuery.ModelWeight) searcher
|
||||||
|
.createWeight(searcher.rewrite(rerankingQueries[i]), ScoreMode.COMPLETE, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
scoreFeatures(searcher, topN, modelWeights, firstPassResults, leaves, reRankedPerModel);
|
||||||
|
|
||||||
|
for (int i = 0; i < rerankingQueries.length; i++) {
|
||||||
|
if (originalRankingIndex == null || originalRankingIndex != i) {
|
||||||
|
sortByScore(reRankedPerModel[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return reRankedPerModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void scoreFeatures(IndexSearcher indexSearcher,
|
||||||
|
int topN, LTRScoringQuery.ModelWeight[] modelWeights, ScoreDoc[] hits, List<LeafReaderContext> leaves,
|
||||||
|
ScoreDoc[][] rerankedPerModel) throws IOException {
|
||||||
|
|
||||||
|
int readerUpto = -1;
|
||||||
|
int endDoc = 0;
|
||||||
|
int docBase = 0;
|
||||||
|
int hitUpto = 0;
|
||||||
|
LTRScoringQuery.ModelWeight.ModelScorer[] scorers = new LTRScoringQuery.ModelWeight.ModelScorer[rerankingQueries.length];
|
||||||
|
while (hitUpto < hits.length) {
|
||||||
|
final ScoreDoc hit = hits[hitUpto];
|
||||||
|
final int docID = hit.doc;
|
||||||
|
LeafReaderContext readerContext = null;
|
||||||
|
while (docID >= endDoc) {
|
||||||
|
readerUpto++;
|
||||||
|
readerContext = leaves.get(readerUpto);
|
||||||
|
endDoc = readerContext.docBase + readerContext.reader().maxDoc();
|
||||||
|
}
|
||||||
|
|
||||||
|
// We advanced to another segment
|
||||||
|
if (readerContext != null) {
|
||||||
|
docBase = readerContext.docBase;
|
||||||
|
for (int i = 0; i < modelWeights.length; i++) {
|
||||||
|
if (modelWeights[i] != null) {
|
||||||
|
scorers[i] = modelWeights[i].scorer(readerContext);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < rerankingQueries.length; i++) {
|
||||||
|
if (modelWeights[i] != null) {
|
||||||
|
scoreSingleHit(indexSearcher, topN, modelWeights[i], docBase, hitUpto, new ScoreDoc(hit.doc, hit.score, hit.shardIndex), docID, rerankingQueries[i], scorers[i], rerankedPerModel[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
hitUpto++;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Explanation explain(IndexSearcher searcher,
|
||||||
|
Explanation firstPassExplanation, int docID) throws IOException {
|
||||||
|
LTRScoringQuery pickedRerankModel = rerankingQueries[0];
|
||||||
|
if (rerankingQueries[1].getPickedInterleavingDocIds().contains(docID)) {
|
||||||
|
pickedRerankModel = rerankingQueries[1];
|
||||||
|
}
|
||||||
|
return getExplanation(searcher, docID, pickedRerankModel);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,53 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.solr.ltr.interleaving;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import org.apache.solr.ltr.LTRScoringQuery;
|
||||||
|
import org.apache.solr.ltr.LTRThreadModule;
|
||||||
|
import org.apache.solr.ltr.model.LTRScoringModel;
|
||||||
|
|
||||||
|
public class LTRInterleavingScoringQuery extends LTRScoringQuery {
|
||||||
|
|
||||||
|
// Model was picked for this Docs
|
||||||
|
private Set<Integer> pickedInterleavingDocIds;
|
||||||
|
|
||||||
|
public LTRInterleavingScoringQuery(LTRScoringModel ltrScoringModel) {
|
||||||
|
super(ltrScoringModel);
|
||||||
|
}
|
||||||
|
|
||||||
|
public LTRInterleavingScoringQuery(LTRScoringModel ltrScoringModel, boolean extractAllFeatures) {
|
||||||
|
super(ltrScoringModel, extractAllFeatures);
|
||||||
|
}
|
||||||
|
|
||||||
|
public LTRInterleavingScoringQuery(LTRScoringModel ltrScoringModel,
|
||||||
|
Map<String, String[]> externalFeatureInfo,
|
||||||
|
boolean extractAllFeatures, LTRThreadModule ltrThreadMgr) {
|
||||||
|
super(ltrScoringModel, externalFeatureInfo, extractAllFeatures, ltrThreadMgr);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Set<Integer> getPickedInterleavingDocIds() {
|
||||||
|
return pickedInterleavingDocIds;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setPickedInterleavingDocIds(Set<Integer> pickedInterleavingDocIds) {
|
||||||
|
this.pickedInterleavingDocIds = pickedInterleavingDocIds;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.solr.ltr.interleaving;
|
||||||
|
|
||||||
|
public final class OriginalRankingLTRScoringQuery extends LTRInterleavingScoringQuery {
|
||||||
|
|
||||||
|
private final String originalRankingModelName;
|
||||||
|
|
||||||
|
public OriginalRankingLTRScoringQuery(String originalRankingModelName) {
|
||||||
|
super(null /* LTRScoringModel */);
|
||||||
|
this.originalRankingModelName = originalRankingModelName;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getScoringModelName() {
|
||||||
|
return this.originalRankingModelName;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,127 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.solr.ltr.interleaving.algorithms;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.LinkedList;
|
||||||
|
import java.util.Random;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
|
import org.apache.solr.ltr.interleaving.Interleaving;
|
||||||
|
import org.apache.solr.ltr.interleaving.InterleavingResult;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interleaving was introduced the first time by Joachims in [1, 2].
|
||||||
|
* Team Draft Interleaving is among the most successful and used interleaving approaches[3].
|
||||||
|
* Team Draft Interleaving implements a method similar to the way in which captains select their players in team-matches.
|
||||||
|
* Team Draft Interleaving produces a fair distribution of ranking models’ elements in the final interleaved list.
|
||||||
|
* "Team draft interleaving" has also proved to overcome an issue of the "Balanced interleaving" approach, in determining the winning model[4].
|
||||||
|
* <p>
|
||||||
|
* [1] T. Joachims. Optimizing search engines using clickthrough data. KDD (2002)
|
||||||
|
* [2] T.Joachims.Evaluatingretrievalperformanceusingclickthroughdata.InJ.Franke, G. Nakhaeizadeh, and I. Renz, editors,
|
||||||
|
* Text Mining, pages 79–96. Physica/Springer (2003)
|
||||||
|
* [3] F. Radlinski, M. Kurup, and T. Joachims. How does clickthrough data reflect re-
|
||||||
|
* trieval quality? In CIKM, pages 43–52. ACM Press (2008)
|
||||||
|
* [4] O. Chapelle, T. Joachims, F. Radlinski, and Y. Yue.
|
||||||
|
* Large-scale validation and analysis of interleaved search evaluation. ACM TOIS, 30(1):1–41, Feb. (2012)
|
||||||
|
*/
|
||||||
|
public class TeamDraftInterleaving implements Interleaving {
|
||||||
|
public static Random RANDOM;
|
||||||
|
|
||||||
|
static {
|
||||||
|
// We try to make things reproducible in the context of our tests by initializing the random instance
|
||||||
|
// based on the current seed
|
||||||
|
String seed = System.getProperty("tests.seed");
|
||||||
|
if (seed == null) {
|
||||||
|
RANDOM = new Random();
|
||||||
|
} else {
|
||||||
|
RANDOM = new Random(seed.hashCode());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Team Draft Interleaving considers two ranking models: modelA and modelB.
|
||||||
|
* For a given query, each model returns its ranked list of documents La = (a1,a2,...) and Lb = (b1, b2, ...).
|
||||||
|
* The algorithm creates a unique ranked list I = (i1, i2, ...).
|
||||||
|
* This list is created by interleaving elements from the two lists la and lb as described by Chapelle et al.[1].
|
||||||
|
* Each element Ij is labelled TeamA if it is selected from La and TeamB if it is selected from Lb.
|
||||||
|
* <p>
|
||||||
|
* [1] O. Chapelle, T. Joachims, F. Radlinski, and Y. Yue.
|
||||||
|
* Large-scale validation and analysis of interleaved search evaluation. ACM TOIS, 30(1):1–41, Feb. (2012)
|
||||||
|
* <p>
|
||||||
|
* Assumptions:
|
||||||
|
* - rerankedA and rerankedB has the same length.
|
||||||
|
* They contains the same search results, ranked differently by two ranking models
|
||||||
|
* - each reranked list can not contain the same search result more than once.
|
||||||
|
* - results are all from the same shard
|
||||||
|
*
|
||||||
|
* @param rerankedA a ranked list of search results produced by a ranking model A
|
||||||
|
* @param rerankedB a ranked list of search results produced by a ranking model B
|
||||||
|
* @return the interleaved ranking list
|
||||||
|
*/
|
||||||
|
public InterleavingResult interleave(ScoreDoc[] rerankedA, ScoreDoc[] rerankedB) {
|
||||||
|
LinkedList<ScoreDoc> interleavedResults = new LinkedList<>();
|
||||||
|
HashSet<Integer> alreadyAdded = new HashSet<>();
|
||||||
|
ScoreDoc[] interleavedResultArray = new ScoreDoc[rerankedA.length];
|
||||||
|
ArrayList<Set<Integer>> interleavingPicks = new ArrayList<>(2);
|
||||||
|
Set<Integer> teamA = new HashSet<>();
|
||||||
|
Set<Integer> teamB = new HashSet<>();
|
||||||
|
int topN = rerankedA.length;
|
||||||
|
int indexA = 0, indexB = 0;
|
||||||
|
|
||||||
|
while (interleavedResults.size() < topN && indexA < rerankedA.length && indexB < rerankedB.length) {
|
||||||
|
if(teamA.size()<teamB.size() || (teamA.size()==teamB.size() && !RANDOM.nextBoolean())){
|
||||||
|
indexA = updateIndex(alreadyAdded, indexA, rerankedA);
|
||||||
|
interleavedResults.add(rerankedA[indexA]);
|
||||||
|
alreadyAdded.add(rerankedA[indexA].doc);
|
||||||
|
teamA.add(rerankedA[indexA].doc);
|
||||||
|
indexA++;
|
||||||
|
} else{
|
||||||
|
indexB = updateIndex(alreadyAdded,indexB,rerankedB);
|
||||||
|
interleavedResults.add(rerankedB[indexB]);
|
||||||
|
alreadyAdded.add(rerankedB[indexB].doc);
|
||||||
|
teamB.add(rerankedB[indexB].doc);
|
||||||
|
indexB++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
interleavingPicks.add(teamA);
|
||||||
|
interleavingPicks.add(teamB);
|
||||||
|
interleavedResultArray = interleavedResults.toArray(interleavedResultArray);
|
||||||
|
|
||||||
|
return new InterleavingResult(interleavedResultArray,interleavingPicks);
|
||||||
|
}
|
||||||
|
|
||||||
|
private int updateIndex(HashSet<Integer> alreadyAdded, int index, ScoreDoc[] reranked) {
|
||||||
|
boolean foundElementToAdd = false;
|
||||||
|
while (index < reranked.length && !foundElementToAdd) {
|
||||||
|
ScoreDoc elementToCheck = reranked[index];
|
||||||
|
if (alreadyAdded.contains(elementToCheck.doc)) {
|
||||||
|
index++;
|
||||||
|
} else {
|
||||||
|
foundElementToAdd = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return index;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void setRANDOM(Random RANDOM) {
|
||||||
|
TeamDraftInterleaving.RANDOM = RANDOM;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Contains Various Interleaving Algorithms
|
||||||
|
*/
|
||||||
|
package org.apache.solr.ltr.interleaving.algorithms;
|
|
@ -0,0 +1,21 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Contains Various Interleaving auxiliary classes
|
||||||
|
*/
|
||||||
|
package org.apache.solr.ltr.interleaving;
|
|
@ -36,6 +36,8 @@ import org.apache.solr.ltr.LTRScoringQuery;
|
||||||
import org.apache.solr.ltr.LTRThreadModule;
|
import org.apache.solr.ltr.LTRThreadModule;
|
||||||
import org.apache.solr.ltr.SolrQueryRequestContextUtils;
|
import org.apache.solr.ltr.SolrQueryRequestContextUtils;
|
||||||
import org.apache.solr.ltr.feature.Feature;
|
import org.apache.solr.ltr.feature.Feature;
|
||||||
|
import org.apache.solr.ltr.interleaving.LTRInterleavingScoringQuery;
|
||||||
|
import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery;
|
||||||
import org.apache.solr.ltr.model.LTRScoringModel;
|
import org.apache.solr.ltr.model.LTRScoringModel;
|
||||||
import org.apache.solr.ltr.norm.Normalizer;
|
import org.apache.solr.ltr.norm.Normalizer;
|
||||||
import org.apache.solr.ltr.search.LTRQParserPlugin;
|
import org.apache.solr.ltr.search.LTRQParserPlugin;
|
||||||
|
@ -126,14 +128,15 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory {
|
||||||
SolrQueryRequestContextUtils.setIsExtractingFeatures(req);
|
SolrQueryRequestContextUtils.setIsExtractingFeatures(req);
|
||||||
|
|
||||||
// Communicate which feature store we are requesting features for
|
// Communicate which feature store we are requesting features for
|
||||||
SolrQueryRequestContextUtils.setFvStoreName(req, localparams.get(FV_STORE, defaultStore));
|
final String fvStoreName = localparams.get(FV_STORE);
|
||||||
|
SolrQueryRequestContextUtils.setFvStoreName(req, (fvStoreName == null ? defaultStore : fvStoreName));
|
||||||
|
|
||||||
// Create and supply the feature logger to be used
|
// Create and supply the feature logger to be used
|
||||||
SolrQueryRequestContextUtils.setFeatureLogger(req,
|
SolrQueryRequestContextUtils.setFeatureLogger(req,
|
||||||
createFeatureLogger(
|
createFeatureLogger(
|
||||||
localparams.get(FV_FORMAT)));
|
localparams.get(FV_FORMAT)));
|
||||||
|
|
||||||
return new FeatureTransformer(name, localparams, req);
|
return new FeatureTransformer(name, localparams, req, (fvStoreName != null) /* hasExplicitFeatureStore */);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -163,11 +166,18 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory {
|
||||||
final private String name;
|
final private String name;
|
||||||
final private SolrParams localparams;
|
final private SolrParams localparams;
|
||||||
final private SolrQueryRequest req;
|
final private SolrQueryRequest req;
|
||||||
|
final private boolean hasExplicitFeatureStore;
|
||||||
|
|
||||||
private List<LeafReaderContext> leafContexts;
|
private List<LeafReaderContext> leafContexts;
|
||||||
private SolrIndexSearcher searcher;
|
private SolrIndexSearcher searcher;
|
||||||
private LTRScoringQuery scoringQuery;
|
/**
|
||||||
private LTRScoringQuery.ModelWeight modelWeight;
|
* rerankingQueries, modelWeights have:
|
||||||
|
* length=1 - [Classic LTR] When reranking with a single model
|
||||||
|
* length=2 - [Interleaving] When reranking with interleaving (two ranking models are involved)
|
||||||
|
*/
|
||||||
|
private LTRScoringQuery[] rerankingQueriesFromContext;
|
||||||
|
private LTRScoringQuery[] rerankingQueries;
|
||||||
|
private LTRScoringQuery.ModelWeight[] modelWeights;
|
||||||
private FeatureLogger featureLogger;
|
private FeatureLogger featureLogger;
|
||||||
private boolean docsWereNotReranked;
|
private boolean docsWereNotReranked;
|
||||||
|
|
||||||
|
@ -177,10 +187,11 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory {
|
||||||
* feature vectors
|
* feature vectors
|
||||||
*/
|
*/
|
||||||
public FeatureTransformer(String name, SolrParams localparams,
|
public FeatureTransformer(String name, SolrParams localparams,
|
||||||
SolrQueryRequest req) {
|
SolrQueryRequest req, boolean hasExplicitFeatureStore) {
|
||||||
this.name = name;
|
this.name = name;
|
||||||
this.localparams = localparams;
|
this.localparams = localparams;
|
||||||
this.req = req;
|
this.req = req;
|
||||||
|
this.hasExplicitFeatureStore = hasExplicitFeatureStore;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -209,51 +220,102 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory {
|
||||||
threadManager.setExecutor(context.getRequest().getCore().getCoreContainer().getUpdateShardHandler().getUpdateExecutor());
|
threadManager.setExecutor(context.getRequest().getCore().getCoreContainer().getUpdateShardHandler().getUpdateExecutor());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup LTRScoringQuery
|
rerankingQueriesFromContext = SolrQueryRequestContextUtils.getScoringQueries(req);
|
||||||
scoringQuery = SolrQueryRequestContextUtils.getScoringQuery(req);
|
docsWereNotReranked = (rerankingQueriesFromContext == null || rerankingQueriesFromContext.length == 0);
|
||||||
docsWereNotReranked = (scoringQuery == null);
|
String transformerFeatureStore = SolrQueryRequestContextUtils.getFvStoreName(req);
|
||||||
String featureStoreName = SolrQueryRequestContextUtils.getFvStoreName(req);
|
Map<String, String[]> transformerExternalFeatureInfo = LTRQParserPlugin.extractEFIParams(localparams);
|
||||||
if (docsWereNotReranked || (featureStoreName != null && (!featureStoreName.equals(scoringQuery.getScoringModel().getFeatureStoreName())))) {
|
|
||||||
// if store is set in the transformer we should overwrite the logger
|
|
||||||
|
|
||||||
final ManagedFeatureStore fr = ManagedFeatureStore.getManagedFeatureStore(req.getCore());
|
final LoggingModel loggingModel = createLoggingModel(transformerFeatureStore);
|
||||||
|
setupRerankingQueriesForLogging(transformerFeatureStore, transformerExternalFeatureInfo, loggingModel);
|
||||||
|
setupRerankingWeightsForLogging(context);
|
||||||
|
}
|
||||||
|
|
||||||
final FeatureStore store = fr.getFeatureStore(featureStoreName);
|
/**
|
||||||
featureStoreName = store.getName(); // if featureStoreName was null before this gets actual name
|
* The loggingModel is an empty model that is just used to extract the features
|
||||||
|
* and log them
|
||||||
|
* @param transformerFeatureStore the explicit transformer feature store
|
||||||
|
*/
|
||||||
|
private LoggingModel createLoggingModel(String transformerFeatureStore) {
|
||||||
|
final ManagedFeatureStore fr = ManagedFeatureStore.getManagedFeatureStore(req.getCore());
|
||||||
|
|
||||||
try {
|
final FeatureStore store = fr.getFeatureStore(transformerFeatureStore);
|
||||||
final LoggingModel lm = new LoggingModel(loggingModelName,
|
transformerFeatureStore = store.getName(); // if transformerFeatureStore was null before this gets actual name
|
||||||
featureStoreName, store.getFeatures());
|
|
||||||
|
|
||||||
scoringQuery = new LTRScoringQuery(lm,
|
return new LoggingModel(loggingModelName,
|
||||||
LTRQParserPlugin.extractEFIParams(localparams),
|
transformerFeatureStore, store.getFeatures());
|
||||||
true,
|
}
|
||||||
threadManager); // request feature weights to be created for all features
|
|
||||||
|
|
||||||
}catch (final Exception e) {
|
/**
|
||||||
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
|
* When preparing the reranking queries for logging features various scenarios apply:
|
||||||
"retrieving the feature store "+featureStoreName, e);
|
*
|
||||||
|
* No Reranking
|
||||||
|
* There is the need of a logger model from the default feature store or the explicit feature store passed
|
||||||
|
* to extract the feature vector
|
||||||
|
*
|
||||||
|
* Re Ranking
|
||||||
|
* 1) If no explicit feature store is passed, the models for each reranking query can be safely re-used
|
||||||
|
* the feature vector can be fetched from the feature vector cache.
|
||||||
|
* 2) If an explicit feature store is passed, and no reranking query uses a model with that feature store,
|
||||||
|
* There is the need of a logger model to extract the feature vector
|
||||||
|
* 3) If an explicit feature store is passed, and there is a reranking query that uses a model with that feature store,
|
||||||
|
* the model can be re-used and there is no need for a logging model
|
||||||
|
*
|
||||||
|
* @param transformerFeatureStore explicit feature store for the transformer
|
||||||
|
* @param transformerExternalFeatureInfo explicit efi for the transformer
|
||||||
|
*/
|
||||||
|
private void setupRerankingQueriesForLogging(String transformerFeatureStore, Map<String, String[]> transformerExternalFeatureInfo, LoggingModel loggingModel) {
|
||||||
|
if (docsWereNotReranked) { //no reranking query
|
||||||
|
LTRScoringQuery loggingQuery = new LTRScoringQuery(loggingModel,
|
||||||
|
transformerExternalFeatureInfo,
|
||||||
|
true /* extractAllFeatures */,
|
||||||
|
threadManager);
|
||||||
|
rerankingQueries = new LTRScoringQuery[]{loggingQuery};
|
||||||
|
} else {
|
||||||
|
rerankingQueries = new LTRScoringQuery[rerankingQueriesFromContext.length];
|
||||||
|
System.arraycopy(rerankingQueriesFromContext, 0, rerankingQueries, 0, rerankingQueriesFromContext.length);
|
||||||
|
|
||||||
|
if (transformerFeatureStore != null) {// explicit feature store for the transformer
|
||||||
|
LTRScoringModel matchingRerankingModel = loggingModel;
|
||||||
|
for (LTRScoringQuery rerankingQuery : rerankingQueries) {
|
||||||
|
if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) &&
|
||||||
|
transformerFeatureStore.equals(rerankingQuery.getScoringModel().getFeatureStoreName())) {
|
||||||
|
matchingRerankingModel = rerankingQuery.getScoringModel();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < rerankingQueries.length; i++) {
|
||||||
|
rerankingQueries[i] = new LTRScoringQuery(
|
||||||
|
matchingRerankingModel,
|
||||||
|
(!transformerExternalFeatureInfo.isEmpty() ? transformerExternalFeatureInfo : rerankingQueries[i].getExternalFeatureInfo()),
|
||||||
|
true /* extractAllFeatures */,
|
||||||
|
threadManager);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (scoringQuery.getOriginalQuery() == null) {
|
private void setupRerankingWeightsForLogging(ResultContext context) {
|
||||||
scoringQuery.setOriginalQuery(context.getQuery());
|
modelWeights = new LTRScoringQuery.ModelWeight[rerankingQueries.length];
|
||||||
}
|
for (int i = 0; i < rerankingQueries.length; i++) {
|
||||||
if (scoringQuery.getFeatureLogger() == null){
|
if (rerankingQueries[i].getOriginalQuery() == null) {
|
||||||
scoringQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) );
|
rerankingQueries[i].setOriginalQuery(context.getQuery());
|
||||||
}
|
}
|
||||||
scoringQuery.setRequest(req);
|
rerankingQueries[i].setRequest(req);
|
||||||
|
if (!(rerankingQueries[i] instanceof OriginalRankingLTRScoringQuery) || hasExplicitFeatureStore) {
|
||||||
featureLogger = scoringQuery.getFeatureLogger();
|
if (rerankingQueries[i].getFeatureLogger() == null) {
|
||||||
|
rerankingQueries[i].setFeatureLogger(SolrQueryRequestContextUtils.getFeatureLogger(req));
|
||||||
try {
|
}
|
||||||
modelWeight = scoringQuery.createWeight(searcher, ScoreMode.COMPLETE, 1f);
|
featureLogger = rerankingQueries[i].getFeatureLogger();
|
||||||
} catch (final IOException e) {
|
try {
|
||||||
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e.getMessage(), e);
|
modelWeights[i] = rerankingQueries[i].createWeight(searcher, ScoreMode.COMPLETE, 1f);
|
||||||
}
|
} catch (final IOException e) {
|
||||||
if (modelWeight == null) {
|
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e.getMessage(), e);
|
||||||
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
|
}
|
||||||
"error logging the features, model weight is null");
|
if (modelWeights[i] == null) {
|
||||||
|
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
|
||||||
|
"error logging the features, model weight is null");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -271,17 +333,26 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory {
|
||||||
|
|
||||||
private void implTransform(SolrDocument doc, int docid, Float score)
|
private void implTransform(SolrDocument doc, int docid, Float score)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
Object fv = featureLogger.getFeatureVector(docid, scoringQuery, searcher);
|
LTRScoringQuery rerankingQuery = rerankingQueries[0];
|
||||||
if (fv == null) { // FV for this document was not in the cache
|
LTRScoringQuery.ModelWeight rerankingModelWeight = modelWeights[0];
|
||||||
fv = featureLogger.makeFeatureVector(
|
for (int i = 1; i < rerankingQueries.length; i++) {
|
||||||
LTRRescorer.extractFeaturesInfo(
|
if (((LTRInterleavingScoringQuery)rerankingQueriesFromContext[i]).getPickedInterleavingDocIds().contains(docid)) {
|
||||||
modelWeight,
|
rerankingQuery = rerankingQueries[i];
|
||||||
docid,
|
rerankingModelWeight = modelWeights[i];
|
||||||
(docsWereNotReranked ? score : null),
|
}
|
||||||
leafContexts));
|
}
|
||||||
|
if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) || hasExplicitFeatureStore) {
|
||||||
|
Object featureVector = featureLogger.getFeatureVector(docid, rerankingQuery, searcher);
|
||||||
|
if (featureVector == null) { // FV for this document was not in the cache
|
||||||
|
featureVector = featureLogger.makeFeatureVector(
|
||||||
|
LTRRescorer.extractFeaturesInfo(
|
||||||
|
rerankingModelWeight,
|
||||||
|
docid,
|
||||||
|
(docsWereNotReranked ? score : null),
|
||||||
|
leafContexts));
|
||||||
|
}
|
||||||
|
doc.addField(name, featureVector);
|
||||||
}
|
}
|
||||||
|
|
||||||
doc.addField(name, fv);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,114 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.solr.ltr.response.transform;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import org.apache.solr.common.SolrDocument;
|
||||||
|
import org.apache.solr.common.params.SolrParams;
|
||||||
|
import org.apache.solr.common.util.NamedList;
|
||||||
|
import org.apache.solr.ltr.interleaving.LTRInterleavingScoringQuery;
|
||||||
|
import org.apache.solr.ltr.LTRScoringQuery;
|
||||||
|
import org.apache.solr.ltr.SolrQueryRequestContextUtils;
|
||||||
|
import org.apache.solr.request.SolrQueryRequest;
|
||||||
|
import org.apache.solr.response.ResultContext;
|
||||||
|
import org.apache.solr.response.transform.DocTransformer;
|
||||||
|
import org.apache.solr.response.transform.TransformerFactory;
|
||||||
|
import org.apache.solr.util.SolrPluginUtils;
|
||||||
|
|
||||||
|
public class LTRInterleavingTransformerFactory extends TransformerFactory {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
@SuppressWarnings({"unchecked"})
|
||||||
|
public void init(@SuppressWarnings("rawtypes") NamedList args) {
|
||||||
|
super.init(args);
|
||||||
|
SolrPluginUtils.invokeSetters(this, args);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DocTransformer create(String name, SolrParams localparams,
|
||||||
|
SolrQueryRequest req) {
|
||||||
|
return new InterleavingTransformer(name, req);
|
||||||
|
}
|
||||||
|
|
||||||
|
class InterleavingTransformer extends DocTransformer {
|
||||||
|
|
||||||
|
final private String name;
|
||||||
|
final private SolrQueryRequest req;
|
||||||
|
|
||||||
|
private LTRInterleavingScoringQuery[] rerankingQueries;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param name
|
||||||
|
* Name of the field to be added in a document representing the
|
||||||
|
* model picked by the interleaving process
|
||||||
|
*/
|
||||||
|
public InterleavingTransformer(String name,
|
||||||
|
SolrQueryRequest req) {
|
||||||
|
this.name = name;
|
||||||
|
this.req = req;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getName() {
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setContext(ResultContext context) {
|
||||||
|
super.setContext(context);
|
||||||
|
if (context == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (context.getRequest() == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rerankingQueries = (LTRInterleavingScoringQuery[])SolrQueryRequestContextUtils.getScoringQueries(req);
|
||||||
|
for (int i = 0; i < rerankingQueries.length; i++) {
|
||||||
|
LTRScoringQuery scoringQuery = rerankingQueries[i];
|
||||||
|
|
||||||
|
if (scoringQuery.getOriginalQuery() == null) {
|
||||||
|
scoringQuery.setOriginalQuery(context.getQuery());
|
||||||
|
}
|
||||||
|
if (scoringQuery.getFeatureLogger() == null) {
|
||||||
|
scoringQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) );
|
||||||
|
}
|
||||||
|
scoringQuery.setRequest(req);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void transform(SolrDocument doc, int docid, float score)
|
||||||
|
throws IOException {
|
||||||
|
implTransform(doc, docid);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void transform(SolrDocument doc, int docid)
|
||||||
|
throws IOException {
|
||||||
|
implTransform(doc, docid);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void implTransform(SolrDocument doc, int docid) {
|
||||||
|
LTRScoringQuery rerankingQuery = rerankingQueries[0];
|
||||||
|
if (rerankingQueries.length > 1 && rerankingQueries[1].getPickedInterleavingDocIds().contains(docid)) {
|
||||||
|
rerankingQuery = rerankingQueries[1];
|
||||||
|
}
|
||||||
|
doc.addField(name, rerankingQuery.getScoringModelName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -23,26 +23,26 @@ import java.util.Map;
|
||||||
|
|
||||||
import org.apache.lucene.util.ResourceLoader;
|
import org.apache.lucene.util.ResourceLoader;
|
||||||
import org.apache.lucene.util.ResourceLoaderAware;
|
import org.apache.lucene.util.ResourceLoaderAware;
|
||||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
|
||||||
import org.apache.lucene.search.Query;
|
import org.apache.lucene.search.Query;
|
||||||
import org.apache.solr.common.SolrException;
|
import org.apache.solr.common.SolrException;
|
||||||
import org.apache.solr.common.params.SolrParams;
|
import org.apache.solr.common.params.SolrParams;
|
||||||
import org.apache.solr.common.util.NamedList;
|
import org.apache.solr.common.util.NamedList;
|
||||||
import org.apache.solr.core.SolrResourceLoader;
|
import org.apache.solr.core.SolrResourceLoader;
|
||||||
import org.apache.solr.ltr.LTRRescorer;
|
import org.apache.solr.ltr.interleaving.LTRInterleavingScoringQuery;
|
||||||
|
import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery;
|
||||||
import org.apache.solr.ltr.LTRScoringQuery;
|
import org.apache.solr.ltr.LTRScoringQuery;
|
||||||
import org.apache.solr.ltr.LTRThreadModule;
|
import org.apache.solr.ltr.LTRThreadModule;
|
||||||
import org.apache.solr.ltr.SolrQueryRequestContextUtils;
|
import org.apache.solr.ltr.SolrQueryRequestContextUtils;
|
||||||
|
import org.apache.solr.ltr.interleaving.Interleaving;
|
||||||
|
import org.apache.solr.ltr.interleaving.LTRInterleavingQuery;
|
||||||
import org.apache.solr.ltr.model.LTRScoringModel;
|
import org.apache.solr.ltr.model.LTRScoringModel;
|
||||||
import org.apache.solr.ltr.store.rest.ManagedFeatureStore;
|
import org.apache.solr.ltr.store.rest.ManagedFeatureStore;
|
||||||
import org.apache.solr.ltr.store.rest.ManagedModelStore;
|
import org.apache.solr.ltr.store.rest.ManagedModelStore;
|
||||||
import org.apache.solr.request.SolrQueryRequest;
|
import org.apache.solr.request.SolrQueryRequest;
|
||||||
import org.apache.solr.rest.ManagedResource;
|
import org.apache.solr.rest.ManagedResource;
|
||||||
import org.apache.solr.rest.ManagedResourceObserver;
|
import org.apache.solr.rest.ManagedResourceObserver;
|
||||||
import org.apache.solr.search.AbstractReRankQuery;
|
|
||||||
import org.apache.solr.search.QParser;
|
import org.apache.solr.search.QParser;
|
||||||
import org.apache.solr.search.QParserPlugin;
|
import org.apache.solr.search.QParserPlugin;
|
||||||
import org.apache.solr.search.RankQuery;
|
|
||||||
import org.apache.solr.search.SyntaxError;
|
import org.apache.solr.search.SyntaxError;
|
||||||
import org.apache.solr.util.SolrPluginUtils;
|
import org.apache.solr.util.SolrPluginUtils;
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ import org.apache.solr.util.SolrPluginUtils;
|
||||||
*/
|
*/
|
||||||
public class LTRQParserPlugin extends QParserPlugin implements ResourceLoaderAware, ManagedResourceObserver {
|
public class LTRQParserPlugin extends QParserPlugin implements ResourceLoaderAware, ManagedResourceObserver {
|
||||||
public static final String NAME = "ltr";
|
public static final String NAME = "ltr";
|
||||||
private static Query defaultQuery = new MatchAllDocsQuery();
|
private static final String ORIGINAL_RANKING = "_OriginalRanking_";
|
||||||
|
|
||||||
// params for setting custom external info that features can use, like query
|
// params for setting custom external info that features can use, like query
|
||||||
// intent
|
// intent
|
||||||
|
@ -145,94 +145,73 @@ public class LTRQParserPlugin extends QParserPlugin implements ResourceLoaderAwa
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Query parse() throws SyntaxError {
|
public Query parse() throws SyntaxError {
|
||||||
// ReRanking Model
|
|
||||||
final String modelName = localParams.get(LTRQParserPlugin.MODEL);
|
|
||||||
if ((modelName == null) || modelName.isEmpty()) {
|
|
||||||
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
|
|
||||||
"Must provide model in the request");
|
|
||||||
}
|
|
||||||
|
|
||||||
final LTRScoringModel ltrScoringModel = mr.getModel(modelName);
|
|
||||||
if (ltrScoringModel == null) {
|
|
||||||
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
|
|
||||||
"cannot find " + LTRQParserPlugin.MODEL + " " + modelName);
|
|
||||||
}
|
|
||||||
|
|
||||||
final String modelFeatureStoreName = ltrScoringModel.getFeatureStoreName();
|
|
||||||
final boolean extractFeatures = SolrQueryRequestContextUtils.isExtractingFeatures(req);
|
|
||||||
final String fvStoreName = SolrQueryRequestContextUtils.getFvStoreName(req);
|
|
||||||
// Check if features are requested and if the model feature store and feature-transform feature store are the same
|
|
||||||
final boolean featuresRequestedFromSameStore = (modelFeatureStoreName.equals(fvStoreName) || fvStoreName == null) ? extractFeatures:false;
|
|
||||||
if (threadManager != null) {
|
if (threadManager != null) {
|
||||||
threadManager.setExecutor(req.getCore().getCoreContainer().getUpdateShardHandler().getUpdateExecutor());
|
threadManager.setExecutor(req.getCore().getCoreContainer().getUpdateShardHandler().getUpdateExecutor());
|
||||||
}
|
}
|
||||||
final LTRScoringQuery scoringQuery = new LTRScoringQuery(ltrScoringModel,
|
// ReRanking Model
|
||||||
extractEFIParams(localParams),
|
final String[] modelNames = localParams.getParams(LTRQParserPlugin.MODEL);
|
||||||
featuresRequestedFromSameStore, threadManager);
|
if ((modelNames == null) || (modelNames.length!=1 && modelNames.length!=2)) {
|
||||||
|
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
|
||||||
// Enable the feature vector caching if we are extracting features, and the features
|
"Must provide one or two models in the request");
|
||||||
// we requested are the same ones we are reranking with
|
}
|
||||||
if (featuresRequestedFromSameStore) {
|
final boolean isInterleaving = (modelNames.length > 1);
|
||||||
scoringQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) );
|
final boolean extractFeatures = SolrQueryRequestContextUtils.isExtractingFeatures(req);
|
||||||
|
final String tranformerFeatureStoreName = SolrQueryRequestContextUtils.getFvStoreName(req);
|
||||||
|
final Map<String,String[]> externalFeatureInfo = extractEFIParams(localParams);
|
||||||
|
|
||||||
|
LTRScoringQuery rerankingQuery = null;
|
||||||
|
LTRInterleavingScoringQuery[] rerankingQueries = new LTRInterleavingScoringQuery[modelNames.length];
|
||||||
|
for (int i = 0; i < modelNames.length; i++) {
|
||||||
|
if (modelNames[i].isEmpty()) {
|
||||||
|
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
|
||||||
|
"the " + LTRQParserPlugin.MODEL + " "+ i +" is empty");
|
||||||
|
}
|
||||||
|
if (!ORIGINAL_RANKING.equals(modelNames[i])) {
|
||||||
|
final LTRScoringModel ltrScoringModel = mr.getModel(modelNames[i]);
|
||||||
|
if (ltrScoringModel == null) {
|
||||||
|
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
|
||||||
|
"cannot find " + LTRQParserPlugin.MODEL + " " + modelNames[i]);
|
||||||
|
}
|
||||||
|
final String modelFeatureStoreName = ltrScoringModel.getFeatureStoreName();
|
||||||
|
// Check if features are requested and if the model feature store and feature-transform feature store are the same
|
||||||
|
final boolean featuresRequestedFromSameStore = (modelFeatureStoreName.equals(tranformerFeatureStoreName) || tranformerFeatureStoreName == null) ? extractFeatures : false;
|
||||||
|
|
||||||
|
if (isInterleaving) {
|
||||||
|
rerankingQuery = rerankingQueries[i] = new LTRInterleavingScoringQuery(ltrScoringModel,
|
||||||
|
externalFeatureInfo,
|
||||||
|
featuresRequestedFromSameStore, threadManager);
|
||||||
|
} else {
|
||||||
|
rerankingQuery = new LTRScoringQuery(ltrScoringModel,
|
||||||
|
externalFeatureInfo,
|
||||||
|
featuresRequestedFromSameStore, threadManager);
|
||||||
|
rerankingQueries[i] = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable the feature vector caching if we are extracting features, and the features
|
||||||
|
// we requested are the same ones we are reranking with
|
||||||
|
if (featuresRequestedFromSameStore) {
|
||||||
|
rerankingQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) );
|
||||||
|
}
|
||||||
|
}else{
|
||||||
|
rerankingQuery = rerankingQueries[i] = new OriginalRankingLTRScoringQuery(ORIGINAL_RANKING);
|
||||||
|
}
|
||||||
|
|
||||||
|
// External features
|
||||||
|
rerankingQuery.setRequest(req);
|
||||||
}
|
}
|
||||||
SolrQueryRequestContextUtils.setScoringQuery(req, scoringQuery);
|
|
||||||
|
|
||||||
int reRankDocs = localParams.getInt(RERANK_DOCS, DEFAULT_RERANK_DOCS);
|
int reRankDocs = localParams.getInt(RERANK_DOCS, DEFAULT_RERANK_DOCS);
|
||||||
if (reRankDocs <= 0) {
|
if (reRankDocs <= 0) {
|
||||||
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
|
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
|
||||||
"Must rerank at least 1 document");
|
"Must rerank at least 1 document");
|
||||||
|
}
|
||||||
|
if (!isInterleaving) {
|
||||||
|
SolrQueryRequestContextUtils.setScoringQueries(req, new LTRScoringQuery[] { rerankingQuery });
|
||||||
|
return new LTRQuery(rerankingQuery, reRankDocs);
|
||||||
|
} else {
|
||||||
|
SolrQueryRequestContextUtils.setScoringQueries(req, rerankingQueries);
|
||||||
|
return new LTRInterleavingQuery(Interleaving.getImplementation(Interleaving.TEAM_DRAFT),rerankingQueries, reRankDocs);
|
||||||
}
|
}
|
||||||
|
|
||||||
// External features
|
|
||||||
scoringQuery.setRequest(req);
|
|
||||||
|
|
||||||
return new LTRQuery(scoringQuery, reRankDocs);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A learning to rank Query, will incapsulate a learning to rank model, and delegate to it the rescoring
|
|
||||||
* of the documents.
|
|
||||||
**/
|
|
||||||
public class LTRQuery extends AbstractReRankQuery {
|
|
||||||
private final LTRScoringQuery scoringQuery;
|
|
||||||
|
|
||||||
public LTRQuery(LTRScoringQuery scoringQuery, int reRankDocs) {
|
|
||||||
super(defaultQuery, reRankDocs, new LTRRescorer(scoringQuery));
|
|
||||||
this.scoringQuery = scoringQuery;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int hashCode() {
|
|
||||||
return 31 * classHash() + (mainQuery.hashCode() + scoringQuery.hashCode() + reRankDocs);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean equals(Object o) {
|
|
||||||
return sameClassAs(o) && equalsTo(getClass().cast(o));
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean equalsTo(LTRQuery other) {
|
|
||||||
return (mainQuery.equals(other.mainQuery)
|
|
||||||
&& scoringQuery.equals(other.scoringQuery) && (reRankDocs == other.reRankDocs));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public RankQuery wrap(Query _mainQuery) {
|
|
||||||
super.wrap(_mainQuery);
|
|
||||||
scoringQuery.setOriginalQuery(_mainQuery);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString(String field) {
|
|
||||||
return "{!ltr mainQuery='" + mainQuery.toString() + "' scoringQuery='"
|
|
||||||
+ scoringQuery.toString() + "' reRankDocs=" + reRankDocs + "}";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Query rewrite(Query rewrittenMainQuery) throws IOException {
|
|
||||||
return new LTRQuery(scoringQuery, reRankDocs).wrap(rewrittenMainQuery);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,80 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.solr.ltr.search;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||||
|
import org.apache.lucene.search.Query;
|
||||||
|
import org.apache.solr.ltr.LTRRescorer;
|
||||||
|
import org.apache.solr.ltr.LTRScoringQuery;
|
||||||
|
import org.apache.solr.search.AbstractReRankQuery;
|
||||||
|
import org.apache.solr.search.RankQuery;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A learning to rank Query, will incapsulate a learning to rank model, and delegate to it the rescoring
|
||||||
|
* of the documents.
|
||||||
|
**/
|
||||||
|
public class LTRQuery extends AbstractReRankQuery {
|
||||||
|
private static final Query defaultQuery = new MatchAllDocsQuery();
|
||||||
|
private final LTRScoringQuery scoringQuery;
|
||||||
|
|
||||||
|
public LTRQuery(LTRScoringQuery scoringQuery, int reRankDocs) {
|
||||||
|
this(scoringQuery, reRankDocs, new LTRRescorer(scoringQuery));
|
||||||
|
}
|
||||||
|
|
||||||
|
protected LTRQuery(LTRScoringQuery scoringQuery, int reRankDocs, LTRRescorer rescorer) {
|
||||||
|
super(defaultQuery, reRankDocs, rescorer);
|
||||||
|
this.scoringQuery = scoringQuery;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return 31 * classHash() + (mainQuery.hashCode() + scoringQuery.hashCode() + reRankDocs);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
return sameClassAs(o) && equalsTo(getClass().cast(o));
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean equalsTo(LTRQuery other) {
|
||||||
|
return (mainQuery.equals(other.mainQuery)
|
||||||
|
&& scoringQuery.equals(other.scoringQuery) && (reRankDocs == other.reRankDocs));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RankQuery wrap(Query _mainQuery) {
|
||||||
|
super.wrap(_mainQuery);
|
||||||
|
if (scoringQuery != null) {
|
||||||
|
scoringQuery.setOriginalQuery(_mainQuery);
|
||||||
|
}
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(String field) {
|
||||||
|
return "{!ltr mainQuery='" + mainQuery.toString() + "' scoringQuery='"
|
||||||
|
+ scoringQuery.toString() + "' reRankDocs=" + reRankDocs + "}";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Query rewrite(Query rewrittenMainQuery) throws IOException {
|
||||||
|
return new LTRQuery(scoringQuery, reRankDocs).wrap(rewrittenMainQuery);
|
||||||
|
}
|
||||||
|
}
|
|
@ -39,7 +39,7 @@ A Learning to Rank model is plugged into the ranking through the {@link org.apac
|
||||||
a {@link org.apache.solr.search.QParserPlugin}. The plugin will
|
a {@link org.apache.solr.search.QParserPlugin}. The plugin will
|
||||||
read from the request the model (instance of {@link org.apache.solr.ltr.model.LTRScoringModel})
|
read from the request the model (instance of {@link org.apache.solr.ltr.model.LTRScoringModel})
|
||||||
used to perform the request plus other
|
used to perform the request plus other
|
||||||
parameters. The plugin will generate a {@link org.apache.solr.ltr.search.LTRQParserPlugin.LTRQuery LTRQuery}:
|
parameters. The plugin will generate a {@link org.apache.solr.ltr.search.LTRQuery LTRQuery}:
|
||||||
a particular {@link org.apache.solr.search.RankQuery}
|
a particular {@link org.apache.solr.search.RankQuery}
|
||||||
that will encapsulate the given model and use it to
|
that will encapsulate the given model and use it to
|
||||||
rescore and rerank the document (by using an {@link org.apache.solr.ltr.LTRRescorer}).
|
rescore and rerank the document (by using an {@link org.apache.solr.ltr.LTRRescorer}).
|
||||||
|
|
|
@ -46,6 +46,15 @@
|
||||||
<str name="fvCacheName">QUERY_DOC_FV</str>
|
<str name="fvCacheName">QUERY_DOC_FV</str>
|
||||||
</transformer>
|
</transformer>
|
||||||
|
|
||||||
|
<!-- add a transformer that will encode the model the interleaving process chose the search result from.
|
||||||
|
For each document the transformer will add an extra field in the response with the model picked.
|
||||||
|
The name of the field will be the the name of the transformer
|
||||||
|
enclosed between brackets (in this case [interleaving]).
|
||||||
|
In order to get the model chosen for the search result
|
||||||
|
you will have to specify that you want the field (e.g., fl="*,[interleaving]) -->
|
||||||
|
<transformer name="interleaving" class="org.apache.solr.ltr.response.transform.LTRInterleavingTransformerFactory">
|
||||||
|
</transformer>
|
||||||
|
|
||||||
<updateHandler class="solr.DirectUpdateHandler2">
|
<updateHandler class="solr.DirectUpdateHandler2">
|
||||||
<autoCommit>
|
<autoCommit>
|
||||||
<maxTime>15000</maxTime>
|
<maxTime>15000</maxTime>
|
||||||
|
|
|
@ -16,7 +16,11 @@
|
||||||
*/
|
*/
|
||||||
package org.apache.solr.ltr;
|
package org.apache.solr.ltr;
|
||||||
|
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
import org.apache.solr.client.solrj.SolrQuery;
|
import org.apache.solr.client.solrj.SolrQuery;
|
||||||
|
import org.apache.solr.ltr.feature.SolrFeature;
|
||||||
|
import org.apache.solr.ltr.interleaving.algorithms.TeamDraftInterleaving;
|
||||||
import org.apache.solr.ltr.model.LinearModel;
|
import org.apache.solr.ltr.model.LinearModel;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
@ -149,4 +153,160 @@ public class TestLTRQParserExplain extends TestRerankBase {
|
||||||
" 65.0 = tree 1 | \\'user_device_tablet\\':1.0 > 0.500001, Go Right | val: 65.0\n'}");
|
" 65.0 = tree 1 | \\'user_device_tablet\\':1.0 > 0.500001, Go Right | val: 65.0\n'}");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void interleavingModels_shouldReturnExplainForTheModelPicked() throws Exception {
|
||||||
|
TeamDraftInterleaving.setRANDOM(new Random(10));//Random Boolean Choices Generation from Seed: [1,0]
|
||||||
|
|
||||||
|
loadFeature("featureA1", SolrFeature.class.getName(),
|
||||||
|
"{\"fq\":[\"{!terms f=popularity}1\"]}");
|
||||||
|
loadFeature("featureA2", SolrFeature.class.getName(),
|
||||||
|
"{\"fq\":[\"{!terms f=description}bloomberg\"]}");
|
||||||
|
loadFeature("featureAB", SolrFeature.class.getName(),
|
||||||
|
"{\"fq\":[\"{!terms f=popularity}2\"]}");
|
||||||
|
loadFeature("featureB1", SolrFeature.class.getName(),
|
||||||
|
"{\"fq\":[\"{!terms f=popularity}5\"]}");
|
||||||
|
loadFeature("featureB2", SolrFeature.class.getName(),
|
||||||
|
"{\"fq\":[\"{!terms f=title}different\"]}");
|
||||||
|
|
||||||
|
loadModel("modelA", LinearModel.class.getName(),
|
||||||
|
new String[]{"featureA1", "featureA2", "featureAB"},
|
||||||
|
"{\"weights\":{\"featureA1\":3.0, \"featureA2\":9.0, \"featureAB\":27.0}}");
|
||||||
|
|
||||||
|
loadModel("modelB", LinearModel.class.getName(),
|
||||||
|
new String[]{"featureB1", "featureB2", "featureAB"},
|
||||||
|
"{\"weights\":{\"featureB1\":2.0, \"featureB2\":4.0, \"featureAB\":8.0}}");
|
||||||
|
|
||||||
|
final SolrQuery query = new SolrQuery();
|
||||||
|
query.setQuery("title:bloomberg");
|
||||||
|
query.setParam("debugQuery", "on");
|
||||||
|
query.add("rows", "10");
|
||||||
|
query.add("rq", "{!ltr reRankDocs=10 model=modelA model=modelB}");
|
||||||
|
query.add("fl", "*,score");
|
||||||
|
|
||||||
|
/*
|
||||||
|
Doc6 = "featureA1=1.0 featureA2=1.0 featureB2=1.0", ScoreA(12), ScoreB(4)
|
||||||
|
Doc7 = "featureA2=1.0 featureAB=1.0", ScoreA(36), ScoreB(8)
|
||||||
|
Doc8 = "featureA2=1.0", ScoreA(9), ScoreB(0)
|
||||||
|
Doc9 = "featureA2=1.0 featureB1=1.0", ScoreA(9), ScoreB(2)
|
||||||
|
|
||||||
|
ModelARerankedList = [7,6,8,9]
|
||||||
|
ModelBRerankedList = [7,6,9,8]
|
||||||
|
|
||||||
|
Random Boolean Choices Generation from Seed: [1,0]
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
int[] expectedInterleaved = new int[]{7, 6, 8, 9};
|
||||||
|
String[] expectedExplains = new String[]{
|
||||||
|
"\n8.0 = LinearModel(name=modelB," +
|
||||||
|
"featureWeights=[featureB1=2.0,featureB2=4.0,featureAB=8.0]) " +
|
||||||
|
"model applied to features, sum of:\n " +
|
||||||
|
"0.0 = prod of:\n 2.0 = weight on feature\n 0.0 = SolrFeature [name=featureB1, params={fq=[{!terms f=popularity}5]}]\n " +
|
||||||
|
"0.0 = prod of:\n 4.0 = weight on feature\n 0.0 = SolrFeature [name=featureB2, params={fq=[{!terms f=title}different]}]\n " +
|
||||||
|
"8.0 = prod of:\n 8.0 = weight on feature\n 1.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n",
|
||||||
|
"\n12.0 = LinearModel(name=modelA," +
|
||||||
|
"featureWeights=[featureA1=3.0,featureA2=9.0,featureAB=27.0]) " +
|
||||||
|
"model applied to features, sum of:\n " +
|
||||||
|
"3.0 = prod of:\n 3.0 = weight on feature\n 1.0 = SolrFeature [name=featureA1, params={fq=[{!terms f=popularity}1]}]\n " +
|
||||||
|
"9.0 = prod of:\n 9.0 = weight on feature\n 1.0 = SolrFeature [name=featureA2, params={fq=[{!terms f=description}bloomberg]}]\n " +
|
||||||
|
"0.0 = prod of:\n 27.0 = weight on feature\n 0.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n",
|
||||||
|
"\n9.0 = LinearModel(name=modelA," +
|
||||||
|
"featureWeights=[featureA1=3.0,featureA2=9.0,featureAB=27.0]) " +
|
||||||
|
"model applied to features, sum of:\n " +
|
||||||
|
"0.0 = prod of:\n 3.0 = weight on feature\n 0.0 = SolrFeature [name=featureA1, params={fq=[{!terms f=popularity}1]}]\n " +
|
||||||
|
"9.0 = prod of:\n 9.0 = weight on feature\n 1.0 = SolrFeature [name=featureA2, params={fq=[{!terms f=description}bloomberg]}]\n " +
|
||||||
|
"0.0 = prod of:\n 27.0 = weight on feature\n 0.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n",
|
||||||
|
"\n2.0 = LinearModel(name=modelB," +
|
||||||
|
"featureWeights=[featureB1=2.0,featureB2=4.0,featureAB=8.0]) " +
|
||||||
|
"model applied to features, sum of:\n " +
|
||||||
|
"2.0 = prod of:\n 2.0 = weight on feature\n 1.0 = SolrFeature [name=featureB1, params={fq=[{!terms f=popularity}5]}]\n " +
|
||||||
|
"0.0 = prod of:\n 4.0 = weight on feature\n 0.0 = SolrFeature [name=featureB2, params={fq=[{!terms f=title}different]}]\n " +
|
||||||
|
"0.0 = prod of:\n 8.0 = weight on feature\n 0.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n"};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
String[] tests = new String[16];
|
||||||
|
tests[0] = "/response/numFound/==4";
|
||||||
|
for (int i = 1; i <= 4; i++) {
|
||||||
|
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||||
|
tests[i + 4] = "/debug/explain/" + expectedInterleaved[(i - 1)] + "=='" + expectedExplains[(i - 1)]+"'}";
|
||||||
|
}
|
||||||
|
assertJQ("/query" + query.toQueryString(), tests);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void interleavingModelsWithOriginalRanking_shouldReturnExplainForTheModelPicked() throws Exception {
|
||||||
|
TeamDraftInterleaving.setRANDOM(new Random(10));//Random Boolean Choices Generation from Seed: [1,0]
|
||||||
|
|
||||||
|
loadFeature("featureA1", SolrFeature.class.getName(),
|
||||||
|
"{\"fq\":[\"{!terms f=popularity}1\"]}");
|
||||||
|
loadFeature("featureA2", SolrFeature.class.getName(),
|
||||||
|
"{\"fq\":[\"{!terms f=description}bloomberg\"]}");
|
||||||
|
loadFeature("featureAB", SolrFeature.class.getName(),
|
||||||
|
"{\"fq\":[\"{!terms f=popularity}2\"]}");
|
||||||
|
|
||||||
|
loadModel("modelA", LinearModel.class.getName(),
|
||||||
|
new String[]{"featureA1", "featureA2", "featureAB"},
|
||||||
|
"{\"weights\":{\"featureA1\":3.0, \"featureA2\":9.0, \"featureAB\":27.0}}");
|
||||||
|
|
||||||
|
final SolrQuery query = new SolrQuery();
|
||||||
|
query.setQuery("title:bloomberg");
|
||||||
|
query.setParam("debugQuery", "on");
|
||||||
|
query.add("rows", "10");
|
||||||
|
query.add("rq", "{!ltr reRankDocs=10 model=modelA model=_OriginalRanking_}");
|
||||||
|
query.add("fl", "*,score");
|
||||||
|
|
||||||
|
/*
|
||||||
|
Doc6 = "featureA1=1.0 featureA2=1.0 featureB2=1.0", ScoreA(12)
|
||||||
|
Doc7 = "featureA2=1.0 featureAB=1.0", ScoreA(36)
|
||||||
|
Doc8 = "featureA2=1.0", ScoreA(9)
|
||||||
|
Doc9 = "featureA2=1.0 featureB1=1.0", ScoreA(9)
|
||||||
|
|
||||||
|
ModelARerankedList = [7,6,8,9]
|
||||||
|
OriginalRanking = [9,8,7,6]
|
||||||
|
|
||||||
|
Random Boolean Choices Generation from Seed: [1,0]
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
int[] expectedInterleaved = new int[]{9, 7, 6, 8};
|
||||||
|
String[] expectedExplains = new String[]{
|
||||||
|
"\n0.07662583 = weight(title:bloomberg in 3) [SchemaSimilarity], result of:\n " +
|
||||||
|
"0.07662583 = score(freq=4.0), computed as boost * idf * tf from:\n " +
|
||||||
|
"0.105360515 = idf, computed as log(1 + (N - n + 0.5) / (n + 0.5)) from:\n 4 = n, number of documents containing term\n 4 = N, total number of documents with field\n " +
|
||||||
|
"0.72727275 = tf, computed as freq / (freq + k1 * (1 - b + b * dl / avgdl)) from:\n 4.0 = freq, occurrences of term within document\n " +
|
||||||
|
"1.2 = k1, term saturation parameter\n " +
|
||||||
|
"0.75 = b, length normalization parameter\n " +
|
||||||
|
"4.0 = dl, length of field\n " +
|
||||||
|
"3.0 = avgdl, average length of field\n",
|
||||||
|
"\n36.0 = LinearModel(name=modelA," +
|
||||||
|
"featureWeights=[featureA1=3.0,featureA2=9.0,featureAB=27.0]) " +
|
||||||
|
"model applied to features, sum of:\n " +
|
||||||
|
"0.0 = prod of:\n 3.0 = weight on feature\n 0.0 = SolrFeature [name=featureA1, params={fq=[{!terms f=popularity}1]}]\n " +
|
||||||
|
"9.0 = prod of:\n 9.0 = weight on feature\n 1.0 = SolrFeature [name=featureA2, params={fq=[{!terms f=description}bloomberg]}]\n " +
|
||||||
|
"27.0 = prod of:\n 27.0 = weight on feature\n 1.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n",
|
||||||
|
"\n12.0 = LinearModel(name=modelA," +
|
||||||
|
"featureWeights=[featureA1=3.0,featureA2=9.0,featureAB=27.0]) " +
|
||||||
|
"model applied to features, sum of:\n " +
|
||||||
|
"3.0 = prod of:\n 3.0 = weight on feature\n 1.0 = SolrFeature [name=featureA1, params={fq=[{!terms f=popularity}1]}]\n " +
|
||||||
|
"9.0 = prod of:\n 9.0 = weight on feature\n 1.0 = SolrFeature [name=featureA2, params={fq=[{!terms f=description}bloomberg]}]\n " +
|
||||||
|
"0.0 = prod of:\n 27.0 = weight on feature\n 0.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n",
|
||||||
|
"\n0.07525751 = weight(title:bloomberg in 2) [SchemaSimilarity], result of:\n " +
|
||||||
|
"0.07525751 = score(freq=3.0), computed as boost * idf * tf from:\n " +
|
||||||
|
"0.105360515 = idf, computed as log(1 + (N - n + 0.5) / (n + 0.5)) from:\n 4 = n, number of documents containing term\n 4 = N, total number of documents with field\n " +
|
||||||
|
"0.71428573 = tf, computed as freq / (freq + k1 * (1 - b + b * dl / avgdl)) from:\n 3.0 = freq, occurrences of term within document\n " +
|
||||||
|
"1.2 = k1, term saturation parameter\n " +
|
||||||
|
"0.75 = b, length normalization parameter\n " +
|
||||||
|
"3.0 = dl, length of field\n " +
|
||||||
|
"3.0 = avgdl, average length of field\n"};
|
||||||
|
|
||||||
|
String[] tests = new String[16];
|
||||||
|
tests[0] = "/response/numFound/==4";
|
||||||
|
for (int i = 1; i <= 4; i++) {
|
||||||
|
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||||
|
tests[i + 4] = "/debug/explain/" + expectedInterleaved[(i - 1)] + "=='" + expectedExplains[(i - 1)]+"'}";
|
||||||
|
}
|
||||||
|
assertJQ("/query" + query.toQueryString(), tests);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,7 +48,21 @@ public class TestLTRQParserPlugin extends TestRerankBase {
|
||||||
query.add("rq", "{!ltr reRankDocs=100}");
|
query.add("rq", "{!ltr reRankDocs=100}");
|
||||||
|
|
||||||
final String res = restTestHarness.query("/query" + query.toQueryString());
|
final String res = restTestHarness.query("/query" + query.toQueryString());
|
||||||
assert (res.contains("Must provide model in the request"));
|
assert (res.contains("Must provide one or two models in the request"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void interleavingLtrTooManyModelsTest() throws Exception {
|
||||||
|
final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}";
|
||||||
|
final SolrQuery query = new SolrQuery();
|
||||||
|
query.setQuery(solrQuery);
|
||||||
|
query.add("fl", "*, score");
|
||||||
|
query.add("rows", "4");
|
||||||
|
query.add("fv", "true");
|
||||||
|
query.add("rq", "{!ltr model=modelA model=modelB model=C reRankDocs=100}");
|
||||||
|
|
||||||
|
final String res = restTestHarness.query("/query" + query.toQueryString());
|
||||||
|
assert (res.contains("Must provide one or two models in the request"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -65,6 +79,34 @@ public class TestLTRQParserPlugin extends TestRerankBase {
|
||||||
assert (res.contains("cannot find model"));
|
assert (res.contains("cannot find model"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void ltrModelIsEmptyTest() throws Exception {
|
||||||
|
final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}";
|
||||||
|
final SolrQuery query = new SolrQuery();
|
||||||
|
query.setQuery(solrQuery);
|
||||||
|
query.add("fl", "*, score");
|
||||||
|
query.add("rows", "4");
|
||||||
|
query.add("fv", "true");
|
||||||
|
query.add("rq", "{!ltr model=\"\" reRankDocs=100}");
|
||||||
|
|
||||||
|
final String res = restTestHarness.query("/query" + query.toQueryString());
|
||||||
|
assert (res.contains("the model 0 is empty"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void interleavingLtrModelIsEmptyTest() throws Exception {
|
||||||
|
final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}";
|
||||||
|
final SolrQuery query = new SolrQuery();
|
||||||
|
query.setQuery(solrQuery);
|
||||||
|
query.add("fl", "*, score");
|
||||||
|
query.add("rows", "4");
|
||||||
|
query.add("fv", "true");
|
||||||
|
query.add("rq", "{!ltr model=6029760550880411648 model=\"\" reRankDocs=100}");
|
||||||
|
|
||||||
|
final String res = restTestHarness.query("/query" + query.toQueryString());
|
||||||
|
assert (res.contains("the model 1 is empty"));
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void ltrBadRerankDocsTest() throws Exception {
|
public void ltrBadRerankDocsTest() throws Exception {
|
||||||
final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}";
|
final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}";
|
||||||
|
|
|
@ -17,8 +17,11 @@
|
||||||
|
|
||||||
package org.apache.solr.ltr;
|
package org.apache.solr.ltr;
|
||||||
|
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
import org.apache.solr.client.solrj.SolrQuery;
|
import org.apache.solr.client.solrj.SolrQuery;
|
||||||
import org.apache.solr.ltr.feature.SolrFeature;
|
import org.apache.solr.ltr.feature.SolrFeature;
|
||||||
|
import org.apache.solr.ltr.interleaving.algorithms.TeamDraftInterleaving;
|
||||||
import org.apache.solr.ltr.model.LinearModel;
|
import org.apache.solr.ltr.model.LinearModel;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
@ -97,4 +100,104 @@ public class TestLTRWithSort extends TestRerankBase {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void interleavingTwoModelsWithSort_shouldInterleave() throws Exception {
|
||||||
|
TeamDraftInterleaving.setRANDOM(new Random(10));//Random Boolean Choices Generation from Seed: [1,0]
|
||||||
|
|
||||||
|
loadFeature("featureA", SolrFeature.class.getName(),
|
||||||
|
"{\"q\":\"{!func}pow(popularity,2)\"}");
|
||||||
|
|
||||||
|
loadFeature("featureB", SolrFeature.class.getName(),
|
||||||
|
"{\"q\":\"{!func}pow(popularity,-2)\"}");
|
||||||
|
|
||||||
|
loadModel("modelA", LinearModel.class.getName(),
|
||||||
|
new String[] {"featureA"}, "{\"weights\":{\"featureA\":1.0}}");
|
||||||
|
|
||||||
|
loadModel("modelB", LinearModel.class.getName(),
|
||||||
|
new String[] {"featureB"}, "{\"weights\":{\"featureB\":1.0}}");
|
||||||
|
|
||||||
|
final SolrQuery query = new SolrQuery();
|
||||||
|
query.setQuery("title:a1");
|
||||||
|
query.add("rows", "10");
|
||||||
|
query.add("rq", "{!ltr reRankDocs=4 model=modelA model=modelB}");
|
||||||
|
query.add("fl", "*,score");
|
||||||
|
query.add("sort", "description desc");
|
||||||
|
|
||||||
|
/*
|
||||||
|
Doc1 = "popularity=1", ScoreA(1) ScoreB(1)
|
||||||
|
Doc5 = "popularity=5", ScoreA(25) ScoreB(0.04)
|
||||||
|
Doc7 = "popularity=7", ScoreA(49) ScoreB(0.02)
|
||||||
|
Doc8 = "popularity=8", ScoreA(64) ScoreB(0.01)
|
||||||
|
|
||||||
|
ModelARerankedList = [8,7,5,1]
|
||||||
|
ModelBRerankedList = [1,5,7,8]
|
||||||
|
|
||||||
|
OriginalRanking = [1,5,8,7]
|
||||||
|
|
||||||
|
Random Boolean Choices Generation from Seed: [1,0]
|
||||||
|
*/
|
||||||
|
|
||||||
|
int[] expectedInterleaved = new int[]{1, 8, 7, 5};
|
||||||
|
|
||||||
|
String[] tests = new String[5];
|
||||||
|
tests[0] = "/response/numFound/==8";
|
||||||
|
for (int i = 1; i <= 4; i++) {
|
||||||
|
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||||
|
}
|
||||||
|
assertJQ("/query" + query.toQueryString(), tests);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void interleavingModelsWithOriginalRankingSort_shouldInterleave() throws Exception {
|
||||||
|
|
||||||
|
loadFeature("powpularityS", SolrFeature.class.getName(),
|
||||||
|
"{\"q\":\"{!func}pow(popularity,2)\"}");
|
||||||
|
|
||||||
|
loadModel("powpularityS-model", LinearModel.class.getName(),
|
||||||
|
new String[] {"powpularityS"}, "{\"weights\":{\"powpularityS\":1.0}}");
|
||||||
|
|
||||||
|
for (boolean originalRankingLast : new boolean[] { true, false }) {
|
||||||
|
TeamDraftInterleaving.setRANDOM(new Random(10));//Random Boolean Choices Generation from Seed: [1,0]
|
||||||
|
|
||||||
|
final SolrQuery query = new SolrQuery();
|
||||||
|
query.setQuery("title:a1");
|
||||||
|
query.add("rows", "10");
|
||||||
|
if (originalRankingLast) {
|
||||||
|
query.add("rq", "{!ltr reRankDocs=4 model=powpularityS-model model=_OriginalRanking_}");
|
||||||
|
} else {
|
||||||
|
query.add("rq", "{!ltr reRankDocs=4 model=_OriginalRanking_ model=powpularityS-model}");
|
||||||
|
}
|
||||||
|
query.add("fl", "*,score");
|
||||||
|
query.add("sort", "description desc");
|
||||||
|
|
||||||
|
/*
|
||||||
|
Doc1 = "popularity=1", ScorePowpularityS(1)
|
||||||
|
Doc5 = "popularity=5", ScorePowpularityS(25)
|
||||||
|
Doc7 = "popularity=7", ScorePowpularityS(49)
|
||||||
|
Doc8 = "popularity=8", ScorePowpularityS(64)
|
||||||
|
|
||||||
|
PowpularitySRerankedList = [8,7,5,1]
|
||||||
|
OriginalRanking = [1,5,8,7]
|
||||||
|
|
||||||
|
Random Boolean Choices Generation from Seed: [1,0]
|
||||||
|
*/
|
||||||
|
|
||||||
|
final int[] expectedInterleaved;
|
||||||
|
if (originalRankingLast) {
|
||||||
|
expectedInterleaved = new int[]{1, 8, 7, 5};
|
||||||
|
} else {
|
||||||
|
expectedInterleaved = new int[]{8, 1, 5, 7};
|
||||||
|
}
|
||||||
|
|
||||||
|
String[] tests = new String[5];
|
||||||
|
tests[0] = "/response/numFound/==8";
|
||||||
|
for (int i = 1; i <= 4; i++) {
|
||||||
|
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||||
|
}
|
||||||
|
assertJQ("/query" + query.toQueryString(), tests);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,170 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.solr.ltr.interleaving.algorithms;
|
||||||
|
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Random;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
|
import org.apache.solr.ltr.interleaving.InterleavingResult;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import static org.hamcrest.MatcherAssert.assertThat;
|
||||||
|
import static org.hamcrest.CoreMatchers.is;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
public class TeamDraftInterleavingTest {
|
||||||
|
private TeamDraftInterleaving toTest;
|
||||||
|
private ScoreDoc[] rerankedA,rerankedB;
|
||||||
|
private ScoreDoc a1,a2,a3,a4,a5;
|
||||||
|
private ScoreDoc b1,b2,b3,b4,b5;
|
||||||
|
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setup() {
|
||||||
|
toTest = new TeamDraftInterleaving();
|
||||||
|
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void initDifferentOrderRerankLists() {
|
||||||
|
a1 = new ScoreDoc(1,10,1);
|
||||||
|
a2 = new ScoreDoc(5,7,1);
|
||||||
|
a3 = new ScoreDoc(4,6,1);
|
||||||
|
a4 = new ScoreDoc(2,5,1);
|
||||||
|
a5 = new ScoreDoc(3,4,1);
|
||||||
|
rerankedA = new ScoreDoc[]{a1,a2,a3,a4,a5};
|
||||||
|
|
||||||
|
b1 = new ScoreDoc(1,10,1);
|
||||||
|
b2 = new ScoreDoc(4,7,1);
|
||||||
|
b3 = new ScoreDoc(5,6,1);
|
||||||
|
b4 = new ScoreDoc(3,5,1);
|
||||||
|
b5 = new ScoreDoc(2,4,1);
|
||||||
|
rerankedB = new ScoreDoc[]{b1,b2,b3,b4,b5};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void interleaving_twoDifferentLists_shouldInterleaveTeamDraft() {
|
||||||
|
initDifferentOrderRerankLists();
|
||||||
|
|
||||||
|
InterleavingResult interleaved = toTest.interleave(rerankedA, rerankedB);
|
||||||
|
ScoreDoc[] interleavedResults = interleaved.getInterleavedResults();
|
||||||
|
|
||||||
|
assertThat(interleavedResults.length,is(5));
|
||||||
|
|
||||||
|
assertThat(interleavedResults[0],is(a1));
|
||||||
|
assertThat(interleavedResults[1],is(b2));
|
||||||
|
|
||||||
|
assertThat(interleavedResults[2],is(b3));
|
||||||
|
assertThat(interleavedResults[3],is(a4));
|
||||||
|
|
||||||
|
assertThat(interleavedResults[4],is(b4));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void interleaving_twoDifferentLists_shouldBuildCorrectInterleavingPicks() {
|
||||||
|
initDifferentOrderRerankLists();
|
||||||
|
|
||||||
|
InterleavingResult interleaved = toTest.interleave(rerankedA, rerankedB);
|
||||||
|
|
||||||
|
ArrayList<Set<Integer>> interleavingPicks = interleaved.getInterleavingPicks();
|
||||||
|
Set<Integer> modelAPicks = interleavingPicks.get(0);
|
||||||
|
Set<Integer> modelBPicks = interleavingPicks.get(1);
|
||||||
|
|
||||||
|
assertThat(modelAPicks.size(),is(2));
|
||||||
|
assertThat(modelBPicks.size(),is(3));
|
||||||
|
|
||||||
|
assertTrue(modelAPicks.contains(a1.doc));
|
||||||
|
assertTrue(modelAPicks.contains(a4.doc));
|
||||||
|
|
||||||
|
assertTrue(modelBPicks.contains(b2.doc));
|
||||||
|
assertTrue(modelBPicks.contains(b3.doc));
|
||||||
|
assertTrue(modelBPicks.contains(b4.doc));
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void initIdenticalOrderRerankLists() {
|
||||||
|
a1 = new ScoreDoc(1,10,1);
|
||||||
|
a2 = new ScoreDoc(5,7,1);
|
||||||
|
a3 = new ScoreDoc(4,6,1);
|
||||||
|
a4 = new ScoreDoc(2,5,1);
|
||||||
|
a5 = new ScoreDoc(3,4,1);
|
||||||
|
rerankedA = new ScoreDoc[]{a1,a2,a3,a4,a5};
|
||||||
|
|
||||||
|
b1 = new ScoreDoc(1,10,1);
|
||||||
|
b2 = new ScoreDoc(5,7,1);
|
||||||
|
b3 = new ScoreDoc(4,6,1);
|
||||||
|
b4 = new ScoreDoc(2,5,1);
|
||||||
|
b5 = new ScoreDoc(3,4,1);
|
||||||
|
rerankedB = new ScoreDoc[]{b1,b2,b3,b4,b5};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void interleaving_identicalRerankLists_shouldInterleaveTeamDraft() {
|
||||||
|
initIdenticalOrderRerankLists();
|
||||||
|
|
||||||
|
InterleavingResult interleaved = toTest.interleave(rerankedA, rerankedB);
|
||||||
|
ScoreDoc[] interleavedResults = interleaved.getInterleavedResults();
|
||||||
|
|
||||||
|
assertThat(interleavedResults.length,is(5));
|
||||||
|
|
||||||
|
assertThat(interleavedResults[0],is(a1));
|
||||||
|
assertThat(interleavedResults[1],is(b2));
|
||||||
|
|
||||||
|
assertThat(interleavedResults[2],is(b3));
|
||||||
|
assertThat(interleavedResults[3],is(a4));
|
||||||
|
|
||||||
|
assertThat(interleavedResults[4],is(b5));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void interleaving_identicalRerankLists_shouldBuildCorrectInterleavingPicks() {
|
||||||
|
initIdenticalOrderRerankLists();
|
||||||
|
|
||||||
|
InterleavingResult interleaved = toTest.interleave(rerankedA, rerankedB);
|
||||||
|
|
||||||
|
ArrayList<Set<Integer>> interleavingPicks = interleaved.getInterleavingPicks();
|
||||||
|
Set<Integer> modelAPicks = interleavingPicks.get(0);
|
||||||
|
Set<Integer> modelBPicks = interleavingPicks.get(1);
|
||||||
|
|
||||||
|
assertThat(modelAPicks.size(),is(2));
|
||||||
|
assertThat(modelBPicks.size(),is(3));
|
||||||
|
|
||||||
|
assertTrue(modelAPicks.contains(a1.doc));
|
||||||
|
assertTrue(modelAPicks.contains(a4.doc));
|
||||||
|
|
||||||
|
assertTrue(modelBPicks.contains(b2.doc));
|
||||||
|
assertTrue(modelBPicks.contains(b3.doc));
|
||||||
|
assertTrue(modelBPicks.contains(b5.doc));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,400 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.solr.ltr.response.transform;
|
||||||
|
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
|
import org.apache.solr.client.solrj.SolrQuery;
|
||||||
|
import org.apache.solr.ltr.TestRerankBase;
|
||||||
|
import org.apache.solr.ltr.feature.SolrFeature;
|
||||||
|
import org.apache.solr.ltr.interleaving.algorithms.TeamDraftInterleaving;
|
||||||
|
import org.apache.solr.ltr.model.LinearModel;
|
||||||
|
import org.junit.After;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
public class TestFeatureLoggerTransformer extends TestRerankBase {
|
||||||
|
@Before
|
||||||
|
public void before() throws Exception {
|
||||||
|
setuptest(false);
|
||||||
|
|
||||||
|
assertU(adoc("id", "1", "title", "w1", "description", "w5", "popularity",
|
||||||
|
"1"));
|
||||||
|
assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description",
|
||||||
|
"w2 2asd asdd didid", "popularity", "2"));
|
||||||
|
assertU(adoc("id", "3", "title", "w1", "description", "w5", "popularity",
|
||||||
|
"3"));
|
||||||
|
assertU(adoc("id", "4", "title", "w1", "description", "w1", "popularity",
|
||||||
|
"6"));
|
||||||
|
assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity",
|
||||||
|
"5"));
|
||||||
|
assertU(adoc("id", "6", "title", "w6 w2", "description", "w1 w2",
|
||||||
|
"popularity", "6"));
|
||||||
|
assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description",
|
||||||
|
"w6 w2 w3 w4 w5 w8", "popularity", "88888"));
|
||||||
|
assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description",
|
||||||
|
"w1 w1 w1 w2 w2 w5", "popularity", "88888"));
|
||||||
|
assertU(commit());
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@After
|
||||||
|
public void after() throws Exception {
|
||||||
|
aftertest();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void loadFeaturesAndModels() throws Exception {
|
||||||
|
loadFeature("featureA1", SolrFeature.class.getName(),
|
||||||
|
"{\"fq\":[\"{!terms f=popularity}88888\"]}");
|
||||||
|
loadFeature("featureA2", SolrFeature.class.getName(),
|
||||||
|
"{\"fq\":[\"{!terms f=title}${user_query}\"]}");
|
||||||
|
loadFeature("featureAB", SolrFeature.class.getName(),
|
||||||
|
"{\"fq\":[\"{!terms f=title}${user_query}\"]}");
|
||||||
|
loadFeature("featureB1", SolrFeature.class.getName(),
|
||||||
|
"{\"fq\":[\"{!terms f=popularity}6\"]}");
|
||||||
|
loadFeature("featureB2", SolrFeature.class.getName(),
|
||||||
|
"{\"fq\":[\"{!terms f=description}${user_query}\"]}");
|
||||||
|
loadFeature("featureC1", SolrFeature.class.getName(),"featureStore2",
|
||||||
|
"{\"fq\":[\"{!terms f=popularity}6\"]}");
|
||||||
|
loadFeature("featureC2", SolrFeature.class.getName(),"featureStore2",
|
||||||
|
"{\"fq\":[\"{!terms f=description}${user_query}\"]}");
|
||||||
|
loadFeature("featureC3", SolrFeature.class.getName(),"featureStore2",
|
||||||
|
"{\"fq\":[\"{!terms f=description}${user_query}\"]}");
|
||||||
|
|
||||||
|
loadModel("modelA", LinearModel.class.getName(),
|
||||||
|
new String[]{"featureA1", "featureA2", "featureAB"},
|
||||||
|
"{\"weights\":{\"featureA1\":3.0, \"featureA2\":9.0, \"featureAB\":27.0}}");
|
||||||
|
|
||||||
|
loadModel("modelB", LinearModel.class.getName(),
|
||||||
|
new String[]{"featureB1", "featureB2", "featureAB"},
|
||||||
|
"{\"weights\":{\"featureB1\":2.0, \"featureB2\":4.0, \"featureAB\":8.0}}");
|
||||||
|
|
||||||
|
loadModel("modelC", LinearModel.class.getName(),
|
||||||
|
new String[]{"featureC1", "featureC2"},"featureStore2",
|
||||||
|
"{\"weights\":{\"featureC1\":5.0, \"featureC2\":25.0}}");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void interleaving_featureTransformer_shouldWorkInSparseFormat() throws Exception {
|
||||||
|
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
loadFeaturesAndModels();
|
||||||
|
|
||||||
|
final SolrQuery query = new SolrQuery();
|
||||||
|
query.setQuery("*:*");
|
||||||
|
query.add("fl", "*, score,features:[fv format=sparse]");
|
||||||
|
query.add("rows", "10");
|
||||||
|
query.add("debugQuery", "true");
|
||||||
|
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||||
|
query.add("rq",
|
||||||
|
"{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}");
|
||||||
|
|
||||||
|
/*
|
||||||
|
Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||||
|
Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||||
|
Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2)
|
||||||
|
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12)
|
||||||
|
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4)
|
||||||
|
ModelARerankedList = [7,8,1,3,4]
|
||||||
|
ModelBRerankedList = [7,1,3,8,4]
|
||||||
|
|
||||||
|
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
*/
|
||||||
|
String[] expectedFeatureVectors = new String[]{"featureA1\\=1.0\\,featureA2\\=1.0\\,featureAB\\=1.0\\,featureB2\\=1.0", "featureB2\\=1.0", "featureB2\\=1.0", "featureA1\\=1.0\\,featureB2\\=1.0", "featureB1\\=1.0"};
|
||||||
|
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||||
|
|
||||||
|
String[] tests = new String[11];
|
||||||
|
tests[0] = "/response/numFound/==5";
|
||||||
|
for (int i = 1; i <= 5; i++) {
|
||||||
|
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||||
|
tests[i + 5] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
|
||||||
|
}
|
||||||
|
assertJQ("/query" + query.toQueryString(), tests);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void interleaving_featureTransformer_shouldWorkInDenseFormat() throws Exception {
|
||||||
|
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
loadFeaturesAndModels();
|
||||||
|
|
||||||
|
final SolrQuery query = new SolrQuery();
|
||||||
|
query.setQuery("*:*");
|
||||||
|
query.add("fl", "*, score,features:[fv format=dense]");
|
||||||
|
query.add("rows", "10");
|
||||||
|
query.add("debugQuery", "true");
|
||||||
|
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||||
|
query.add("rq",
|
||||||
|
"{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}");
|
||||||
|
|
||||||
|
/*
|
||||||
|
Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||||
|
Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||||
|
Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2)
|
||||||
|
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12)
|
||||||
|
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4)
|
||||||
|
ModelARerankedList = [7,8,1,3,4]
|
||||||
|
ModelBRerankedList = [7,1,3,8,4]
|
||||||
|
|
||||||
|
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
*/
|
||||||
|
String[] expectedFeatureVectors = new String[]{"featureA1\\=1.0\\,featureA2\\=1.0\\,featureAB\\=1.0\\,featureB1\\=0.0\\,featureB2\\=1.0",
|
||||||
|
"featureA1\\=0.0\\,featureA2\\=0.0\\,featureAB\\=0.0\\,featureB1\\=0.0\\,featureB2\\=1.0",
|
||||||
|
"featureA1\\=0.0\\,featureA2\\=0.0\\,featureAB\\=0.0\\,featureB1\\=0.0\\,featureB2\\=1.0",
|
||||||
|
"featureA1\\=1.0\\,featureA2\\=0.0\\,featureAB\\=0.0\\,featureB1\\=0.0\\,featureB2\\=1.0",
|
||||||
|
"featureA1\\=0.0\\,featureA2\\=0.0\\,featureAB\\=0.0\\,featureB1\\=1.0\\,featureB2\\=0.0"};
|
||||||
|
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||||
|
|
||||||
|
String[] tests = new String[11];
|
||||||
|
tests[0] = "/response/numFound/==5";
|
||||||
|
for (int i = 1; i <= 5; i++) {
|
||||||
|
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||||
|
tests[i + 5] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
|
||||||
|
}
|
||||||
|
assertJQ("/query" + query.toQueryString(), tests);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void interleaving_explicitNewFeatureStore_shouldExtractAllFeaturesFromNewStore() throws Exception {
|
||||||
|
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
loadFeaturesAndModels();
|
||||||
|
|
||||||
|
final SolrQuery query = new SolrQuery();
|
||||||
|
query.setQuery("*:*");
|
||||||
|
query.add("fl", "*, score,features:[fv store=featureStore2 efi.user_query='w5' format=dense]");
|
||||||
|
query.add("rows", "10");
|
||||||
|
query.add("debugQuery", "true");
|
||||||
|
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||||
|
query.add("rq",
|
||||||
|
"{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}");
|
||||||
|
|
||||||
|
/*
|
||||||
|
Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||||
|
Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||||
|
Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2)
|
||||||
|
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12)
|
||||||
|
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4)
|
||||||
|
ModelARerankedList = [7,8,1,3,4]
|
||||||
|
ModelBRerankedList = [7,1,3,8,4]
|
||||||
|
|
||||||
|
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
*/
|
||||||
|
String[] expectedFeatureVectors = new String[]{
|
||||||
|
"featureC1\\=0.0\\,featureC2\\=1.0\\,featureC3\\=1.0",
|
||||||
|
"featureC1\\=0.0\\,featureC2\\=1.0\\,featureC3\\=1.0",
|
||||||
|
"featureC1\\=0.0\\,featureC2\\=1.0\\,featureC3\\=1.0",
|
||||||
|
"featureC1\\=0.0\\,featureC2\\=1.0\\,featureC3\\=1.0",
|
||||||
|
"featureC1\\=1.0\\,featureC2\\=0.0\\,featureC3\\=0.0"};
|
||||||
|
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||||
|
|
||||||
|
String[] tests = new String[16];
|
||||||
|
tests[0] = "/response/numFound/==5";
|
||||||
|
for (int i = 1; i <= 5; i++) {
|
||||||
|
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||||
|
tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
|
||||||
|
}
|
||||||
|
assertJQ("/query" + query.toQueryString(), tests);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void interleaving_withOriginalRankingAndExplicitFeatureStore_shouldReturnNewCalculatedFeatureVector() throws Exception {
|
||||||
|
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
loadFeaturesAndModels();
|
||||||
|
|
||||||
|
final SolrQuery query = new SolrQuery();
|
||||||
|
query.setQuery("*:*");
|
||||||
|
query.add("fl", "*, score,features:[fv store=featureStore2 efi.user_query='w5' format=sparse]");
|
||||||
|
query.add("rows", "10");
|
||||||
|
query.add("debugQuery", "true");
|
||||||
|
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||||
|
query.add("rq",
|
||||||
|
"{!ltr model=modelA model=_OriginalRanking_ reRankDocs=10 efi.user_query='w5'}");
|
||||||
|
|
||||||
|
/*
|
||||||
|
Doc1 = "featureB2=1.0", ScoreA(0)
|
||||||
|
Doc3 = "featureB2=1.0", ScoreA(0)
|
||||||
|
Doc4 = "featureB1=1.0", ScoreA(0)
|
||||||
|
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39)
|
||||||
|
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3)
|
||||||
|
|
||||||
|
ModelARerankedList = [7,8,1,3,4]
|
||||||
|
_OriginalRanking_ = [1,3,4,7,8]
|
||||||
|
|
||||||
|
|
||||||
|
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
*/
|
||||||
|
String[] expectedFeatureVectors = new String[]{
|
||||||
|
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||||
|
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||||
|
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||||
|
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||||
|
"featureC1\\=1.0"};
|
||||||
|
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||||
|
|
||||||
|
String[] tests = new String[16];
|
||||||
|
tests[0] = "/response/numFound/==5";
|
||||||
|
for (int i = 1; i <= 5; i++) {
|
||||||
|
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||||
|
if (expectedFeatureVectors[(i - 1)] != null) {
|
||||||
|
tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assertJQ("/query" + query.toQueryString(), tests);
|
||||||
|
|
||||||
|
int[] nullFeatureVectorIndexes = new int[]{1, 2, 4};
|
||||||
|
for (int index : nullFeatureVectorIndexes) {
|
||||||
|
TeamDraftInterleaving.setRANDOM(new Random(10101010));
|
||||||
|
String[] nullFeatureVectorTests = new String[1];
|
||||||
|
try {
|
||||||
|
nullFeatureVectorTests[0] = "/response/docs/[" + index + "]/features==";
|
||||||
|
assertJQ("/query" + query.toQueryString(), nullFeatureVectorTests);
|
||||||
|
} catch (Exception e) {
|
||||||
|
assertEquals("Path not found: /response/docs/[" + index + "]/features", e.getMessage());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void interleaving_modelsFromDifferentFeatureStores_shouldLogFeaturesCorrectly() throws Exception {
|
||||||
|
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
loadFeaturesAndModels();
|
||||||
|
|
||||||
|
final SolrQuery query = new SolrQuery();
|
||||||
|
query.setQuery("*:*");
|
||||||
|
query.add("fl", "*, score,features:[fv format=sparse]");
|
||||||
|
query.add("rows", "10");
|
||||||
|
query.add("debugQuery", "true");
|
||||||
|
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||||
|
query.add("rq",
|
||||||
|
"{!ltr model=modelA model=modelC reRankDocs=10 efi.user_query='w5'}");
|
||||||
|
|
||||||
|
/*
|
||||||
|
Doc1 = "featureC2=1.0", ScoreA(0), ScoreC(25)
|
||||||
|
Doc3 = "featureC2=1.0", ScoreA(0), ScoreC(25)
|
||||||
|
Doc4 = "featureC1=1.0", ScoreA(0), ScoreC(5)
|
||||||
|
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0", ScoreA(39), ScoreC(0)
|
||||||
|
Doc8 = "featureA1=1.0,featureC2=1.0", ScoreA(3), ScoreC(25)
|
||||||
|
ModelARerankedList = [7,8,1,3,4]
|
||||||
|
ModelCRerankedList = [1,3,8,4,7]
|
||||||
|
|
||||||
|
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
*/
|
||||||
|
String[] expectedFeatureVectors = new String[]{
|
||||||
|
"featureA1\\=1.0\\,featureA2\\=1.0\\,featureAB\\=1.0\\,featureB2\\=1.0",
|
||||||
|
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||||
|
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||||
|
"featureA1\\=1.0\\,featureB2\\=1.0",
|
||||||
|
"featureC1\\=1.0"};
|
||||||
|
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||||
|
|
||||||
|
String[] tests = new String[16];
|
||||||
|
tests[0] = "/response/numFound/==5";
|
||||||
|
for (int i = 1; i <= 5; i++) {
|
||||||
|
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||||
|
tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
|
||||||
|
}
|
||||||
|
assertJQ("/query" + query.toQueryString(), tests);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void interleaving_featureLoggerFromNewFeatureStoreWithDifferentEfi_shouldReturnNewCalculatedFeatureVector() throws Exception {
|
||||||
|
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
loadFeaturesAndModels();
|
||||||
|
|
||||||
|
final SolrQuery query = new SolrQuery();
|
||||||
|
query.setQuery("*:*");
|
||||||
|
query.add("fl", "*, score,features:[fv store=featureStore2 efi.user_query='w8' format=sparse]");
|
||||||
|
query.add("rows", "10");
|
||||||
|
query.add("debugQuery", "true");
|
||||||
|
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||||
|
query.add("rq",
|
||||||
|
"{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}");
|
||||||
|
|
||||||
|
/*
|
||||||
|
Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||||
|
Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||||
|
Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2)
|
||||||
|
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12)
|
||||||
|
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4)
|
||||||
|
ModelARerankedList = [7,8,1,3,4]
|
||||||
|
ModelBRerankedList = [7,1,3,8,4]
|
||||||
|
|
||||||
|
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
*/
|
||||||
|
String[] expectedFeatureVectors = new String[]{
|
||||||
|
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
""};
|
||||||
|
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||||
|
|
||||||
|
String[] tests = new String[16];
|
||||||
|
tests[0] = "/response/numFound/==5";
|
||||||
|
for (int i = 1; i <= 5; i++) {
|
||||||
|
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||||
|
tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
|
||||||
|
}
|
||||||
|
assertJQ("/query" + query.toQueryString(), tests);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void interleaving_explicitFeatureStoreReusableFromModel_shouldLogFeaturesCorrectly() throws Exception {
|
||||||
|
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
loadFeaturesAndModels();
|
||||||
|
|
||||||
|
final SolrQuery query = new SolrQuery();
|
||||||
|
query.setQuery("*:*");
|
||||||
|
query.add("fl", "*, score,features:[fv store=featureStore2 format=sparse]");
|
||||||
|
query.add("rows", "10");
|
||||||
|
query.add("debugQuery", "true");
|
||||||
|
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||||
|
query.add("rq",
|
||||||
|
"{!ltr model=modelA model=modelC reRankDocs=10 efi.user_query='w5'}");
|
||||||
|
|
||||||
|
/*
|
||||||
|
Doc1 = "featureC2=1.0", ScoreA(0), ScoreC(25)
|
||||||
|
Doc3 = "featureC2=1.0", ScoreA(0), ScoreC(25)
|
||||||
|
Doc4 = "featureC1=1.0", ScoreA(0), ScoreC(5)
|
||||||
|
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0", ScoreA(39), ScoreC(0)
|
||||||
|
Doc8 = "featureA1=1.0,featureC2=1.0", ScoreA(3), ScoreC(25)
|
||||||
|
ModelARerankedList = [7,8,1,3,4]
|
||||||
|
ModelCRerankedList = [1,3,8,4,7]
|
||||||
|
|
||||||
|
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
*/
|
||||||
|
String[] expectedFeatureVectors = new String[]{
|
||||||
|
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||||
|
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||||
|
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||||
|
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||||
|
"featureC1\\=1.0"};
|
||||||
|
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||||
|
|
||||||
|
String[] tests = new String[16];
|
||||||
|
tests[0] = "/response/numFound/==5";
|
||||||
|
for (int i = 1; i <= 5; i++) {
|
||||||
|
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||||
|
tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
|
||||||
|
}
|
||||||
|
assertJQ("/query" + query.toQueryString(), tests);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,277 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.solr.ltr.response.transform;
|
||||||
|
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
|
import org.apache.solr.client.solrj.SolrQuery;
|
||||||
|
import org.apache.solr.ltr.TestRerankBase;
|
||||||
|
import org.apache.solr.ltr.feature.SolrFeature;
|
||||||
|
import org.apache.solr.ltr.interleaving.algorithms.TeamDraftInterleaving;
|
||||||
|
import org.apache.solr.ltr.model.LinearModel;
|
||||||
|
import org.junit.After;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
public class TestInterleavingTransformer extends TestRerankBase {
|
||||||
|
@Before
|
||||||
|
public void before() throws Exception {
|
||||||
|
setuptest(false);
|
||||||
|
|
||||||
|
assertU(adoc("id", "1", "title", "w1", "description", "w5", "popularity",
|
||||||
|
"1"));
|
||||||
|
assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description",
|
||||||
|
"w2 2asd asdd didid", "popularity", "2"));
|
||||||
|
assertU(adoc("id", "3", "title", "w1", "description", "w5", "popularity",
|
||||||
|
"3"));
|
||||||
|
assertU(adoc("id", "4", "title", "w1", "description", "w1", "popularity",
|
||||||
|
"6"));
|
||||||
|
assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity",
|
||||||
|
"5"));
|
||||||
|
assertU(adoc("id", "6", "title", "w6 w2", "description", "w1 w2",
|
||||||
|
"popularity", "6"));
|
||||||
|
assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description",
|
||||||
|
"w6 w2 w3 w4 w5 w8", "popularity", "88888"));
|
||||||
|
assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description",
|
||||||
|
"w1 w1 w1 w2 w2 w5", "popularity", "88888"));
|
||||||
|
assertU(commit());
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@After
|
||||||
|
public void after() throws Exception {
|
||||||
|
aftertest();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void loadFeaturesAndModelsForInterleaving() throws Exception {
|
||||||
|
loadFeature("featureA1", SolrFeature.class.getName(),
|
||||||
|
"{\"fq\":[\"{!terms f=popularity}88888\"]}");
|
||||||
|
loadFeature("featureA2", SolrFeature.class.getName(),
|
||||||
|
"{\"fq\":[\"{!terms f=title}${user_query}\"]}");
|
||||||
|
loadFeature("featureAB", SolrFeature.class.getName(),
|
||||||
|
"{\"fq\":[\"{!terms f=title}${user_query}\"]}");
|
||||||
|
loadFeature("featureB1", SolrFeature.class.getName(),
|
||||||
|
"{\"fq\":[\"{!terms f=popularity}6\"]}");
|
||||||
|
loadFeature("featureB2", SolrFeature.class.getName(),
|
||||||
|
"{\"fq\":[\"{!terms f=description}${user_query}\"]}");
|
||||||
|
loadFeature("featureC1", SolrFeature.class.getName(),"featureStore2",
|
||||||
|
"{\"fq\":[\"{!terms f=popularity}6\"]}");
|
||||||
|
loadFeature("featureC2", SolrFeature.class.getName(),"featureStore2",
|
||||||
|
"{\"fq\":[\"{!terms f=popularity}1\"]}");
|
||||||
|
|
||||||
|
loadModel("modelA", LinearModel.class.getName(),
|
||||||
|
new String[]{"featureA1", "featureA2", "featureAB"},
|
||||||
|
"{\"weights\":{\"featureA1\":3.0, \"featureA2\":9.0, \"featureAB\":27.0}}");
|
||||||
|
|
||||||
|
loadModel("modelB", LinearModel.class.getName(),
|
||||||
|
new String[]{"featureB1", "featureB2", "featureAB"},
|
||||||
|
"{\"weights\":{\"featureB1\":2.0, \"featureB2\":4.0, \"featureAB\":8.0}}");
|
||||||
|
|
||||||
|
loadModel("modelC", LinearModel.class.getName(),
|
||||||
|
new String[]{"featureC1", "featureC2"},"featureStore2",
|
||||||
|
"{\"weights\":{\"featureC1\":5.0, \"featureC2\":25.0}}");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void interleavingTransformer_shouldReturnInterleavingPickInTheResults() throws Exception {
|
||||||
|
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
loadFeaturesAndModelsForInterleaving();
|
||||||
|
|
||||||
|
final SolrQuery query = new SolrQuery();
|
||||||
|
query.setQuery("*:*");
|
||||||
|
query.add("fl", "*, score,interleavingPick:[interleaving]");
|
||||||
|
query.add("rows", "10");
|
||||||
|
query.add("debugQuery", "true");
|
||||||
|
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||||
|
query.add("rq",
|
||||||
|
"{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}");
|
||||||
|
|
||||||
|
/*
|
||||||
|
Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||||
|
Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||||
|
Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2)
|
||||||
|
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12)
|
||||||
|
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4)
|
||||||
|
ModelARerankedList = [7,8,1,3,4]
|
||||||
|
ModelBRerankedList = [7,1,3,8,4]
|
||||||
|
|
||||||
|
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
*/
|
||||||
|
String[] expectedInterleavingPicks = new String[]{"modelA", "modelB", "modelB", "modelA", "modelB"};
|
||||||
|
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||||
|
|
||||||
|
String[] tests = new String[11];
|
||||||
|
tests[0] = "/response/numFound/==5";
|
||||||
|
for (int i = 1; i <= 5; i++) {
|
||||||
|
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||||
|
tests[i + 5] = "/response/docs/[" + (i - 1) + "]/interleavingPick==" + expectedInterleavingPicks[(i - 1)];
|
||||||
|
}
|
||||||
|
assertJQ("/query" + query.toQueryString(), tests);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void interleavingTransformerWithOriginalRanking_shouldReturnInterleavingPickInTheResults() throws Exception {
|
||||||
|
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
loadFeaturesAndModelsForInterleaving();
|
||||||
|
|
||||||
|
final SolrQuery query = new SolrQuery();
|
||||||
|
query.setQuery("*:*");
|
||||||
|
query.add("fl", "*, score,interleavingPick:[interleaving]");
|
||||||
|
query.add("rows", "10");
|
||||||
|
query.add("debugQuery", "true");
|
||||||
|
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||||
|
query.add("rq",
|
||||||
|
"{!ltr model=modelA model=_OriginalRanking_ reRankDocs=10 efi.user_query='w5'}");
|
||||||
|
|
||||||
|
/*
|
||||||
|
Doc1 = "featureB2=1.0", ScoreA(0)
|
||||||
|
Doc3 = "featureB2=1.0", ScoreA(0)
|
||||||
|
Doc4 = "featureB1=1.0", ScoreA(0)
|
||||||
|
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39)
|
||||||
|
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3)
|
||||||
|
|
||||||
|
ModelARerankedList = [7,8,1,3,4]
|
||||||
|
OriginalRanking = [1,3,4,7,8]
|
||||||
|
|
||||||
|
|
||||||
|
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
*/
|
||||||
|
String[] expectedInterleavingPicks = new String[]{"modelA", "_OriginalRanking_", "_OriginalRanking_", "modelA", "_OriginalRanking_"};
|
||||||
|
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||||
|
|
||||||
|
String[] tests = new String[11];
|
||||||
|
tests[0] = "/response/numFound/==5";
|
||||||
|
for (int i = 1; i <= 5; i++) {
|
||||||
|
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||||
|
tests[i + 5] = "/response/docs/[" + (i - 1) + "]/interleavingPick==" + expectedInterleavingPicks[(i - 1)];
|
||||||
|
}
|
||||||
|
assertJQ("/query" + query.toQueryString(), tests);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void interleavingTransformer_shouldBeCompatibleWithFeatureTransformer() throws Exception {
|
||||||
|
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
loadFeaturesAndModelsForInterleaving();
|
||||||
|
|
||||||
|
final SolrQuery query = new SolrQuery();
|
||||||
|
query.setQuery("*:*");
|
||||||
|
query.add("fl", "*, score,interleavingPick:[interleaving],features:[fv format=sparse]");
|
||||||
|
query.add("rows", "10");
|
||||||
|
query.add("debugQuery", "true");
|
||||||
|
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||||
|
query.add("rq",
|
||||||
|
"{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}");
|
||||||
|
|
||||||
|
/*
|
||||||
|
Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||||
|
Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||||
|
Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2)
|
||||||
|
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12)
|
||||||
|
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4)
|
||||||
|
ModelARerankedList = [7,8,1,3,4]
|
||||||
|
ModelBRerankedList = [7,1,3,8,4]
|
||||||
|
|
||||||
|
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
*/
|
||||||
|
String[] expectedFeatureVectors = new String[]{
|
||||||
|
"featureA1\\=1.0\\,featureA2\\=1.0\\,featureAB\\=1.0\\,featureB2\\=1.0",
|
||||||
|
"featureB2\\=1.0",
|
||||||
|
"featureB2\\=1.0",
|
||||||
|
"featureA1\\=1.0\\,featureB2\\=1.0",
|
||||||
|
"featureB1\\=1.0"};
|
||||||
|
String[] expectedInterleavingPicks = new String[]{"modelA", "modelB", "modelB", "modelA", "modelB"};
|
||||||
|
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||||
|
|
||||||
|
String[] tests = new String[16];
|
||||||
|
tests[0] = "/response/numFound/==5";
|
||||||
|
for (int i = 1; i <= 5; i++) {
|
||||||
|
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||||
|
tests[i + 5] = "/response/docs/[" + (i - 1) + "]/interleavingPick==" + expectedInterleavingPicks[(i - 1)];
|
||||||
|
tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
|
||||||
|
}
|
||||||
|
assertJQ("/query" + query.toQueryString(), tests);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void interleavingTransformerWithOriginalRanking_shouldBeCompatibleWithFeatureTransformer() throws Exception {
|
||||||
|
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
loadFeaturesAndModelsForInterleaving();
|
||||||
|
|
||||||
|
final SolrQuery query = new SolrQuery();
|
||||||
|
query.setQuery("*:*");
|
||||||
|
query.add("fl", "*, score,interleavingPick:[interleaving],features:[fv format=sparse]");
|
||||||
|
query.add("rows", "10");
|
||||||
|
query.add("debugQuery", "true");
|
||||||
|
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||||
|
query.add("rq",
|
||||||
|
"{!ltr model=modelA model=_OriginalRanking_ reRankDocs=10 efi.user_query='w5'}");
|
||||||
|
|
||||||
|
/*
|
||||||
|
Doc1 = "featureB2=1.0", ScoreA(0)
|
||||||
|
Doc3 = "featureB2=1.0", ScoreA(0)
|
||||||
|
Doc4 = "featureB1=1.0", ScoreA(0)
|
||||||
|
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39)
|
||||||
|
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3)
|
||||||
|
|
||||||
|
ModelARerankedList = [7,8,1,3,4]
|
||||||
|
_OriginalRanking_ = [1,3,4,7,8]
|
||||||
|
|
||||||
|
|
||||||
|
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||||
|
*/
|
||||||
|
String[] expectedFeatureVectors = new String[]{
|
||||||
|
"featureA1\\=1.0\\,featureA2\\=1.0\\,featureAB\\=1.0\\,featureB2\\=1.0",
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
"featureA1\\=1.0\\,featureB2\\=1.0",
|
||||||
|
null};
|
||||||
|
String[] expectedInterleavingPicks = new String[]{"modelA", "_OriginalRanking_", "_OriginalRanking_", "modelA", "_OriginalRanking_"};
|
||||||
|
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||||
|
|
||||||
|
String[] tests = new String[16];
|
||||||
|
tests[0] = "/response/numFound/==5";
|
||||||
|
for (int i = 1; i <= 5; i++) {
|
||||||
|
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||||
|
tests[i + 5] = "/response/docs/[" + (i - 1) + "]/interleavingPick==" + expectedInterleavingPicks[(i - 1)];
|
||||||
|
if (expectedFeatureVectors[(i - 1)] != null) {
|
||||||
|
tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assertJQ("/query" + query.toQueryString(), tests);
|
||||||
|
|
||||||
|
int[] nullFeatureVectorIndexes = new int[]{1, 2, 4};
|
||||||
|
for (int index : nullFeatureVectorIndexes) {
|
||||||
|
TeamDraftInterleaving.setRANDOM(new Random(10101010));
|
||||||
|
String[] nullFeatureVectorTests = new String[1];
|
||||||
|
try {
|
||||||
|
nullFeatureVectorTests[0] = "/response/docs/[" + index + "]/features==";
|
||||||
|
assertJQ("/query" + query.toQueryString(), nullFeatureVectorTests);
|
||||||
|
} catch (Exception e) {
|
||||||
|
assertEquals("Path not found: /response/docs/[" + index + "]/features", e.getMessage());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -38,6 +38,13 @@ A ranking model computes the scores used to rerank documents. Irrespective of an
|
||||||
* features that represent the document being scored
|
* features that represent the document being scored
|
||||||
* features that represent the query for which the document is being scored
|
* features that represent the query for which the document is being scored
|
||||||
|
|
||||||
|
==== Interleaving
|
||||||
|
|
||||||
|
Interleaving is an approach to Online Search Quality evaluation that allows to compare two models interleaving their results in the final ranked list returned to the user.
|
||||||
|
|
||||||
|
* currently only the Team Draft Interleaving algorithm is supported (and its implementation assumes all results are from the same shard)
|
||||||
|
|
||||||
|
|
||||||
==== Feature
|
==== Feature
|
||||||
|
|
||||||
A feature is a value, a number, that represents some quantity or quality of the document being scored or of the query for which documents are being scored. For example documents often have a 'recency' quality and 'number of past purchases' might be a quantity that is passed to Solr as part of the search query.
|
A feature is a value, a number, that represents some quantity or quality of the document being scored or of the query for which documents are being scored. For example documents often have a 'recency' quality and 'number of past purchases' might be a quantity that is passed to Solr as part of the search query.
|
||||||
|
@ -247,6 +254,81 @@ The output XML will include feature values as a comma-separated list, resembling
|
||||||
}}
|
}}
|
||||||
----
|
----
|
||||||
|
|
||||||
|
=== Running a Rerank Query Interleaving Two Models
|
||||||
|
|
||||||
|
To rerank the results of a query, interleaving two models (myModelA, myModelB) add the `rq` parameter to your search, passing two models in input, for example:
|
||||||
|
|
||||||
|
[source,text]
|
||||||
|
http://localhost:8983/solr/techproducts/query?q=test&rq={!ltr model=myModelA model=myModelB reRankDocs=100}&fl=id,score
|
||||||
|
|
||||||
|
To obtain the model that interleaving picked for a search result, computed during reranking, add `[interleaving]` to the `fl` parameter, for example:
|
||||||
|
|
||||||
|
[source,text]
|
||||||
|
http://localhost:8983/solr/techproducts/query?q=test&rq={!ltr model=myModelA model=myModelB reRankDocs=100}&fl=id,score,[interleaving]
|
||||||
|
|
||||||
|
The output XML will include the model picked for each search result, resembling the output shown here:
|
||||||
|
|
||||||
|
[source,json]
|
||||||
|
----
|
||||||
|
{
|
||||||
|
"responseHeader":{
|
||||||
|
"status":0,
|
||||||
|
"QTime":0,
|
||||||
|
"params":{
|
||||||
|
"q":"test",
|
||||||
|
"fl":"id,score,[interleaving]",
|
||||||
|
"rq":"{!ltr model=myModelA model=myModelB reRankDocs=100}"}},
|
||||||
|
"response":{"numFound":2,"start":0,"maxScore":1.0005897,"docs":[
|
||||||
|
{
|
||||||
|
"id":"GB18030TEST",
|
||||||
|
"score":1.0005897,
|
||||||
|
"[interleaving]":"myModelB"},
|
||||||
|
{
|
||||||
|
"id":"UTF8TEST",
|
||||||
|
"score":0.79656565,
|
||||||
|
"[interleaving]":"myModelA"}]
|
||||||
|
}}
|
||||||
|
----
|
||||||
|
|
||||||
|
=== Running a Rerank Query Interleaving a model with the original ranking
|
||||||
|
When approaching Search Quality Evaluation with interleaving it may be useful to compare a model with the original ranking.
|
||||||
|
To rerank the results of a query, interleaving a model with the original ranking, add the `rq` parameter to your search, passing the special inbuilt `_OriginalRanking_` model identifier as one model and your comparison model as the other model, for example:
|
||||||
|
|
||||||
|
|
||||||
|
[source,text]
|
||||||
|
http://localhost:8983/solr/techproducts/query?q=test&rq={!ltr model=_OriginalRanking_ model=myModel reRankDocs=100}&fl=id,score
|
||||||
|
|
||||||
|
The addition of the `rq` parameter will not change the output XML of the search.
|
||||||
|
|
||||||
|
To obtain the model that interleaving picked for a search result, computed during reranking, add `[interleaving]` to the `fl` parameter, for example:
|
||||||
|
|
||||||
|
[source,text]
|
||||||
|
http://localhost:8983/solr/techproducts/query?q=test&rq={!ltr model=_OriginalRanking_ model=myModel reRankDocs=100}&fl=id,score,[interleaving]
|
||||||
|
|
||||||
|
The output XML will include the model picked for each search result, resembling the output shown here:
|
||||||
|
|
||||||
|
[source,json]
|
||||||
|
----
|
||||||
|
{
|
||||||
|
"responseHeader":{
|
||||||
|
"status":0,
|
||||||
|
"QTime":0,
|
||||||
|
"params":{
|
||||||
|
"q":"test",
|
||||||
|
"fl":"id,score,[features]",
|
||||||
|
"rq":"{!ltr model=_OriginalRanking_ model=myModel reRankDocs=100}"}},
|
||||||
|
"response":{"numFound":2,"start":0,"maxScore":1.0005897,"docs":[
|
||||||
|
{
|
||||||
|
"id":"GB18030TEST",
|
||||||
|
"score":1.0005897,
|
||||||
|
"[interleaving]":"_OriginalRanking_"},
|
||||||
|
{
|
||||||
|
"id":"UTF8TEST",
|
||||||
|
"score":0.79656565,
|
||||||
|
"[interleaving]":"myModel"}]
|
||||||
|
}}
|
||||||
|
----
|
||||||
|
|
||||||
=== External Feature Information
|
=== External Feature Information
|
||||||
|
|
||||||
The {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/feature/ValueFeature.html[ValueFeature] and {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/feature/SolrFeature.html[SolrFeature] classes support the use of external feature information, `efi` for short.
|
The {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/feature/ValueFeature.html[ValueFeature] and {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/feature/SolrFeature.html[SolrFeature] classes support the use of external feature information, `efi` for short.
|
||||||
|
@ -418,6 +500,13 @@ Learning-To-Rank is a contrib module and therefore its plugins must be configure
|
||||||
</transformer>
|
</transformer>
|
||||||
----
|
----
|
||||||
|
|
||||||
|
* Declaration of the `[interleaving]` transformer.
|
||||||
|
+
|
||||||
|
[source,xml]
|
||||||
|
----
|
||||||
|
<transformer name="interleaving" class="org.apache.solr.ltr.response.transform.LTRInterleavingTransformerFactory"/>
|
||||||
|
----
|
||||||
|
|
||||||
=== Advanced Options
|
=== Advanced Options
|
||||||
|
|
||||||
==== LTRThreadModule
|
==== LTRThreadModule
|
||||||
|
@ -446,11 +535,12 @@ How does Solr Learning-To-Rank work under the hood?::
|
||||||
Please refer to the `ltr` {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/package-summary.html[javadocs] for an implementation overview.
|
Please refer to the `ltr` {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/package-summary.html[javadocs] for an implementation overview.
|
||||||
|
|
||||||
How could I write additional models and/or features?::
|
How could I write additional models and/or features?::
|
||||||
Contributions for further models, features and normalizers are welcome. Related links:
|
Contributions for further models, features, normalizers and interleaving algorithms are welcome. Related links:
|
||||||
+
|
+
|
||||||
* {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/model/LTRScoringModel.html[LTRScoringModel javadocs]
|
* {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/model/LTRScoringModel.html[LTRScoringModel javadocs]
|
||||||
* {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/feature/Feature.html[Feature javadocs]
|
* {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/feature/Feature.html[Feature javadocs]
|
||||||
* {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/norm/Normalizer.html[Normalizer javadocs]
|
* {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/norm/Normalizer.html[Normalizer javadocs]
|
||||||
|
* {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/interleaving/Interleaving.html[Interleaving javadocs]
|
||||||
* https://cwiki.apache.org/confluence/display/solr/HowToContribute
|
* https://cwiki.apache.org/confluence/display/solr/HowToContribute
|
||||||
* https://cwiki.apache.org/confluence/display/LUCENE/HowToContribute
|
* https://cwiki.apache.org/confluence/display/LUCENE/HowToContribute
|
||||||
|
|
||||||
|
@ -779,3 +869,7 @@ The feature store and the model store are both <<managed-resources.adoc#managed-
|
||||||
* "Learning to Rank in Solr" presentation at Lucene/Solr Revolution 2015 in Austin:
|
* "Learning to Rank in Solr" presentation at Lucene/Solr Revolution 2015 in Austin:
|
||||||
** Slides: http://www.slideshare.net/lucidworks/learning-to-rank-in-solr-presented-by-michael-nilsson-diego-ceccarelli-bloomberg-lp
|
** Slides: http://www.slideshare.net/lucidworks/learning-to-rank-in-solr-presented-by-michael-nilsson-diego-ceccarelli-bloomberg-lp
|
||||||
** Video: https://www.youtube.com/watch?v=M7BKwJoh96s
|
** Video: https://www.youtube.com/watch?v=M7BKwJoh96s
|
||||||
|
|
||||||
|
* The importance of Online Testing in Learning To Rank:
|
||||||
|
** Blog: https://sease.io/2020/04/the-importance-of-online-testing-in-learning-to-rank-part-1.html
|
||||||
|
** Blog: https://sease.io/2020/05/online-testing-for-learning-to-rank-interleaving.html
|
||||||
|
|
Loading…
Reference in New Issue