mirror of https://github.com/apache/lucene.git
SOLR-14560: Interleaving for Learning To Rank (#1571)
SOLR-14560: Add interleaving support in Learning To Rank
This commit is contained in:
parent
ea4dd0580f
commit
af0455ac83
|
@ -167,6 +167,8 @@ New Features
|
|||
|
||||
* SOLR-14907: Add v2 API for configSet upload, including single-file insertion. (Houston Putman)
|
||||
|
||||
* SOLR-14560: Add interleaving support in Learning To Rank. (Alessandro Benedetti, Christine Poerschke)
|
||||
|
||||
Improvements
|
||||
---------------------
|
||||
* SOLR-14942: Reduce leader election time on node shutdown by removing election nodes before closing cores.
|
||||
|
|
|
@ -31,6 +31,7 @@ import org.apache.lucene.search.ScoreMode;
|
|||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.search.Weight;
|
||||
import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery;
|
||||
import org.apache.solr.search.SolrIndexSearcher;
|
||||
|
||||
|
||||
|
@ -42,12 +43,17 @@ import org.apache.solr.search.SolrIndexSearcher;
|
|||
* */
|
||||
public class LTRRescorer extends Rescorer {
|
||||
|
||||
LTRScoringQuery scoringQuery;
|
||||
final private LTRScoringQuery scoringQuery;
|
||||
|
||||
public LTRRescorer() {
|
||||
this.scoringQuery = null;
|
||||
}
|
||||
|
||||
public LTRRescorer(LTRScoringQuery scoringQuery) {
|
||||
this.scoringQuery = scoringQuery;
|
||||
}
|
||||
|
||||
private void heapAdjust(ScoreDoc[] hits, int size, int root) {
|
||||
protected static void heapAdjust(ScoreDoc[] hits, int size, int root) {
|
||||
final ScoreDoc doc = hits[root];
|
||||
final float score = doc.score;
|
||||
int i = root;
|
||||
|
@ -82,7 +88,7 @@ public class LTRRescorer extends Rescorer {
|
|||
}
|
||||
}
|
||||
|
||||
private void heapify(ScoreDoc[] hits, int size) {
|
||||
protected static void heapify(ScoreDoc[] hits, int size) {
|
||||
for (int i = (size >> 1) - 1; i >= 0; i--) {
|
||||
heapAdjust(hits, size, i);
|
||||
}
|
||||
|
@ -104,23 +110,27 @@ public class LTRRescorer extends Rescorer {
|
|||
if ((topN == 0) || (firstPassTopDocs.scoreDocs.length == 0)) {
|
||||
return firstPassTopDocs;
|
||||
}
|
||||
final ScoreDoc[] hits = firstPassTopDocs.scoreDocs;
|
||||
Arrays.sort(hits, new Comparator<ScoreDoc>() {
|
||||
@Override
|
||||
public int compare(ScoreDoc a, ScoreDoc b) {
|
||||
return a.doc - b.doc;
|
||||
}
|
||||
});
|
||||
|
||||
assert firstPassTopDocs.totalHits.relation == TotalHits.Relation.EQUAL_TO;
|
||||
final ScoreDoc[] firstPassResults = getFirstPassDocsRanked(firstPassTopDocs);
|
||||
topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value));
|
||||
|
||||
final ScoreDoc[] reranked = rerank(searcher, topN, firstPassResults);
|
||||
|
||||
return new TopDocs(firstPassTopDocs.totalHits, reranked);
|
||||
}
|
||||
|
||||
private ScoreDoc[] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPassResults) throws IOException {
|
||||
final ScoreDoc[] reranked = new ScoreDoc[topN];
|
||||
final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
|
||||
final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) searcher
|
||||
.createWeight(searcher.rewrite(scoringQuery), ScoreMode.COMPLETE, 1);
|
||||
|
||||
scoreFeatures(searcher, firstPassTopDocs,topN, modelWeight, hits, leaves, reranked);
|
||||
scoreFeatures(searcher,topN, modelWeight, firstPassResults, leaves, reranked);
|
||||
// Must sort all documents that we reranked, and then select the top
|
||||
sortByScore(reranked);
|
||||
return reranked;
|
||||
}
|
||||
|
||||
protected static void sortByScore(ScoreDoc[] reranked) {
|
||||
Arrays.sort(reranked, new Comparator<ScoreDoc>() {
|
||||
@Override
|
||||
public int compare(ScoreDoc a, ScoreDoc b) {
|
||||
|
@ -136,13 +146,24 @@ public class LTRRescorer extends Rescorer {
|
|||
}
|
||||
}
|
||||
});
|
||||
|
||||
return new TopDocs(firstPassTopDocs.totalHits, reranked);
|
||||
}
|
||||
|
||||
public void scoreFeatures(IndexSearcher indexSearcher, TopDocs firstPassTopDocs,
|
||||
int topN, LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List<LeafReaderContext> leaves,
|
||||
ScoreDoc[] reranked) throws IOException {
|
||||
protected static ScoreDoc[] getFirstPassDocsRanked(TopDocs firstPassTopDocs) {
|
||||
final ScoreDoc[] hits = firstPassTopDocs.scoreDocs;
|
||||
Arrays.sort(hits, new Comparator<ScoreDoc>() {
|
||||
@Override
|
||||
public int compare(ScoreDoc a, ScoreDoc b) {
|
||||
return a.doc - b.doc;
|
||||
}
|
||||
});
|
||||
|
||||
assert firstPassTopDocs.totalHits.relation == TotalHits.Relation.EQUAL_TO;
|
||||
return hits;
|
||||
}
|
||||
|
||||
public void scoreFeatures(IndexSearcher indexSearcher,
|
||||
int topN, LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List<LeafReaderContext> leaves,
|
||||
ScoreDoc[] reranked) throws IOException {
|
||||
|
||||
int readerUpto = -1;
|
||||
int endDoc = 0;
|
||||
|
@ -150,7 +171,6 @@ public class LTRRescorer extends Rescorer {
|
|||
|
||||
LTRScoringQuery.ModelWeight.ModelScorer scorer = null;
|
||||
int hitUpto = 0;
|
||||
final FeatureLogger featureLogger = scoringQuery.getFeatureLogger();
|
||||
|
||||
while (hitUpto < hits.length) {
|
||||
final ScoreDoc hit = hits[hitUpto];
|
||||
|
@ -166,64 +186,77 @@ public class LTRRescorer extends Rescorer {
|
|||
docBase = readerContext.docBase;
|
||||
scorer = modelWeight.scorer(readerContext);
|
||||
}
|
||||
// Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to
|
||||
// call score
|
||||
// even if no feature scorers match, since a model might use that info to
|
||||
// return a
|
||||
// non-zero score. Same applies for the case of advancing a LTRScoringQuery.ModelWeight.ModelScorer
|
||||
// past the target
|
||||
// doc since the model algorithm still needs to compute a potentially
|
||||
// non-zero score from blank features.
|
||||
assert (scorer != null);
|
||||
final int targetDoc = docID - docBase;
|
||||
scorer.docID();
|
||||
scorer.iterator().advance(targetDoc);
|
||||
scoreSingleHit(indexSearcher, topN, modelWeight, docBase, hitUpto, hit, docID, scoringQuery, scorer, reranked);
|
||||
hitUpto++;
|
||||
}
|
||||
}
|
||||
|
||||
scorer.getDocInfo().setOriginalDocScore(hit.score);
|
||||
hit.score = scorer.score();
|
||||
if (hitUpto < topN) {
|
||||
reranked[hitUpto] = hit;
|
||||
// if the heap is not full, maybe I want to log the features for this
|
||||
// document
|
||||
protected static void scoreSingleHit(IndexSearcher indexSearcher, int topN, LTRScoringQuery.ModelWeight modelWeight, int docBase, int hitUpto, ScoreDoc hit, int docID, LTRScoringQuery rerankingQuery, LTRScoringQuery.ModelWeight.ModelScorer scorer, ScoreDoc[] reranked) throws IOException {
|
||||
final FeatureLogger featureLogger = rerankingQuery.getFeatureLogger();
|
||||
// Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to
|
||||
// call score
|
||||
// even if no feature scorers match, since a model might use that info to
|
||||
// return a
|
||||
// non-zero score. Same applies for the case of advancing a LTRScoringQuery.ModelWeight.ModelScorer
|
||||
// past the target
|
||||
// doc since the model algorithm still needs to compute a potentially
|
||||
// non-zero score from blank features.
|
||||
assert (scorer != null);
|
||||
final int targetDoc = docID - docBase;
|
||||
scorer.docID();
|
||||
scorer.iterator().advance(targetDoc);
|
||||
|
||||
scorer.getDocInfo().setOriginalDocScore(hit.score);
|
||||
hit.score = scorer.score();
|
||||
if (hitUpto < topN) {
|
||||
reranked[hitUpto] = hit;
|
||||
// if the heap is not full, maybe I want to log the features for this
|
||||
// document
|
||||
if (featureLogger != null && indexSearcher instanceof SolrIndexSearcher) {
|
||||
featureLogger.log(hit.doc, rerankingQuery, (SolrIndexSearcher) indexSearcher,
|
||||
modelWeight.getFeaturesInfo());
|
||||
}
|
||||
} else if (hitUpto == topN) {
|
||||
// collected topN document, I create the heap
|
||||
heapify(reranked, topN);
|
||||
}
|
||||
if (hitUpto >= topN) {
|
||||
// once that heap is ready, if the score of this document is lower that
|
||||
// the minimum
|
||||
// i don't want to log the feature. Otherwise I replace it with the
|
||||
// minimum and fix the
|
||||
// heap.
|
||||
if (hit.score > reranked[0].score) {
|
||||
reranked[0] = hit;
|
||||
heapAdjust(reranked, topN, 0);
|
||||
if (featureLogger != null && indexSearcher instanceof SolrIndexSearcher) {
|
||||
featureLogger.log(hit.doc, scoringQuery, (SolrIndexSearcher)indexSearcher,
|
||||
featureLogger.log(hit.doc, rerankingQuery, (SolrIndexSearcher) indexSearcher,
|
||||
modelWeight.getFeaturesInfo());
|
||||
}
|
||||
} else if (hitUpto == topN) {
|
||||
// collected topN document, I create the heap
|
||||
heapify(reranked, topN);
|
||||
}
|
||||
if (hitUpto >= topN) {
|
||||
// once that heap is ready, if the score of this document is lower that
|
||||
// the minimum
|
||||
// i don't want to log the feature. Otherwise I replace it with the
|
||||
// minimum and fix the
|
||||
// heap.
|
||||
if (hit.score > reranked[0].score) {
|
||||
reranked[0] = hit;
|
||||
heapAdjust(reranked, topN, 0);
|
||||
if (featureLogger != null && indexSearcher instanceof SolrIndexSearcher) {
|
||||
featureLogger.log(hit.doc, scoringQuery, (SolrIndexSearcher)indexSearcher,
|
||||
modelWeight.getFeaturesInfo());
|
||||
}
|
||||
}
|
||||
}
|
||||
hitUpto++;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation explain(IndexSearcher searcher,
|
||||
Explanation firstPassExplanation, int docID) throws IOException {
|
||||
return getExplanation(searcher, docID, scoringQuery);
|
||||
}
|
||||
|
||||
protected static Explanation getExplanation(IndexSearcher searcher, int docID, LTRScoringQuery rerankingQuery) throws IOException {
|
||||
final List<LeafReaderContext> leafContexts = searcher.getTopReaderContext()
|
||||
.leaves();
|
||||
final int n = ReaderUtil.subIndex(docID, leafContexts);
|
||||
final LeafReaderContext context = leafContexts.get(n);
|
||||
final int deBasedDoc = docID - context.docBase;
|
||||
final Weight modelWeight = searcher.createWeight(searcher.rewrite(scoringQuery),
|
||||
ScoreMode.COMPLETE, 1);
|
||||
return modelWeight.explain(context, deBasedDoc);
|
||||
final Weight rankingWeight;
|
||||
if (rerankingQuery instanceof OriginalRankingLTRScoringQuery) {
|
||||
rankingWeight = rerankingQuery.getOriginalQuery().createWeight(searcher, ScoreMode.COMPLETE, 1);
|
||||
} else {
|
||||
rankingWeight = searcher.createWeight(searcher.rewrite(rerankingQuery),
|
||||
ScoreMode.COMPLETE, 1);
|
||||
}
|
||||
return rankingWeight.explain(context, deBasedDoc);
|
||||
}
|
||||
|
||||
public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo(LTRScoringQuery.ModelWeight modelWeight,
|
||||
|
|
|
@ -102,6 +102,10 @@ public class LTRScoringQuery extends Query implements Accountable {
|
|||
return ltrScoringModel;
|
||||
}
|
||||
|
||||
public String getScoringModelName() {
|
||||
return ltrScoringModel.getName();
|
||||
}
|
||||
|
||||
public void setFeatureLogger(FeatureLogger fl) {
|
||||
this.fl = fl;
|
||||
}
|
||||
|
|
|
@ -26,8 +26,8 @@ public class SolrQueryRequestContextUtils {
|
|||
/** key of the feature logger in the request context **/
|
||||
private static final String FEATURE_LOGGER = LTR_PREFIX + "feature_logger";
|
||||
|
||||
/** key of the scoring query in the request context **/
|
||||
private static final String SCORING_QUERY = LTR_PREFIX + "scoring_query";
|
||||
/** key of the scoring queries in the request context **/
|
||||
private static final String SCORING_QUERIES = LTR_PREFIX + "scoring_queries";
|
||||
|
||||
/** key of the isExtractingFeatures flag in the request context **/
|
||||
private static final String IS_EXTRACTING_FEATURES = LTR_PREFIX + "isExtractingFeatures";
|
||||
|
@ -47,12 +47,12 @@ public class SolrQueryRequestContextUtils {
|
|||
|
||||
/** scoring query accessors **/
|
||||
|
||||
public static void setScoringQuery(SolrQueryRequest req, LTRScoringQuery scoringQuery) {
|
||||
req.getContext().put(SCORING_QUERY, scoringQuery);
|
||||
public static void setScoringQueries(SolrQueryRequest req, LTRScoringQuery[] scoringQueries) {
|
||||
req.getContext().put(SCORING_QUERIES, scoringQueries);
|
||||
}
|
||||
|
||||
public static LTRScoringQuery getScoringQuery(SolrQueryRequest req) {
|
||||
return (LTRScoringQuery) req.getContext().get(SCORING_QUERY);
|
||||
public static LTRScoringQuery[] getScoringQueries(SolrQueryRequest req) {
|
||||
return (LTRScoringQuery[]) req.getContext().get(SCORING_QUERIES);
|
||||
}
|
||||
|
||||
/** isExtractingFeatures flag accessors **/
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.solr.ltr.interleaving;
|
||||
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.solr.ltr.interleaving.algorithms.TeamDraftInterleaving;
|
||||
|
||||
/**
|
||||
* Interleaving considers two ranking models: modelA and modelB.
|
||||
* For a given query, each model returns its ranked list of documents La = (a1,a2,...) and Lb = (b1, b2, ...).
|
||||
* An Interleaving algorithm creates a unique ranked list I = (i1, i2, ...).
|
||||
* This list is created by interleaving elements from the two lists la and lb as described by the implementation algorithm.
|
||||
* Each element Ij is labelled TeamA if it is selected from La and TeamB if it is selected from Lb.
|
||||
*/
|
||||
public interface Interleaving {
|
||||
|
||||
String TEAM_DRAFT = "TeamDraft";
|
||||
|
||||
InterleavingResult interleave(ScoreDoc[] rerankedA, ScoreDoc[] rerankedB);
|
||||
|
||||
static Interleaving getImplementation(String algorithm) {
|
||||
switch(algorithm) {
|
||||
case TEAM_DRAFT:
|
||||
default:
|
||||
return new TeamDraftInterleaving();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.solr.ltr.interleaving;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Set;
|
||||
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
|
||||
public class InterleavingResult {
|
||||
final private ScoreDoc[] interleavedResults;
|
||||
final private ArrayList<Set<Integer>> interleavingPicks;
|
||||
|
||||
public InterleavingResult(ScoreDoc[] interleavedResults, ArrayList<Set<Integer>> interleavingPicks) {
|
||||
this.interleavedResults = interleavedResults;
|
||||
this.interleavingPicks = interleavingPicks;
|
||||
}
|
||||
|
||||
public ScoreDoc[] getInterleavedResults() {
|
||||
return interleavedResults;
|
||||
}
|
||||
|
||||
public ArrayList<Set<Integer>> getInterleavingPicks() {
|
||||
return interleavingPicks;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.solr.ltr.interleaving;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.solr.ltr.search.LTRQuery;
|
||||
import org.apache.solr.search.RankQuery;
|
||||
|
||||
/**
|
||||
* A learning to rank Query with Interleaving, will incapsulate two models, and delegate to it the rescoring
|
||||
* of the documents.
|
||||
**/
|
||||
public class LTRInterleavingQuery extends LTRQuery {
|
||||
private final LTRInterleavingScoringQuery[] rerankingQueries;
|
||||
private final Interleaving interlavingAlgorithm;
|
||||
|
||||
public LTRInterleavingQuery(Interleaving interleavingAlgorithm, LTRInterleavingScoringQuery[] rerankingQueries, int rerankDocs) {
|
||||
super(null, rerankDocs, new LTRInterleavingRescorer(interleavingAlgorithm, rerankingQueries));
|
||||
this.rerankingQueries = rerankingQueries;
|
||||
this.interlavingAlgorithm = interleavingAlgorithm;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return 31 * classHash() + (mainQuery.hashCode() + rerankingQueries.hashCode() + reRankDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
return sameClassAs(o) && equalsTo(getClass().cast(o));
|
||||
}
|
||||
|
||||
private boolean equalsTo(LTRInterleavingQuery other) {
|
||||
return (mainQuery.equals(other.mainQuery)
|
||||
&& rerankingQueries.equals(other.rerankingQueries) && (reRankDocs == other.reRankDocs));
|
||||
}
|
||||
|
||||
@Override
|
||||
public RankQuery wrap(Query _mainQuery) {
|
||||
super.wrap(_mainQuery);
|
||||
for(LTRInterleavingScoringQuery rerankingQuery: rerankingQueries){
|
||||
rerankingQuery.setOriginalQuery(_mainQuery);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return "{!ltr mainQuery='" + mainQuery.toString() + "' rerankingQueries='"
|
||||
+ Arrays.toString(rerankingQueries) + "' reRankDocs=" + reRankDocs + "}";
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Query rewrite(Query rewrittenMainQuery) throws IOException {
|
||||
return new LTRInterleavingQuery(interlavingAlgorithm, rerankingQueries, reRankDocs).wrap(rewrittenMainQuery);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,162 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.ltr.interleaving;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.ScoreMode;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.solr.ltr.LTRRescorer;
|
||||
import org.apache.solr.ltr.LTRScoringQuery;
|
||||
|
||||
/**
|
||||
* Implements the rescoring logic. The top documents returned by solr with their
|
||||
* original scores, will be processed by a {@link LTRScoringQuery} that will assign a
|
||||
* new score to each document. The top documents will be resorted based on the
|
||||
* new score.
|
||||
* */
|
||||
public class LTRInterleavingRescorer extends LTRRescorer {
|
||||
|
||||
final private LTRInterleavingScoringQuery[] rerankingQueries;
|
||||
private Integer originalRankingIndex = null;
|
||||
final private Interleaving interleavingAlgorithm;
|
||||
|
||||
public LTRInterleavingRescorer( Interleaving interleavingAlgorithm, LTRInterleavingScoringQuery[] rerankingQueries) {
|
||||
this.rerankingQueries = rerankingQueries;
|
||||
this.interleavingAlgorithm = interleavingAlgorithm;
|
||||
for(int i=0;i<this.rerankingQueries.length;i++){
|
||||
if(this.rerankingQueries[i] instanceof OriginalRankingLTRScoringQuery){
|
||||
this.originalRankingIndex = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* rescores the documents:
|
||||
*
|
||||
* @param searcher
|
||||
* current IndexSearcher
|
||||
* @param firstPassTopDocs
|
||||
* documents to rerank;
|
||||
* @param topN
|
||||
* documents to return;
|
||||
*/
|
||||
@Override
|
||||
public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs,
|
||||
int topN) throws IOException {
|
||||
if ((topN == 0) || (firstPassTopDocs.scoreDocs.length == 0)) {
|
||||
return firstPassTopDocs;
|
||||
}
|
||||
|
||||
ScoreDoc[] firstPassResults = null;
|
||||
if(originalRankingIndex != null) {
|
||||
firstPassResults = new ScoreDoc[firstPassTopDocs.scoreDocs.length];
|
||||
System.arraycopy(firstPassTopDocs.scoreDocs, 0, firstPassResults, 0, firstPassTopDocs.scoreDocs.length);
|
||||
}
|
||||
topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value));
|
||||
|
||||
ScoreDoc[][] reRankedPerModel = rerank(searcher,topN,getFirstPassDocsRanked(firstPassTopDocs));
|
||||
if (originalRankingIndex != null) {
|
||||
reRankedPerModel[originalRankingIndex] = firstPassResults;
|
||||
}
|
||||
InterleavingResult interleaved = interleavingAlgorithm.interleave(reRankedPerModel[0], reRankedPerModel[1]);
|
||||
ScoreDoc[] interleavedResults = interleaved.getInterleavedResults();
|
||||
|
||||
ArrayList<Set<Integer>> interleavingPicks = interleaved.getInterleavingPicks();
|
||||
rerankingQueries[0].setPickedInterleavingDocIds(interleavingPicks.get(0));
|
||||
rerankingQueries[1].setPickedInterleavingDocIds(interleavingPicks.get(1));
|
||||
|
||||
return new TopDocs(firstPassTopDocs.totalHits, interleavedResults);
|
||||
}
|
||||
|
||||
private ScoreDoc[][] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPassResults) throws IOException {
|
||||
ScoreDoc[][] reRankedPerModel = new ScoreDoc[rerankingQueries.length][topN];
|
||||
final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
|
||||
LTRScoringQuery.ModelWeight[] modelWeights = new LTRScoringQuery.ModelWeight[rerankingQueries.length];
|
||||
for (int i = 0; i < rerankingQueries.length; i++) {
|
||||
if (originalRankingIndex == null || originalRankingIndex != i) {
|
||||
modelWeights[i] = (LTRScoringQuery.ModelWeight) searcher
|
||||
.createWeight(searcher.rewrite(rerankingQueries[i]), ScoreMode.COMPLETE, 1);
|
||||
}
|
||||
}
|
||||
scoreFeatures(searcher, topN, modelWeights, firstPassResults, leaves, reRankedPerModel);
|
||||
|
||||
for (int i = 0; i < rerankingQueries.length; i++) {
|
||||
if (originalRankingIndex == null || originalRankingIndex != i) {
|
||||
sortByScore(reRankedPerModel[i]);
|
||||
}
|
||||
}
|
||||
|
||||
return reRankedPerModel;
|
||||
}
|
||||
|
||||
public void scoreFeatures(IndexSearcher indexSearcher,
|
||||
int topN, LTRScoringQuery.ModelWeight[] modelWeights, ScoreDoc[] hits, List<LeafReaderContext> leaves,
|
||||
ScoreDoc[][] rerankedPerModel) throws IOException {
|
||||
|
||||
int readerUpto = -1;
|
||||
int endDoc = 0;
|
||||
int docBase = 0;
|
||||
int hitUpto = 0;
|
||||
LTRScoringQuery.ModelWeight.ModelScorer[] scorers = new LTRScoringQuery.ModelWeight.ModelScorer[rerankingQueries.length];
|
||||
while (hitUpto < hits.length) {
|
||||
final ScoreDoc hit = hits[hitUpto];
|
||||
final int docID = hit.doc;
|
||||
LeafReaderContext readerContext = null;
|
||||
while (docID >= endDoc) {
|
||||
readerUpto++;
|
||||
readerContext = leaves.get(readerUpto);
|
||||
endDoc = readerContext.docBase + readerContext.reader().maxDoc();
|
||||
}
|
||||
|
||||
// We advanced to another segment
|
||||
if (readerContext != null) {
|
||||
docBase = readerContext.docBase;
|
||||
for (int i = 0; i < modelWeights.length; i++) {
|
||||
if (modelWeights[i] != null) {
|
||||
scorers[i] = modelWeights[i].scorer(readerContext);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < rerankingQueries.length; i++) {
|
||||
if (modelWeights[i] != null) {
|
||||
scoreSingleHit(indexSearcher, topN, modelWeights[i], docBase, hitUpto, new ScoreDoc(hit.doc, hit.score, hit.shardIndex), docID, rerankingQueries[i], scorers[i], rerankedPerModel[i]);
|
||||
}
|
||||
}
|
||||
hitUpto++;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation explain(IndexSearcher searcher,
|
||||
Explanation firstPassExplanation, int docID) throws IOException {
|
||||
LTRScoringQuery pickedRerankModel = rerankingQueries[0];
|
||||
if (rerankingQueries[1].getPickedInterleavingDocIds().contains(docID)) {
|
||||
pickedRerankModel = rerankingQueries[1];
|
||||
}
|
||||
return getExplanation(searcher, docID, pickedRerankModel);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.ltr.interleaving;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import org.apache.solr.ltr.LTRScoringQuery;
|
||||
import org.apache.solr.ltr.LTRThreadModule;
|
||||
import org.apache.solr.ltr.model.LTRScoringModel;
|
||||
|
||||
public class LTRInterleavingScoringQuery extends LTRScoringQuery {
|
||||
|
||||
// Model was picked for this Docs
|
||||
private Set<Integer> pickedInterleavingDocIds;
|
||||
|
||||
public LTRInterleavingScoringQuery(LTRScoringModel ltrScoringModel) {
|
||||
super(ltrScoringModel);
|
||||
}
|
||||
|
||||
public LTRInterleavingScoringQuery(LTRScoringModel ltrScoringModel, boolean extractAllFeatures) {
|
||||
super(ltrScoringModel, extractAllFeatures);
|
||||
}
|
||||
|
||||
public LTRInterleavingScoringQuery(LTRScoringModel ltrScoringModel,
|
||||
Map<String, String[]> externalFeatureInfo,
|
||||
boolean extractAllFeatures, LTRThreadModule ltrThreadMgr) {
|
||||
super(ltrScoringModel, externalFeatureInfo, extractAllFeatures, ltrThreadMgr);
|
||||
}
|
||||
|
||||
public Set<Integer> getPickedInterleavingDocIds() {
|
||||
return pickedInterleavingDocIds;
|
||||
}
|
||||
|
||||
public void setPickedInterleavingDocIds(Set<Integer> pickedInterleavingDocIds) {
|
||||
this.pickedInterleavingDocIds = pickedInterleavingDocIds;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.ltr.interleaving;
|
||||
|
||||
public final class OriginalRankingLTRScoringQuery extends LTRInterleavingScoringQuery {
|
||||
|
||||
private final String originalRankingModelName;
|
||||
|
||||
public OriginalRankingLTRScoringQuery(String originalRankingModelName) {
|
||||
super(null /* LTRScoringModel */);
|
||||
this.originalRankingModelName = originalRankingModelName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getScoringModelName() {
|
||||
return this.originalRankingModelName;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,127 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.solr.ltr.interleaving.algorithms;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedList;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.solr.ltr.interleaving.Interleaving;
|
||||
import org.apache.solr.ltr.interleaving.InterleavingResult;
|
||||
|
||||
/**
|
||||
* Interleaving was introduced the first time by Joachims in [1, 2].
|
||||
* Team Draft Interleaving is among the most successful and used interleaving approaches[3].
|
||||
* Team Draft Interleaving implements a method similar to the way in which captains select their players in team-matches.
|
||||
* Team Draft Interleaving produces a fair distribution of ranking models’ elements in the final interleaved list.
|
||||
* "Team draft interleaving" has also proved to overcome an issue of the "Balanced interleaving" approach, in determining the winning model[4].
|
||||
* <p>
|
||||
* [1] T. Joachims. Optimizing search engines using clickthrough data. KDD (2002)
|
||||
* [2] T.Joachims.Evaluatingretrievalperformanceusingclickthroughdata.InJ.Franke, G. Nakhaeizadeh, and I. Renz, editors,
|
||||
* Text Mining, pages 79–96. Physica/Springer (2003)
|
||||
* [3] F. Radlinski, M. Kurup, and T. Joachims. How does clickthrough data reflect re-
|
||||
* trieval quality? In CIKM, pages 43–52. ACM Press (2008)
|
||||
* [4] O. Chapelle, T. Joachims, F. Radlinski, and Y. Yue.
|
||||
* Large-scale validation and analysis of interleaved search evaluation. ACM TOIS, 30(1):1–41, Feb. (2012)
|
||||
*/
|
||||
public class TeamDraftInterleaving implements Interleaving {
|
||||
public static Random RANDOM;
|
||||
|
||||
static {
|
||||
// We try to make things reproducible in the context of our tests by initializing the random instance
|
||||
// based on the current seed
|
||||
String seed = System.getProperty("tests.seed");
|
||||
if (seed == null) {
|
||||
RANDOM = new Random();
|
||||
} else {
|
||||
RANDOM = new Random(seed.hashCode());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Team Draft Interleaving considers two ranking models: modelA and modelB.
|
||||
* For a given query, each model returns its ranked list of documents La = (a1,a2,...) and Lb = (b1, b2, ...).
|
||||
* The algorithm creates a unique ranked list I = (i1, i2, ...).
|
||||
* This list is created by interleaving elements from the two lists la and lb as described by Chapelle et al.[1].
|
||||
* Each element Ij is labelled TeamA if it is selected from La and TeamB if it is selected from Lb.
|
||||
* <p>
|
||||
* [1] O. Chapelle, T. Joachims, F. Radlinski, and Y. Yue.
|
||||
* Large-scale validation and analysis of interleaved search evaluation. ACM TOIS, 30(1):1–41, Feb. (2012)
|
||||
* <p>
|
||||
* Assumptions:
|
||||
* - rerankedA and rerankedB has the same length.
|
||||
* They contains the same search results, ranked differently by two ranking models
|
||||
* - each reranked list can not contain the same search result more than once.
|
||||
* - results are all from the same shard
|
||||
*
|
||||
* @param rerankedA a ranked list of search results produced by a ranking model A
|
||||
* @param rerankedB a ranked list of search results produced by a ranking model B
|
||||
* @return the interleaved ranking list
|
||||
*/
|
||||
public InterleavingResult interleave(ScoreDoc[] rerankedA, ScoreDoc[] rerankedB) {
|
||||
LinkedList<ScoreDoc> interleavedResults = new LinkedList<>();
|
||||
HashSet<Integer> alreadyAdded = new HashSet<>();
|
||||
ScoreDoc[] interleavedResultArray = new ScoreDoc[rerankedA.length];
|
||||
ArrayList<Set<Integer>> interleavingPicks = new ArrayList<>(2);
|
||||
Set<Integer> teamA = new HashSet<>();
|
||||
Set<Integer> teamB = new HashSet<>();
|
||||
int topN = rerankedA.length;
|
||||
int indexA = 0, indexB = 0;
|
||||
|
||||
while (interleavedResults.size() < topN && indexA < rerankedA.length && indexB < rerankedB.length) {
|
||||
if(teamA.size()<teamB.size() || (teamA.size()==teamB.size() && !RANDOM.nextBoolean())){
|
||||
indexA = updateIndex(alreadyAdded, indexA, rerankedA);
|
||||
interleavedResults.add(rerankedA[indexA]);
|
||||
alreadyAdded.add(rerankedA[indexA].doc);
|
||||
teamA.add(rerankedA[indexA].doc);
|
||||
indexA++;
|
||||
} else{
|
||||
indexB = updateIndex(alreadyAdded,indexB,rerankedB);
|
||||
interleavedResults.add(rerankedB[indexB]);
|
||||
alreadyAdded.add(rerankedB[indexB].doc);
|
||||
teamB.add(rerankedB[indexB].doc);
|
||||
indexB++;
|
||||
}
|
||||
}
|
||||
interleavingPicks.add(teamA);
|
||||
interleavingPicks.add(teamB);
|
||||
interleavedResultArray = interleavedResults.toArray(interleavedResultArray);
|
||||
|
||||
return new InterleavingResult(interleavedResultArray,interleavingPicks);
|
||||
}
|
||||
|
||||
private int updateIndex(HashSet<Integer> alreadyAdded, int index, ScoreDoc[] reranked) {
|
||||
boolean foundElementToAdd = false;
|
||||
while (index < reranked.length && !foundElementToAdd) {
|
||||
ScoreDoc elementToCheck = reranked[index];
|
||||
if (alreadyAdded.contains(elementToCheck.doc)) {
|
||||
index++;
|
||||
} else {
|
||||
foundElementToAdd = true;
|
||||
}
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
public static void setRANDOM(Random RANDOM) {
|
||||
TeamDraftInterleaving.RANDOM = RANDOM;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Contains Various Interleaving Algorithms
|
||||
*/
|
||||
package org.apache.solr.ltr.interleaving.algorithms;
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Contains Various Interleaving auxiliary classes
|
||||
*/
|
||||
package org.apache.solr.ltr.interleaving;
|
|
@ -36,6 +36,8 @@ import org.apache.solr.ltr.LTRScoringQuery;
|
|||
import org.apache.solr.ltr.LTRThreadModule;
|
||||
import org.apache.solr.ltr.SolrQueryRequestContextUtils;
|
||||
import org.apache.solr.ltr.feature.Feature;
|
||||
import org.apache.solr.ltr.interleaving.LTRInterleavingScoringQuery;
|
||||
import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery;
|
||||
import org.apache.solr.ltr.model.LTRScoringModel;
|
||||
import org.apache.solr.ltr.norm.Normalizer;
|
||||
import org.apache.solr.ltr.search.LTRQParserPlugin;
|
||||
|
@ -126,14 +128,15 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory {
|
|||
SolrQueryRequestContextUtils.setIsExtractingFeatures(req);
|
||||
|
||||
// Communicate which feature store we are requesting features for
|
||||
SolrQueryRequestContextUtils.setFvStoreName(req, localparams.get(FV_STORE, defaultStore));
|
||||
final String fvStoreName = localparams.get(FV_STORE);
|
||||
SolrQueryRequestContextUtils.setFvStoreName(req, (fvStoreName == null ? defaultStore : fvStoreName));
|
||||
|
||||
// Create and supply the feature logger to be used
|
||||
SolrQueryRequestContextUtils.setFeatureLogger(req,
|
||||
createFeatureLogger(
|
||||
localparams.get(FV_FORMAT)));
|
||||
|
||||
return new FeatureTransformer(name, localparams, req);
|
||||
return new FeatureTransformer(name, localparams, req, (fvStoreName != null) /* hasExplicitFeatureStore */);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -163,11 +166,18 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory {
|
|||
final private String name;
|
||||
final private SolrParams localparams;
|
||||
final private SolrQueryRequest req;
|
||||
final private boolean hasExplicitFeatureStore;
|
||||
|
||||
private List<LeafReaderContext> leafContexts;
|
||||
private SolrIndexSearcher searcher;
|
||||
private LTRScoringQuery scoringQuery;
|
||||
private LTRScoringQuery.ModelWeight modelWeight;
|
||||
/**
|
||||
* rerankingQueries, modelWeights have:
|
||||
* length=1 - [Classic LTR] When reranking with a single model
|
||||
* length=2 - [Interleaving] When reranking with interleaving (two ranking models are involved)
|
||||
*/
|
||||
private LTRScoringQuery[] rerankingQueriesFromContext;
|
||||
private LTRScoringQuery[] rerankingQueries;
|
||||
private LTRScoringQuery.ModelWeight[] modelWeights;
|
||||
private FeatureLogger featureLogger;
|
||||
private boolean docsWereNotReranked;
|
||||
|
||||
|
@ -177,10 +187,11 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory {
|
|||
* feature vectors
|
||||
*/
|
||||
public FeatureTransformer(String name, SolrParams localparams,
|
||||
SolrQueryRequest req) {
|
||||
SolrQueryRequest req, boolean hasExplicitFeatureStore) {
|
||||
this.name = name;
|
||||
this.localparams = localparams;
|
||||
this.req = req;
|
||||
this.hasExplicitFeatureStore = hasExplicitFeatureStore;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -209,51 +220,102 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory {
|
|||
threadManager.setExecutor(context.getRequest().getCore().getCoreContainer().getUpdateShardHandler().getUpdateExecutor());
|
||||
}
|
||||
|
||||
// Setup LTRScoringQuery
|
||||
scoringQuery = SolrQueryRequestContextUtils.getScoringQuery(req);
|
||||
docsWereNotReranked = (scoringQuery == null);
|
||||
String featureStoreName = SolrQueryRequestContextUtils.getFvStoreName(req);
|
||||
if (docsWereNotReranked || (featureStoreName != null && (!featureStoreName.equals(scoringQuery.getScoringModel().getFeatureStoreName())))) {
|
||||
// if store is set in the transformer we should overwrite the logger
|
||||
rerankingQueriesFromContext = SolrQueryRequestContextUtils.getScoringQueries(req);
|
||||
docsWereNotReranked = (rerankingQueriesFromContext == null || rerankingQueriesFromContext.length == 0);
|
||||
String transformerFeatureStore = SolrQueryRequestContextUtils.getFvStoreName(req);
|
||||
Map<String, String[]> transformerExternalFeatureInfo = LTRQParserPlugin.extractEFIParams(localparams);
|
||||
|
||||
final ManagedFeatureStore fr = ManagedFeatureStore.getManagedFeatureStore(req.getCore());
|
||||
final LoggingModel loggingModel = createLoggingModel(transformerFeatureStore);
|
||||
setupRerankingQueriesForLogging(transformerFeatureStore, transformerExternalFeatureInfo, loggingModel);
|
||||
setupRerankingWeightsForLogging(context);
|
||||
}
|
||||
|
||||
final FeatureStore store = fr.getFeatureStore(featureStoreName);
|
||||
featureStoreName = store.getName(); // if featureStoreName was null before this gets actual name
|
||||
/**
|
||||
* The loggingModel is an empty model that is just used to extract the features
|
||||
* and log them
|
||||
* @param transformerFeatureStore the explicit transformer feature store
|
||||
*/
|
||||
private LoggingModel createLoggingModel(String transformerFeatureStore) {
|
||||
final ManagedFeatureStore fr = ManagedFeatureStore.getManagedFeatureStore(req.getCore());
|
||||
|
||||
try {
|
||||
final LoggingModel lm = new LoggingModel(loggingModelName,
|
||||
featureStoreName, store.getFeatures());
|
||||
final FeatureStore store = fr.getFeatureStore(transformerFeatureStore);
|
||||
transformerFeatureStore = store.getName(); // if transformerFeatureStore was null before this gets actual name
|
||||
|
||||
scoringQuery = new LTRScoringQuery(lm,
|
||||
LTRQParserPlugin.extractEFIParams(localparams),
|
||||
true,
|
||||
threadManager); // request feature weights to be created for all features
|
||||
return new LoggingModel(loggingModelName,
|
||||
transformerFeatureStore, store.getFeatures());
|
||||
}
|
||||
|
||||
}catch (final Exception e) {
|
||||
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
|
||||
"retrieving the feature store "+featureStoreName, e);
|
||||
/**
|
||||
* When preparing the reranking queries for logging features various scenarios apply:
|
||||
*
|
||||
* No Reranking
|
||||
* There is the need of a logger model from the default feature store or the explicit feature store passed
|
||||
* to extract the feature vector
|
||||
*
|
||||
* Re Ranking
|
||||
* 1) If no explicit feature store is passed, the models for each reranking query can be safely re-used
|
||||
* the feature vector can be fetched from the feature vector cache.
|
||||
* 2) If an explicit feature store is passed, and no reranking query uses a model with that feature store,
|
||||
* There is the need of a logger model to extract the feature vector
|
||||
* 3) If an explicit feature store is passed, and there is a reranking query that uses a model with that feature store,
|
||||
* the model can be re-used and there is no need for a logging model
|
||||
*
|
||||
* @param transformerFeatureStore explicit feature store for the transformer
|
||||
* @param transformerExternalFeatureInfo explicit efi for the transformer
|
||||
*/
|
||||
private void setupRerankingQueriesForLogging(String transformerFeatureStore, Map<String, String[]> transformerExternalFeatureInfo, LoggingModel loggingModel) {
|
||||
if (docsWereNotReranked) { //no reranking query
|
||||
LTRScoringQuery loggingQuery = new LTRScoringQuery(loggingModel,
|
||||
transformerExternalFeatureInfo,
|
||||
true /* extractAllFeatures */,
|
||||
threadManager);
|
||||
rerankingQueries = new LTRScoringQuery[]{loggingQuery};
|
||||
} else {
|
||||
rerankingQueries = new LTRScoringQuery[rerankingQueriesFromContext.length];
|
||||
System.arraycopy(rerankingQueriesFromContext, 0, rerankingQueries, 0, rerankingQueriesFromContext.length);
|
||||
|
||||
if (transformerFeatureStore != null) {// explicit feature store for the transformer
|
||||
LTRScoringModel matchingRerankingModel = loggingModel;
|
||||
for (LTRScoringQuery rerankingQuery : rerankingQueries) {
|
||||
if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) &&
|
||||
transformerFeatureStore.equals(rerankingQuery.getScoringModel().getFeatureStoreName())) {
|
||||
matchingRerankingModel = rerankingQuery.getScoringModel();
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < rerankingQueries.length; i++) {
|
||||
rerankingQueries[i] = new LTRScoringQuery(
|
||||
matchingRerankingModel,
|
||||
(!transformerExternalFeatureInfo.isEmpty() ? transformerExternalFeatureInfo : rerankingQueries[i].getExternalFeatureInfo()),
|
||||
true /* extractAllFeatures */,
|
||||
threadManager);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (scoringQuery.getOriginalQuery() == null) {
|
||||
scoringQuery.setOriginalQuery(context.getQuery());
|
||||
}
|
||||
if (scoringQuery.getFeatureLogger() == null){
|
||||
scoringQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) );
|
||||
}
|
||||
scoringQuery.setRequest(req);
|
||||
|
||||
featureLogger = scoringQuery.getFeatureLogger();
|
||||
|
||||
try {
|
||||
modelWeight = scoringQuery.createWeight(searcher, ScoreMode.COMPLETE, 1f);
|
||||
} catch (final IOException e) {
|
||||
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e.getMessage(), e);
|
||||
}
|
||||
if (modelWeight == null) {
|
||||
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
|
||||
"error logging the features, model weight is null");
|
||||
private void setupRerankingWeightsForLogging(ResultContext context) {
|
||||
modelWeights = new LTRScoringQuery.ModelWeight[rerankingQueries.length];
|
||||
for (int i = 0; i < rerankingQueries.length; i++) {
|
||||
if (rerankingQueries[i].getOriginalQuery() == null) {
|
||||
rerankingQueries[i].setOriginalQuery(context.getQuery());
|
||||
}
|
||||
rerankingQueries[i].setRequest(req);
|
||||
if (!(rerankingQueries[i] instanceof OriginalRankingLTRScoringQuery) || hasExplicitFeatureStore) {
|
||||
if (rerankingQueries[i].getFeatureLogger() == null) {
|
||||
rerankingQueries[i].setFeatureLogger(SolrQueryRequestContextUtils.getFeatureLogger(req));
|
||||
}
|
||||
featureLogger = rerankingQueries[i].getFeatureLogger();
|
||||
try {
|
||||
modelWeights[i] = rerankingQueries[i].createWeight(searcher, ScoreMode.COMPLETE, 1f);
|
||||
} catch (final IOException e) {
|
||||
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e.getMessage(), e);
|
||||
}
|
||||
if (modelWeights[i] == null) {
|
||||
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
|
||||
"error logging the features, model weight is null");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -271,17 +333,26 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory {
|
|||
|
||||
private void implTransform(SolrDocument doc, int docid, Float score)
|
||||
throws IOException {
|
||||
Object fv = featureLogger.getFeatureVector(docid, scoringQuery, searcher);
|
||||
if (fv == null) { // FV for this document was not in the cache
|
||||
fv = featureLogger.makeFeatureVector(
|
||||
LTRRescorer.extractFeaturesInfo(
|
||||
modelWeight,
|
||||
docid,
|
||||
(docsWereNotReranked ? score : null),
|
||||
leafContexts));
|
||||
LTRScoringQuery rerankingQuery = rerankingQueries[0];
|
||||
LTRScoringQuery.ModelWeight rerankingModelWeight = modelWeights[0];
|
||||
for (int i = 1; i < rerankingQueries.length; i++) {
|
||||
if (((LTRInterleavingScoringQuery)rerankingQueriesFromContext[i]).getPickedInterleavingDocIds().contains(docid)) {
|
||||
rerankingQuery = rerankingQueries[i];
|
||||
rerankingModelWeight = modelWeights[i];
|
||||
}
|
||||
}
|
||||
if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) || hasExplicitFeatureStore) {
|
||||
Object featureVector = featureLogger.getFeatureVector(docid, rerankingQuery, searcher);
|
||||
if (featureVector == null) { // FV for this document was not in the cache
|
||||
featureVector = featureLogger.makeFeatureVector(
|
||||
LTRRescorer.extractFeaturesInfo(
|
||||
rerankingModelWeight,
|
||||
docid,
|
||||
(docsWereNotReranked ? score : null),
|
||||
leafContexts));
|
||||
}
|
||||
doc.addField(name, featureVector);
|
||||
}
|
||||
|
||||
doc.addField(name, fv);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.ltr.response.transform;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.solr.common.SolrDocument;
|
||||
import org.apache.solr.common.params.SolrParams;
|
||||
import org.apache.solr.common.util.NamedList;
|
||||
import org.apache.solr.ltr.interleaving.LTRInterleavingScoringQuery;
|
||||
import org.apache.solr.ltr.LTRScoringQuery;
|
||||
import org.apache.solr.ltr.SolrQueryRequestContextUtils;
|
||||
import org.apache.solr.request.SolrQueryRequest;
|
||||
import org.apache.solr.response.ResultContext;
|
||||
import org.apache.solr.response.transform.DocTransformer;
|
||||
import org.apache.solr.response.transform.TransformerFactory;
|
||||
import org.apache.solr.util.SolrPluginUtils;
|
||||
|
||||
public class LTRInterleavingTransformerFactory extends TransformerFactory {
|
||||
|
||||
@Override
|
||||
@SuppressWarnings({"unchecked"})
|
||||
public void init(@SuppressWarnings("rawtypes") NamedList args) {
|
||||
super.init(args);
|
||||
SolrPluginUtils.invokeSetters(this, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocTransformer create(String name, SolrParams localparams,
|
||||
SolrQueryRequest req) {
|
||||
return new InterleavingTransformer(name, req);
|
||||
}
|
||||
|
||||
class InterleavingTransformer extends DocTransformer {
|
||||
|
||||
final private String name;
|
||||
final private SolrQueryRequest req;
|
||||
|
||||
private LTRInterleavingScoringQuery[] rerankingQueries;
|
||||
|
||||
/**
|
||||
* @param name
|
||||
* Name of the field to be added in a document representing the
|
||||
* model picked by the interleaving process
|
||||
*/
|
||||
public InterleavingTransformer(String name,
|
||||
SolrQueryRequest req) {
|
||||
this.name = name;
|
||||
this.req = req;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setContext(ResultContext context) {
|
||||
super.setContext(context);
|
||||
if (context == null) {
|
||||
return;
|
||||
}
|
||||
if (context.getRequest() == null) {
|
||||
return;
|
||||
}
|
||||
rerankingQueries = (LTRInterleavingScoringQuery[])SolrQueryRequestContextUtils.getScoringQueries(req);
|
||||
for (int i = 0; i < rerankingQueries.length; i++) {
|
||||
LTRScoringQuery scoringQuery = rerankingQueries[i];
|
||||
|
||||
if (scoringQuery.getOriginalQuery() == null) {
|
||||
scoringQuery.setOriginalQuery(context.getQuery());
|
||||
}
|
||||
if (scoringQuery.getFeatureLogger() == null) {
|
||||
scoringQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) );
|
||||
}
|
||||
scoringQuery.setRequest(req);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void transform(SolrDocument doc, int docid, float score)
|
||||
throws IOException {
|
||||
implTransform(doc, docid);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void transform(SolrDocument doc, int docid)
|
||||
throws IOException {
|
||||
implTransform(doc, docid);
|
||||
}
|
||||
|
||||
private void implTransform(SolrDocument doc, int docid) {
|
||||
LTRScoringQuery rerankingQuery = rerankingQueries[0];
|
||||
if (rerankingQueries.length > 1 && rerankingQueries[1].getPickedInterleavingDocIds().contains(docid)) {
|
||||
rerankingQuery = rerankingQueries[1];
|
||||
}
|
||||
doc.addField(name, rerankingQuery.getScoringModelName());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -23,26 +23,26 @@ import java.util.Map;
|
|||
|
||||
import org.apache.lucene.util.ResourceLoader;
|
||||
import org.apache.lucene.util.ResourceLoaderAware;
|
||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.solr.common.SolrException;
|
||||
import org.apache.solr.common.params.SolrParams;
|
||||
import org.apache.solr.common.util.NamedList;
|
||||
import org.apache.solr.core.SolrResourceLoader;
|
||||
import org.apache.solr.ltr.LTRRescorer;
|
||||
import org.apache.solr.ltr.interleaving.LTRInterleavingScoringQuery;
|
||||
import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery;
|
||||
import org.apache.solr.ltr.LTRScoringQuery;
|
||||
import org.apache.solr.ltr.LTRThreadModule;
|
||||
import org.apache.solr.ltr.SolrQueryRequestContextUtils;
|
||||
import org.apache.solr.ltr.interleaving.Interleaving;
|
||||
import org.apache.solr.ltr.interleaving.LTRInterleavingQuery;
|
||||
import org.apache.solr.ltr.model.LTRScoringModel;
|
||||
import org.apache.solr.ltr.store.rest.ManagedFeatureStore;
|
||||
import org.apache.solr.ltr.store.rest.ManagedModelStore;
|
||||
import org.apache.solr.request.SolrQueryRequest;
|
||||
import org.apache.solr.rest.ManagedResource;
|
||||
import org.apache.solr.rest.ManagedResourceObserver;
|
||||
import org.apache.solr.search.AbstractReRankQuery;
|
||||
import org.apache.solr.search.QParser;
|
||||
import org.apache.solr.search.QParserPlugin;
|
||||
import org.apache.solr.search.RankQuery;
|
||||
import org.apache.solr.search.SyntaxError;
|
||||
import org.apache.solr.util.SolrPluginUtils;
|
||||
|
||||
|
@ -55,7 +55,7 @@ import org.apache.solr.util.SolrPluginUtils;
|
|||
*/
|
||||
public class LTRQParserPlugin extends QParserPlugin implements ResourceLoaderAware, ManagedResourceObserver {
|
||||
public static final String NAME = "ltr";
|
||||
private static Query defaultQuery = new MatchAllDocsQuery();
|
||||
private static final String ORIGINAL_RANKING = "_OriginalRanking_";
|
||||
|
||||
// params for setting custom external info that features can use, like query
|
||||
// intent
|
||||
|
@ -145,94 +145,73 @@ public class LTRQParserPlugin extends QParserPlugin implements ResourceLoaderAwa
|
|||
|
||||
@Override
|
||||
public Query parse() throws SyntaxError {
|
||||
// ReRanking Model
|
||||
final String modelName = localParams.get(LTRQParserPlugin.MODEL);
|
||||
if ((modelName == null) || modelName.isEmpty()) {
|
||||
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
|
||||
"Must provide model in the request");
|
||||
}
|
||||
|
||||
final LTRScoringModel ltrScoringModel = mr.getModel(modelName);
|
||||
if (ltrScoringModel == null) {
|
||||
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
|
||||
"cannot find " + LTRQParserPlugin.MODEL + " " + modelName);
|
||||
}
|
||||
|
||||
final String modelFeatureStoreName = ltrScoringModel.getFeatureStoreName();
|
||||
final boolean extractFeatures = SolrQueryRequestContextUtils.isExtractingFeatures(req);
|
||||
final String fvStoreName = SolrQueryRequestContextUtils.getFvStoreName(req);
|
||||
// Check if features are requested and if the model feature store and feature-transform feature store are the same
|
||||
final boolean featuresRequestedFromSameStore = (modelFeatureStoreName.equals(fvStoreName) || fvStoreName == null) ? extractFeatures:false;
|
||||
if (threadManager != null) {
|
||||
threadManager.setExecutor(req.getCore().getCoreContainer().getUpdateShardHandler().getUpdateExecutor());
|
||||
}
|
||||
final LTRScoringQuery scoringQuery = new LTRScoringQuery(ltrScoringModel,
|
||||
extractEFIParams(localParams),
|
||||
featuresRequestedFromSameStore, threadManager);
|
||||
|
||||
// Enable the feature vector caching if we are extracting features, and the features
|
||||
// we requested are the same ones we are reranking with
|
||||
if (featuresRequestedFromSameStore) {
|
||||
scoringQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) );
|
||||
// ReRanking Model
|
||||
final String[] modelNames = localParams.getParams(LTRQParserPlugin.MODEL);
|
||||
if ((modelNames == null) || (modelNames.length!=1 && modelNames.length!=2)) {
|
||||
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
|
||||
"Must provide one or two models in the request");
|
||||
}
|
||||
final boolean isInterleaving = (modelNames.length > 1);
|
||||
final boolean extractFeatures = SolrQueryRequestContextUtils.isExtractingFeatures(req);
|
||||
final String tranformerFeatureStoreName = SolrQueryRequestContextUtils.getFvStoreName(req);
|
||||
final Map<String,String[]> externalFeatureInfo = extractEFIParams(localParams);
|
||||
|
||||
LTRScoringQuery rerankingQuery = null;
|
||||
LTRInterleavingScoringQuery[] rerankingQueries = new LTRInterleavingScoringQuery[modelNames.length];
|
||||
for (int i = 0; i < modelNames.length; i++) {
|
||||
if (modelNames[i].isEmpty()) {
|
||||
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
|
||||
"the " + LTRQParserPlugin.MODEL + " "+ i +" is empty");
|
||||
}
|
||||
if (!ORIGINAL_RANKING.equals(modelNames[i])) {
|
||||
final LTRScoringModel ltrScoringModel = mr.getModel(modelNames[i]);
|
||||
if (ltrScoringModel == null) {
|
||||
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
|
||||
"cannot find " + LTRQParserPlugin.MODEL + " " + modelNames[i]);
|
||||
}
|
||||
final String modelFeatureStoreName = ltrScoringModel.getFeatureStoreName();
|
||||
// Check if features are requested and if the model feature store and feature-transform feature store are the same
|
||||
final boolean featuresRequestedFromSameStore = (modelFeatureStoreName.equals(tranformerFeatureStoreName) || tranformerFeatureStoreName == null) ? extractFeatures : false;
|
||||
|
||||
if (isInterleaving) {
|
||||
rerankingQuery = rerankingQueries[i] = new LTRInterleavingScoringQuery(ltrScoringModel,
|
||||
externalFeatureInfo,
|
||||
featuresRequestedFromSameStore, threadManager);
|
||||
} else {
|
||||
rerankingQuery = new LTRScoringQuery(ltrScoringModel,
|
||||
externalFeatureInfo,
|
||||
featuresRequestedFromSameStore, threadManager);
|
||||
rerankingQueries[i] = null;
|
||||
}
|
||||
|
||||
// Enable the feature vector caching if we are extracting features, and the features
|
||||
// we requested are the same ones we are reranking with
|
||||
if (featuresRequestedFromSameStore) {
|
||||
rerankingQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) );
|
||||
}
|
||||
}else{
|
||||
rerankingQuery = rerankingQueries[i] = new OriginalRankingLTRScoringQuery(ORIGINAL_RANKING);
|
||||
}
|
||||
|
||||
// External features
|
||||
rerankingQuery.setRequest(req);
|
||||
}
|
||||
SolrQueryRequestContextUtils.setScoringQuery(req, scoringQuery);
|
||||
|
||||
int reRankDocs = localParams.getInt(RERANK_DOCS, DEFAULT_RERANK_DOCS);
|
||||
if (reRankDocs <= 0) {
|
||||
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
|
||||
"Must rerank at least 1 document");
|
||||
"Must rerank at least 1 document");
|
||||
}
|
||||
if (!isInterleaving) {
|
||||
SolrQueryRequestContextUtils.setScoringQueries(req, new LTRScoringQuery[] { rerankingQuery });
|
||||
return new LTRQuery(rerankingQuery, reRankDocs);
|
||||
} else {
|
||||
SolrQueryRequestContextUtils.setScoringQueries(req, rerankingQueries);
|
||||
return new LTRInterleavingQuery(Interleaving.getImplementation(Interleaving.TEAM_DRAFT),rerankingQueries, reRankDocs);
|
||||
}
|
||||
|
||||
// External features
|
||||
scoringQuery.setRequest(req);
|
||||
|
||||
return new LTRQuery(scoringQuery, reRankDocs);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A learning to rank Query, will incapsulate a learning to rank model, and delegate to it the rescoring
|
||||
* of the documents.
|
||||
**/
|
||||
public class LTRQuery extends AbstractReRankQuery {
|
||||
private final LTRScoringQuery scoringQuery;
|
||||
|
||||
public LTRQuery(LTRScoringQuery scoringQuery, int reRankDocs) {
|
||||
super(defaultQuery, reRankDocs, new LTRRescorer(scoringQuery));
|
||||
this.scoringQuery = scoringQuery;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return 31 * classHash() + (mainQuery.hashCode() + scoringQuery.hashCode() + reRankDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
return sameClassAs(o) && equalsTo(getClass().cast(o));
|
||||
}
|
||||
|
||||
private boolean equalsTo(LTRQuery other) {
|
||||
return (mainQuery.equals(other.mainQuery)
|
||||
&& scoringQuery.equals(other.scoringQuery) && (reRankDocs == other.reRankDocs));
|
||||
}
|
||||
|
||||
@Override
|
||||
public RankQuery wrap(Query _mainQuery) {
|
||||
super.wrap(_mainQuery);
|
||||
scoringQuery.setOriginalQuery(_mainQuery);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return "{!ltr mainQuery='" + mainQuery.toString() + "' scoringQuery='"
|
||||
+ scoringQuery.toString() + "' reRankDocs=" + reRankDocs + "}";
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Query rewrite(Query rewrittenMainQuery) throws IOException {
|
||||
return new LTRQuery(scoringQuery, reRankDocs).wrap(rewrittenMainQuery);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.solr.ltr.search;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.solr.ltr.LTRRescorer;
|
||||
import org.apache.solr.ltr.LTRScoringQuery;
|
||||
import org.apache.solr.search.AbstractReRankQuery;
|
||||
import org.apache.solr.search.RankQuery;
|
||||
|
||||
/**
|
||||
* A learning to rank Query, will incapsulate a learning to rank model, and delegate to it the rescoring
|
||||
* of the documents.
|
||||
**/
|
||||
public class LTRQuery extends AbstractReRankQuery {
|
||||
private static final Query defaultQuery = new MatchAllDocsQuery();
|
||||
private final LTRScoringQuery scoringQuery;
|
||||
|
||||
public LTRQuery(LTRScoringQuery scoringQuery, int reRankDocs) {
|
||||
this(scoringQuery, reRankDocs, new LTRRescorer(scoringQuery));
|
||||
}
|
||||
|
||||
protected LTRQuery(LTRScoringQuery scoringQuery, int reRankDocs, LTRRescorer rescorer) {
|
||||
super(defaultQuery, reRankDocs, rescorer);
|
||||
this.scoringQuery = scoringQuery;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return 31 * classHash() + (mainQuery.hashCode() + scoringQuery.hashCode() + reRankDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
return sameClassAs(o) && equalsTo(getClass().cast(o));
|
||||
}
|
||||
|
||||
private boolean equalsTo(LTRQuery other) {
|
||||
return (mainQuery.equals(other.mainQuery)
|
||||
&& scoringQuery.equals(other.scoringQuery) && (reRankDocs == other.reRankDocs));
|
||||
}
|
||||
|
||||
@Override
|
||||
public RankQuery wrap(Query _mainQuery) {
|
||||
super.wrap(_mainQuery);
|
||||
if (scoringQuery != null) {
|
||||
scoringQuery.setOriginalQuery(_mainQuery);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return "{!ltr mainQuery='" + mainQuery.toString() + "' scoringQuery='"
|
||||
+ scoringQuery.toString() + "' reRankDocs=" + reRankDocs + "}";
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Query rewrite(Query rewrittenMainQuery) throws IOException {
|
||||
return new LTRQuery(scoringQuery, reRankDocs).wrap(rewrittenMainQuery);
|
||||
}
|
||||
}
|
|
@ -39,7 +39,7 @@ A Learning to Rank model is plugged into the ranking through the {@link org.apac
|
|||
a {@link org.apache.solr.search.QParserPlugin}. The plugin will
|
||||
read from the request the model (instance of {@link org.apache.solr.ltr.model.LTRScoringModel})
|
||||
used to perform the request plus other
|
||||
parameters. The plugin will generate a {@link org.apache.solr.ltr.search.LTRQParserPlugin.LTRQuery LTRQuery}:
|
||||
parameters. The plugin will generate a {@link org.apache.solr.ltr.search.LTRQuery LTRQuery}:
|
||||
a particular {@link org.apache.solr.search.RankQuery}
|
||||
that will encapsulate the given model and use it to
|
||||
rescore and rerank the document (by using an {@link org.apache.solr.ltr.LTRRescorer}).
|
||||
|
|
|
@ -46,6 +46,15 @@
|
|||
<str name="fvCacheName">QUERY_DOC_FV</str>
|
||||
</transformer>
|
||||
|
||||
<!-- add a transformer that will encode the model the interleaving process chose the search result from.
|
||||
For each document the transformer will add an extra field in the response with the model picked.
|
||||
The name of the field will be the the name of the transformer
|
||||
enclosed between brackets (in this case [interleaving]).
|
||||
In order to get the model chosen for the search result
|
||||
you will have to specify that you want the field (e.g., fl="*,[interleaving]) -->
|
||||
<transformer name="interleaving" class="org.apache.solr.ltr.response.transform.LTRInterleavingTransformerFactory">
|
||||
</transformer>
|
||||
|
||||
<updateHandler class="solr.DirectUpdateHandler2">
|
||||
<autoCommit>
|
||||
<maxTime>15000</maxTime>
|
||||
|
|
|
@ -16,7 +16,11 @@
|
|||
*/
|
||||
package org.apache.solr.ltr;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
import org.apache.solr.client.solrj.SolrQuery;
|
||||
import org.apache.solr.ltr.feature.SolrFeature;
|
||||
import org.apache.solr.ltr.interleaving.algorithms.TeamDraftInterleaving;
|
||||
import org.apache.solr.ltr.model.LinearModel;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
@ -149,4 +153,160 @@ public class TestLTRQParserExplain extends TestRerankBase {
|
|||
" 65.0 = tree 1 | \\'user_device_tablet\\':1.0 > 0.500001, Go Right | val: 65.0\n'}");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void interleavingModels_shouldReturnExplainForTheModelPicked() throws Exception {
|
||||
TeamDraftInterleaving.setRANDOM(new Random(10));//Random Boolean Choices Generation from Seed: [1,0]
|
||||
|
||||
loadFeature("featureA1", SolrFeature.class.getName(),
|
||||
"{\"fq\":[\"{!terms f=popularity}1\"]}");
|
||||
loadFeature("featureA2", SolrFeature.class.getName(),
|
||||
"{\"fq\":[\"{!terms f=description}bloomberg\"]}");
|
||||
loadFeature("featureAB", SolrFeature.class.getName(),
|
||||
"{\"fq\":[\"{!terms f=popularity}2\"]}");
|
||||
loadFeature("featureB1", SolrFeature.class.getName(),
|
||||
"{\"fq\":[\"{!terms f=popularity}5\"]}");
|
||||
loadFeature("featureB2", SolrFeature.class.getName(),
|
||||
"{\"fq\":[\"{!terms f=title}different\"]}");
|
||||
|
||||
loadModel("modelA", LinearModel.class.getName(),
|
||||
new String[]{"featureA1", "featureA2", "featureAB"},
|
||||
"{\"weights\":{\"featureA1\":3.0, \"featureA2\":9.0, \"featureAB\":27.0}}");
|
||||
|
||||
loadModel("modelB", LinearModel.class.getName(),
|
||||
new String[]{"featureB1", "featureB2", "featureAB"},
|
||||
"{\"weights\":{\"featureB1\":2.0, \"featureB2\":4.0, \"featureAB\":8.0}}");
|
||||
|
||||
final SolrQuery query = new SolrQuery();
|
||||
query.setQuery("title:bloomberg");
|
||||
query.setParam("debugQuery", "on");
|
||||
query.add("rows", "10");
|
||||
query.add("rq", "{!ltr reRankDocs=10 model=modelA model=modelB}");
|
||||
query.add("fl", "*,score");
|
||||
|
||||
/*
|
||||
Doc6 = "featureA1=1.0 featureA2=1.0 featureB2=1.0", ScoreA(12), ScoreB(4)
|
||||
Doc7 = "featureA2=1.0 featureAB=1.0", ScoreA(36), ScoreB(8)
|
||||
Doc8 = "featureA2=1.0", ScoreA(9), ScoreB(0)
|
||||
Doc9 = "featureA2=1.0 featureB1=1.0", ScoreA(9), ScoreB(2)
|
||||
|
||||
ModelARerankedList = [7,6,8,9]
|
||||
ModelBRerankedList = [7,6,9,8]
|
||||
|
||||
Random Boolean Choices Generation from Seed: [1,0]
|
||||
|
||||
*/
|
||||
|
||||
int[] expectedInterleaved = new int[]{7, 6, 8, 9};
|
||||
String[] expectedExplains = new String[]{
|
||||
"\n8.0 = LinearModel(name=modelB," +
|
||||
"featureWeights=[featureB1=2.0,featureB2=4.0,featureAB=8.0]) " +
|
||||
"model applied to features, sum of:\n " +
|
||||
"0.0 = prod of:\n 2.0 = weight on feature\n 0.0 = SolrFeature [name=featureB1, params={fq=[{!terms f=popularity}5]}]\n " +
|
||||
"0.0 = prod of:\n 4.0 = weight on feature\n 0.0 = SolrFeature [name=featureB2, params={fq=[{!terms f=title}different]}]\n " +
|
||||
"8.0 = prod of:\n 8.0 = weight on feature\n 1.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n",
|
||||
"\n12.0 = LinearModel(name=modelA," +
|
||||
"featureWeights=[featureA1=3.0,featureA2=9.0,featureAB=27.0]) " +
|
||||
"model applied to features, sum of:\n " +
|
||||
"3.0 = prod of:\n 3.0 = weight on feature\n 1.0 = SolrFeature [name=featureA1, params={fq=[{!terms f=popularity}1]}]\n " +
|
||||
"9.0 = prod of:\n 9.0 = weight on feature\n 1.0 = SolrFeature [name=featureA2, params={fq=[{!terms f=description}bloomberg]}]\n " +
|
||||
"0.0 = prod of:\n 27.0 = weight on feature\n 0.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n",
|
||||
"\n9.0 = LinearModel(name=modelA," +
|
||||
"featureWeights=[featureA1=3.0,featureA2=9.0,featureAB=27.0]) " +
|
||||
"model applied to features, sum of:\n " +
|
||||
"0.0 = prod of:\n 3.0 = weight on feature\n 0.0 = SolrFeature [name=featureA1, params={fq=[{!terms f=popularity}1]}]\n " +
|
||||
"9.0 = prod of:\n 9.0 = weight on feature\n 1.0 = SolrFeature [name=featureA2, params={fq=[{!terms f=description}bloomberg]}]\n " +
|
||||
"0.0 = prod of:\n 27.0 = weight on feature\n 0.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n",
|
||||
"\n2.0 = LinearModel(name=modelB," +
|
||||
"featureWeights=[featureB1=2.0,featureB2=4.0,featureAB=8.0]) " +
|
||||
"model applied to features, sum of:\n " +
|
||||
"2.0 = prod of:\n 2.0 = weight on feature\n 1.0 = SolrFeature [name=featureB1, params={fq=[{!terms f=popularity}5]}]\n " +
|
||||
"0.0 = prod of:\n 4.0 = weight on feature\n 0.0 = SolrFeature [name=featureB2, params={fq=[{!terms f=title}different]}]\n " +
|
||||
"0.0 = prod of:\n 8.0 = weight on feature\n 0.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n"};
|
||||
|
||||
|
||||
|
||||
String[] tests = new String[16];
|
||||
tests[0] = "/response/numFound/==4";
|
||||
for (int i = 1; i <= 4; i++) {
|
||||
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||
tests[i + 4] = "/debug/explain/" + expectedInterleaved[(i - 1)] + "=='" + expectedExplains[(i - 1)]+"'}";
|
||||
}
|
||||
assertJQ("/query" + query.toQueryString(), tests);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void interleavingModelsWithOriginalRanking_shouldReturnExplainForTheModelPicked() throws Exception {
|
||||
TeamDraftInterleaving.setRANDOM(new Random(10));//Random Boolean Choices Generation from Seed: [1,0]
|
||||
|
||||
loadFeature("featureA1", SolrFeature.class.getName(),
|
||||
"{\"fq\":[\"{!terms f=popularity}1\"]}");
|
||||
loadFeature("featureA2", SolrFeature.class.getName(),
|
||||
"{\"fq\":[\"{!terms f=description}bloomberg\"]}");
|
||||
loadFeature("featureAB", SolrFeature.class.getName(),
|
||||
"{\"fq\":[\"{!terms f=popularity}2\"]}");
|
||||
|
||||
loadModel("modelA", LinearModel.class.getName(),
|
||||
new String[]{"featureA1", "featureA2", "featureAB"},
|
||||
"{\"weights\":{\"featureA1\":3.0, \"featureA2\":9.0, \"featureAB\":27.0}}");
|
||||
|
||||
final SolrQuery query = new SolrQuery();
|
||||
query.setQuery("title:bloomberg");
|
||||
query.setParam("debugQuery", "on");
|
||||
query.add("rows", "10");
|
||||
query.add("rq", "{!ltr reRankDocs=10 model=modelA model=_OriginalRanking_}");
|
||||
query.add("fl", "*,score");
|
||||
|
||||
/*
|
||||
Doc6 = "featureA1=1.0 featureA2=1.0 featureB2=1.0", ScoreA(12)
|
||||
Doc7 = "featureA2=1.0 featureAB=1.0", ScoreA(36)
|
||||
Doc8 = "featureA2=1.0", ScoreA(9)
|
||||
Doc9 = "featureA2=1.0 featureB1=1.0", ScoreA(9)
|
||||
|
||||
ModelARerankedList = [7,6,8,9]
|
||||
OriginalRanking = [9,8,7,6]
|
||||
|
||||
Random Boolean Choices Generation from Seed: [1,0]
|
||||
|
||||
*/
|
||||
|
||||
int[] expectedInterleaved = new int[]{9, 7, 6, 8};
|
||||
String[] expectedExplains = new String[]{
|
||||
"\n0.07662583 = weight(title:bloomberg in 3) [SchemaSimilarity], result of:\n " +
|
||||
"0.07662583 = score(freq=4.0), computed as boost * idf * tf from:\n " +
|
||||
"0.105360515 = idf, computed as log(1 + (N - n + 0.5) / (n + 0.5)) from:\n 4 = n, number of documents containing term\n 4 = N, total number of documents with field\n " +
|
||||
"0.72727275 = tf, computed as freq / (freq + k1 * (1 - b + b * dl / avgdl)) from:\n 4.0 = freq, occurrences of term within document\n " +
|
||||
"1.2 = k1, term saturation parameter\n " +
|
||||
"0.75 = b, length normalization parameter\n " +
|
||||
"4.0 = dl, length of field\n " +
|
||||
"3.0 = avgdl, average length of field\n",
|
||||
"\n36.0 = LinearModel(name=modelA," +
|
||||
"featureWeights=[featureA1=3.0,featureA2=9.0,featureAB=27.0]) " +
|
||||
"model applied to features, sum of:\n " +
|
||||
"0.0 = prod of:\n 3.0 = weight on feature\n 0.0 = SolrFeature [name=featureA1, params={fq=[{!terms f=popularity}1]}]\n " +
|
||||
"9.0 = prod of:\n 9.0 = weight on feature\n 1.0 = SolrFeature [name=featureA2, params={fq=[{!terms f=description}bloomberg]}]\n " +
|
||||
"27.0 = prod of:\n 27.0 = weight on feature\n 1.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n",
|
||||
"\n12.0 = LinearModel(name=modelA," +
|
||||
"featureWeights=[featureA1=3.0,featureA2=9.0,featureAB=27.0]) " +
|
||||
"model applied to features, sum of:\n " +
|
||||
"3.0 = prod of:\n 3.0 = weight on feature\n 1.0 = SolrFeature [name=featureA1, params={fq=[{!terms f=popularity}1]}]\n " +
|
||||
"9.0 = prod of:\n 9.0 = weight on feature\n 1.0 = SolrFeature [name=featureA2, params={fq=[{!terms f=description}bloomberg]}]\n " +
|
||||
"0.0 = prod of:\n 27.0 = weight on feature\n 0.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n",
|
||||
"\n0.07525751 = weight(title:bloomberg in 2) [SchemaSimilarity], result of:\n " +
|
||||
"0.07525751 = score(freq=3.0), computed as boost * idf * tf from:\n " +
|
||||
"0.105360515 = idf, computed as log(1 + (N - n + 0.5) / (n + 0.5)) from:\n 4 = n, number of documents containing term\n 4 = N, total number of documents with field\n " +
|
||||
"0.71428573 = tf, computed as freq / (freq + k1 * (1 - b + b * dl / avgdl)) from:\n 3.0 = freq, occurrences of term within document\n " +
|
||||
"1.2 = k1, term saturation parameter\n " +
|
||||
"0.75 = b, length normalization parameter\n " +
|
||||
"3.0 = dl, length of field\n " +
|
||||
"3.0 = avgdl, average length of field\n"};
|
||||
|
||||
String[] tests = new String[16];
|
||||
tests[0] = "/response/numFound/==4";
|
||||
for (int i = 1; i <= 4; i++) {
|
||||
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||
tests[i + 4] = "/debug/explain/" + expectedInterleaved[(i - 1)] + "=='" + expectedExplains[(i - 1)]+"'}";
|
||||
}
|
||||
assertJQ("/query" + query.toQueryString(), tests);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -48,7 +48,21 @@ public class TestLTRQParserPlugin extends TestRerankBase {
|
|||
query.add("rq", "{!ltr reRankDocs=100}");
|
||||
|
||||
final String res = restTestHarness.query("/query" + query.toQueryString());
|
||||
assert (res.contains("Must provide model in the request"));
|
||||
assert (res.contains("Must provide one or two models in the request"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void interleavingLtrTooManyModelsTest() throws Exception {
|
||||
final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}";
|
||||
final SolrQuery query = new SolrQuery();
|
||||
query.setQuery(solrQuery);
|
||||
query.add("fl", "*, score");
|
||||
query.add("rows", "4");
|
||||
query.add("fv", "true");
|
||||
query.add("rq", "{!ltr model=modelA model=modelB model=C reRankDocs=100}");
|
||||
|
||||
final String res = restTestHarness.query("/query" + query.toQueryString());
|
||||
assert (res.contains("Must provide one or two models in the request"));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -65,6 +79,34 @@ public class TestLTRQParserPlugin extends TestRerankBase {
|
|||
assert (res.contains("cannot find model"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void ltrModelIsEmptyTest() throws Exception {
|
||||
final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}";
|
||||
final SolrQuery query = new SolrQuery();
|
||||
query.setQuery(solrQuery);
|
||||
query.add("fl", "*, score");
|
||||
query.add("rows", "4");
|
||||
query.add("fv", "true");
|
||||
query.add("rq", "{!ltr model=\"\" reRankDocs=100}");
|
||||
|
||||
final String res = restTestHarness.query("/query" + query.toQueryString());
|
||||
assert (res.contains("the model 0 is empty"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void interleavingLtrModelIsEmptyTest() throws Exception {
|
||||
final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}";
|
||||
final SolrQuery query = new SolrQuery();
|
||||
query.setQuery(solrQuery);
|
||||
query.add("fl", "*, score");
|
||||
query.add("rows", "4");
|
||||
query.add("fv", "true");
|
||||
query.add("rq", "{!ltr model=6029760550880411648 model=\"\" reRankDocs=100}");
|
||||
|
||||
final String res = restTestHarness.query("/query" + query.toQueryString());
|
||||
assert (res.contains("the model 1 is empty"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void ltrBadRerankDocsTest() throws Exception {
|
||||
final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}";
|
||||
|
|
|
@ -17,8 +17,11 @@
|
|||
|
||||
package org.apache.solr.ltr;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
import org.apache.solr.client.solrj.SolrQuery;
|
||||
import org.apache.solr.ltr.feature.SolrFeature;
|
||||
import org.apache.solr.ltr.interleaving.algorithms.TeamDraftInterleaving;
|
||||
import org.apache.solr.ltr.model.LinearModel;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
@ -97,4 +100,104 @@ public class TestLTRWithSort extends TestRerankBase {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void interleavingTwoModelsWithSort_shouldInterleave() throws Exception {
|
||||
TeamDraftInterleaving.setRANDOM(new Random(10));//Random Boolean Choices Generation from Seed: [1,0]
|
||||
|
||||
loadFeature("featureA", SolrFeature.class.getName(),
|
||||
"{\"q\":\"{!func}pow(popularity,2)\"}");
|
||||
|
||||
loadFeature("featureB", SolrFeature.class.getName(),
|
||||
"{\"q\":\"{!func}pow(popularity,-2)\"}");
|
||||
|
||||
loadModel("modelA", LinearModel.class.getName(),
|
||||
new String[] {"featureA"}, "{\"weights\":{\"featureA\":1.0}}");
|
||||
|
||||
loadModel("modelB", LinearModel.class.getName(),
|
||||
new String[] {"featureB"}, "{\"weights\":{\"featureB\":1.0}}");
|
||||
|
||||
final SolrQuery query = new SolrQuery();
|
||||
query.setQuery("title:a1");
|
||||
query.add("rows", "10");
|
||||
query.add("rq", "{!ltr reRankDocs=4 model=modelA model=modelB}");
|
||||
query.add("fl", "*,score");
|
||||
query.add("sort", "description desc");
|
||||
|
||||
/*
|
||||
Doc1 = "popularity=1", ScoreA(1) ScoreB(1)
|
||||
Doc5 = "popularity=5", ScoreA(25) ScoreB(0.04)
|
||||
Doc7 = "popularity=7", ScoreA(49) ScoreB(0.02)
|
||||
Doc8 = "popularity=8", ScoreA(64) ScoreB(0.01)
|
||||
|
||||
ModelARerankedList = [8,7,5,1]
|
||||
ModelBRerankedList = [1,5,7,8]
|
||||
|
||||
OriginalRanking = [1,5,8,7]
|
||||
|
||||
Random Boolean Choices Generation from Seed: [1,0]
|
||||
*/
|
||||
|
||||
int[] expectedInterleaved = new int[]{1, 8, 7, 5};
|
||||
|
||||
String[] tests = new String[5];
|
||||
tests[0] = "/response/numFound/==8";
|
||||
for (int i = 1; i <= 4; i++) {
|
||||
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||
}
|
||||
assertJQ("/query" + query.toQueryString(), tests);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void interleavingModelsWithOriginalRankingSort_shouldInterleave() throws Exception {
|
||||
|
||||
loadFeature("powpularityS", SolrFeature.class.getName(),
|
||||
"{\"q\":\"{!func}pow(popularity,2)\"}");
|
||||
|
||||
loadModel("powpularityS-model", LinearModel.class.getName(),
|
||||
new String[] {"powpularityS"}, "{\"weights\":{\"powpularityS\":1.0}}");
|
||||
|
||||
for (boolean originalRankingLast : new boolean[] { true, false }) {
|
||||
TeamDraftInterleaving.setRANDOM(new Random(10));//Random Boolean Choices Generation from Seed: [1,0]
|
||||
|
||||
final SolrQuery query = new SolrQuery();
|
||||
query.setQuery("title:a1");
|
||||
query.add("rows", "10");
|
||||
if (originalRankingLast) {
|
||||
query.add("rq", "{!ltr reRankDocs=4 model=powpularityS-model model=_OriginalRanking_}");
|
||||
} else {
|
||||
query.add("rq", "{!ltr reRankDocs=4 model=_OriginalRanking_ model=powpularityS-model}");
|
||||
}
|
||||
query.add("fl", "*,score");
|
||||
query.add("sort", "description desc");
|
||||
|
||||
/*
|
||||
Doc1 = "popularity=1", ScorePowpularityS(1)
|
||||
Doc5 = "popularity=5", ScorePowpularityS(25)
|
||||
Doc7 = "popularity=7", ScorePowpularityS(49)
|
||||
Doc8 = "popularity=8", ScorePowpularityS(64)
|
||||
|
||||
PowpularitySRerankedList = [8,7,5,1]
|
||||
OriginalRanking = [1,5,8,7]
|
||||
|
||||
Random Boolean Choices Generation from Seed: [1,0]
|
||||
*/
|
||||
|
||||
final int[] expectedInterleaved;
|
||||
if (originalRankingLast) {
|
||||
expectedInterleaved = new int[]{1, 8, 7, 5};
|
||||
} else {
|
||||
expectedInterleaved = new int[]{8, 1, 5, 7};
|
||||
}
|
||||
|
||||
String[] tests = new String[5];
|
||||
tests[0] = "/response/numFound/==8";
|
||||
for (int i = 1; i <= 4; i++) {
|
||||
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||
}
|
||||
assertJQ("/query" + query.toQueryString(), tests);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,170 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.solr.ltr.interleaving.algorithms;
|
||||
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.solr.ltr.interleaving.InterleavingResult;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.hamcrest.MatcherAssert.assertThat;
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class TeamDraftInterleavingTest {
|
||||
private TeamDraftInterleaving toTest;
|
||||
private ScoreDoc[] rerankedA,rerankedB;
|
||||
private ScoreDoc a1,a2,a3,a4,a5;
|
||||
private ScoreDoc b1,b2,b3,b4,b5;
|
||||
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
toTest = new TeamDraftInterleaving();
|
||||
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
}
|
||||
|
||||
protected void initDifferentOrderRerankLists() {
|
||||
a1 = new ScoreDoc(1,10,1);
|
||||
a2 = new ScoreDoc(5,7,1);
|
||||
a3 = new ScoreDoc(4,6,1);
|
||||
a4 = new ScoreDoc(2,5,1);
|
||||
a5 = new ScoreDoc(3,4,1);
|
||||
rerankedA = new ScoreDoc[]{a1,a2,a3,a4,a5};
|
||||
|
||||
b1 = new ScoreDoc(1,10,1);
|
||||
b2 = new ScoreDoc(4,7,1);
|
||||
b3 = new ScoreDoc(5,6,1);
|
||||
b4 = new ScoreDoc(3,5,1);
|
||||
b5 = new ScoreDoc(2,4,1);
|
||||
rerankedB = new ScoreDoc[]{b1,b2,b3,b4,b5};
|
||||
}
|
||||
|
||||
/**
|
||||
* Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
*/
|
||||
@Test
|
||||
public void interleaving_twoDifferentLists_shouldInterleaveTeamDraft() {
|
||||
initDifferentOrderRerankLists();
|
||||
|
||||
InterleavingResult interleaved = toTest.interleave(rerankedA, rerankedB);
|
||||
ScoreDoc[] interleavedResults = interleaved.getInterleavedResults();
|
||||
|
||||
assertThat(interleavedResults.length,is(5));
|
||||
|
||||
assertThat(interleavedResults[0],is(a1));
|
||||
assertThat(interleavedResults[1],is(b2));
|
||||
|
||||
assertThat(interleavedResults[2],is(b3));
|
||||
assertThat(interleavedResults[3],is(a4));
|
||||
|
||||
assertThat(interleavedResults[4],is(b4));
|
||||
}
|
||||
|
||||
/**
|
||||
* Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
*/
|
||||
@Test
|
||||
public void interleaving_twoDifferentLists_shouldBuildCorrectInterleavingPicks() {
|
||||
initDifferentOrderRerankLists();
|
||||
|
||||
InterleavingResult interleaved = toTest.interleave(rerankedA, rerankedB);
|
||||
|
||||
ArrayList<Set<Integer>> interleavingPicks = interleaved.getInterleavingPicks();
|
||||
Set<Integer> modelAPicks = interleavingPicks.get(0);
|
||||
Set<Integer> modelBPicks = interleavingPicks.get(1);
|
||||
|
||||
assertThat(modelAPicks.size(),is(2));
|
||||
assertThat(modelBPicks.size(),is(3));
|
||||
|
||||
assertTrue(modelAPicks.contains(a1.doc));
|
||||
assertTrue(modelAPicks.contains(a4.doc));
|
||||
|
||||
assertTrue(modelBPicks.contains(b2.doc));
|
||||
assertTrue(modelBPicks.contains(b3.doc));
|
||||
assertTrue(modelBPicks.contains(b4.doc));
|
||||
}
|
||||
|
||||
protected void initIdenticalOrderRerankLists() {
|
||||
a1 = new ScoreDoc(1,10,1);
|
||||
a2 = new ScoreDoc(5,7,1);
|
||||
a3 = new ScoreDoc(4,6,1);
|
||||
a4 = new ScoreDoc(2,5,1);
|
||||
a5 = new ScoreDoc(3,4,1);
|
||||
rerankedA = new ScoreDoc[]{a1,a2,a3,a4,a5};
|
||||
|
||||
b1 = new ScoreDoc(1,10,1);
|
||||
b2 = new ScoreDoc(5,7,1);
|
||||
b3 = new ScoreDoc(4,6,1);
|
||||
b4 = new ScoreDoc(2,5,1);
|
||||
b5 = new ScoreDoc(3,4,1);
|
||||
rerankedB = new ScoreDoc[]{b1,b2,b3,b4,b5};
|
||||
}
|
||||
|
||||
/**
|
||||
* Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
*/
|
||||
@Test
|
||||
public void interleaving_identicalRerankLists_shouldInterleaveTeamDraft() {
|
||||
initIdenticalOrderRerankLists();
|
||||
|
||||
InterleavingResult interleaved = toTest.interleave(rerankedA, rerankedB);
|
||||
ScoreDoc[] interleavedResults = interleaved.getInterleavedResults();
|
||||
|
||||
assertThat(interleavedResults.length,is(5));
|
||||
|
||||
assertThat(interleavedResults[0],is(a1));
|
||||
assertThat(interleavedResults[1],is(b2));
|
||||
|
||||
assertThat(interleavedResults[2],is(b3));
|
||||
assertThat(interleavedResults[3],is(a4));
|
||||
|
||||
assertThat(interleavedResults[4],is(b5));
|
||||
}
|
||||
|
||||
/**
|
||||
* Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
*/
|
||||
@Test
|
||||
public void interleaving_identicalRerankLists_shouldBuildCorrectInterleavingPicks() {
|
||||
initIdenticalOrderRerankLists();
|
||||
|
||||
InterleavingResult interleaved = toTest.interleave(rerankedA, rerankedB);
|
||||
|
||||
ArrayList<Set<Integer>> interleavingPicks = interleaved.getInterleavingPicks();
|
||||
Set<Integer> modelAPicks = interleavingPicks.get(0);
|
||||
Set<Integer> modelBPicks = interleavingPicks.get(1);
|
||||
|
||||
assertThat(modelAPicks.size(),is(2));
|
||||
assertThat(modelBPicks.size(),is(3));
|
||||
|
||||
assertTrue(modelAPicks.contains(a1.doc));
|
||||
assertTrue(modelAPicks.contains(a4.doc));
|
||||
|
||||
assertTrue(modelBPicks.contains(b2.doc));
|
||||
assertTrue(modelBPicks.contains(b3.doc));
|
||||
assertTrue(modelBPicks.contains(b5.doc));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,400 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.ltr.response.transform;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
import org.apache.solr.client.solrj.SolrQuery;
|
||||
import org.apache.solr.ltr.TestRerankBase;
|
||||
import org.apache.solr.ltr.feature.SolrFeature;
|
||||
import org.apache.solr.ltr.interleaving.algorithms.TeamDraftInterleaving;
|
||||
import org.apache.solr.ltr.model.LinearModel;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
public class TestFeatureLoggerTransformer extends TestRerankBase {
|
||||
@Before
|
||||
public void before() throws Exception {
|
||||
setuptest(false);
|
||||
|
||||
assertU(adoc("id", "1", "title", "w1", "description", "w5", "popularity",
|
||||
"1"));
|
||||
assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description",
|
||||
"w2 2asd asdd didid", "popularity", "2"));
|
||||
assertU(adoc("id", "3", "title", "w1", "description", "w5", "popularity",
|
||||
"3"));
|
||||
assertU(adoc("id", "4", "title", "w1", "description", "w1", "popularity",
|
||||
"6"));
|
||||
assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity",
|
||||
"5"));
|
||||
assertU(adoc("id", "6", "title", "w6 w2", "description", "w1 w2",
|
||||
"popularity", "6"));
|
||||
assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description",
|
||||
"w6 w2 w3 w4 w5 w8", "popularity", "88888"));
|
||||
assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description",
|
||||
"w1 w1 w1 w2 w2 w5", "popularity", "88888"));
|
||||
assertU(commit());
|
||||
|
||||
|
||||
}
|
||||
|
||||
@After
|
||||
public void after() throws Exception {
|
||||
aftertest();
|
||||
}
|
||||
|
||||
protected void loadFeaturesAndModels() throws Exception {
|
||||
loadFeature("featureA1", SolrFeature.class.getName(),
|
||||
"{\"fq\":[\"{!terms f=popularity}88888\"]}");
|
||||
loadFeature("featureA2", SolrFeature.class.getName(),
|
||||
"{\"fq\":[\"{!terms f=title}${user_query}\"]}");
|
||||
loadFeature("featureAB", SolrFeature.class.getName(),
|
||||
"{\"fq\":[\"{!terms f=title}${user_query}\"]}");
|
||||
loadFeature("featureB1", SolrFeature.class.getName(),
|
||||
"{\"fq\":[\"{!terms f=popularity}6\"]}");
|
||||
loadFeature("featureB2", SolrFeature.class.getName(),
|
||||
"{\"fq\":[\"{!terms f=description}${user_query}\"]}");
|
||||
loadFeature("featureC1", SolrFeature.class.getName(),"featureStore2",
|
||||
"{\"fq\":[\"{!terms f=popularity}6\"]}");
|
||||
loadFeature("featureC2", SolrFeature.class.getName(),"featureStore2",
|
||||
"{\"fq\":[\"{!terms f=description}${user_query}\"]}");
|
||||
loadFeature("featureC3", SolrFeature.class.getName(),"featureStore2",
|
||||
"{\"fq\":[\"{!terms f=description}${user_query}\"]}");
|
||||
|
||||
loadModel("modelA", LinearModel.class.getName(),
|
||||
new String[]{"featureA1", "featureA2", "featureAB"},
|
||||
"{\"weights\":{\"featureA1\":3.0, \"featureA2\":9.0, \"featureAB\":27.0}}");
|
||||
|
||||
loadModel("modelB", LinearModel.class.getName(),
|
||||
new String[]{"featureB1", "featureB2", "featureAB"},
|
||||
"{\"weights\":{\"featureB1\":2.0, \"featureB2\":4.0, \"featureAB\":8.0}}");
|
||||
|
||||
loadModel("modelC", LinearModel.class.getName(),
|
||||
new String[]{"featureC1", "featureC2"},"featureStore2",
|
||||
"{\"weights\":{\"featureC1\":5.0, \"featureC2\":25.0}}");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void interleaving_featureTransformer_shouldWorkInSparseFormat() throws Exception {
|
||||
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
loadFeaturesAndModels();
|
||||
|
||||
final SolrQuery query = new SolrQuery();
|
||||
query.setQuery("*:*");
|
||||
query.add("fl", "*, score,features:[fv format=sparse]");
|
||||
query.add("rows", "10");
|
||||
query.add("debugQuery", "true");
|
||||
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||
query.add("rq",
|
||||
"{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}");
|
||||
|
||||
/*
|
||||
Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||
Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||
Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2)
|
||||
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12)
|
||||
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4)
|
||||
ModelARerankedList = [7,8,1,3,4]
|
||||
ModelBRerankedList = [7,1,3,8,4]
|
||||
|
||||
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
*/
|
||||
String[] expectedFeatureVectors = new String[]{"featureA1\\=1.0\\,featureA2\\=1.0\\,featureAB\\=1.0\\,featureB2\\=1.0", "featureB2\\=1.0", "featureB2\\=1.0", "featureA1\\=1.0\\,featureB2\\=1.0", "featureB1\\=1.0"};
|
||||
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||
|
||||
String[] tests = new String[11];
|
||||
tests[0] = "/response/numFound/==5";
|
||||
for (int i = 1; i <= 5; i++) {
|
||||
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||
tests[i + 5] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
|
||||
}
|
||||
assertJQ("/query" + query.toQueryString(), tests);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void interleaving_featureTransformer_shouldWorkInDenseFormat() throws Exception {
|
||||
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
loadFeaturesAndModels();
|
||||
|
||||
final SolrQuery query = new SolrQuery();
|
||||
query.setQuery("*:*");
|
||||
query.add("fl", "*, score,features:[fv format=dense]");
|
||||
query.add("rows", "10");
|
||||
query.add("debugQuery", "true");
|
||||
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||
query.add("rq",
|
||||
"{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}");
|
||||
|
||||
/*
|
||||
Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||
Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||
Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2)
|
||||
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12)
|
||||
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4)
|
||||
ModelARerankedList = [7,8,1,3,4]
|
||||
ModelBRerankedList = [7,1,3,8,4]
|
||||
|
||||
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
*/
|
||||
String[] expectedFeatureVectors = new String[]{"featureA1\\=1.0\\,featureA2\\=1.0\\,featureAB\\=1.0\\,featureB1\\=0.0\\,featureB2\\=1.0",
|
||||
"featureA1\\=0.0\\,featureA2\\=0.0\\,featureAB\\=0.0\\,featureB1\\=0.0\\,featureB2\\=1.0",
|
||||
"featureA1\\=0.0\\,featureA2\\=0.0\\,featureAB\\=0.0\\,featureB1\\=0.0\\,featureB2\\=1.0",
|
||||
"featureA1\\=1.0\\,featureA2\\=0.0\\,featureAB\\=0.0\\,featureB1\\=0.0\\,featureB2\\=1.0",
|
||||
"featureA1\\=0.0\\,featureA2\\=0.0\\,featureAB\\=0.0\\,featureB1\\=1.0\\,featureB2\\=0.0"};
|
||||
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||
|
||||
String[] tests = new String[11];
|
||||
tests[0] = "/response/numFound/==5";
|
||||
for (int i = 1; i <= 5; i++) {
|
||||
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||
tests[i + 5] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
|
||||
}
|
||||
assertJQ("/query" + query.toQueryString(), tests);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void interleaving_explicitNewFeatureStore_shouldExtractAllFeaturesFromNewStore() throws Exception {
|
||||
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
loadFeaturesAndModels();
|
||||
|
||||
final SolrQuery query = new SolrQuery();
|
||||
query.setQuery("*:*");
|
||||
query.add("fl", "*, score,features:[fv store=featureStore2 efi.user_query='w5' format=dense]");
|
||||
query.add("rows", "10");
|
||||
query.add("debugQuery", "true");
|
||||
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||
query.add("rq",
|
||||
"{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}");
|
||||
|
||||
/*
|
||||
Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||
Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||
Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2)
|
||||
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12)
|
||||
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4)
|
||||
ModelARerankedList = [7,8,1,3,4]
|
||||
ModelBRerankedList = [7,1,3,8,4]
|
||||
|
||||
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
*/
|
||||
String[] expectedFeatureVectors = new String[]{
|
||||
"featureC1\\=0.0\\,featureC2\\=1.0\\,featureC3\\=1.0",
|
||||
"featureC1\\=0.0\\,featureC2\\=1.0\\,featureC3\\=1.0",
|
||||
"featureC1\\=0.0\\,featureC2\\=1.0\\,featureC3\\=1.0",
|
||||
"featureC1\\=0.0\\,featureC2\\=1.0\\,featureC3\\=1.0",
|
||||
"featureC1\\=1.0\\,featureC2\\=0.0\\,featureC3\\=0.0"};
|
||||
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||
|
||||
String[] tests = new String[16];
|
||||
tests[0] = "/response/numFound/==5";
|
||||
for (int i = 1; i <= 5; i++) {
|
||||
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||
tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
|
||||
}
|
||||
assertJQ("/query" + query.toQueryString(), tests);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void interleaving_withOriginalRankingAndExplicitFeatureStore_shouldReturnNewCalculatedFeatureVector() throws Exception {
|
||||
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
loadFeaturesAndModels();
|
||||
|
||||
final SolrQuery query = new SolrQuery();
|
||||
query.setQuery("*:*");
|
||||
query.add("fl", "*, score,features:[fv store=featureStore2 efi.user_query='w5' format=sparse]");
|
||||
query.add("rows", "10");
|
||||
query.add("debugQuery", "true");
|
||||
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||
query.add("rq",
|
||||
"{!ltr model=modelA model=_OriginalRanking_ reRankDocs=10 efi.user_query='w5'}");
|
||||
|
||||
/*
|
||||
Doc1 = "featureB2=1.0", ScoreA(0)
|
||||
Doc3 = "featureB2=1.0", ScoreA(0)
|
||||
Doc4 = "featureB1=1.0", ScoreA(0)
|
||||
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39)
|
||||
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3)
|
||||
|
||||
ModelARerankedList = [7,8,1,3,4]
|
||||
_OriginalRanking_ = [1,3,4,7,8]
|
||||
|
||||
|
||||
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
*/
|
||||
String[] expectedFeatureVectors = new String[]{
|
||||
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||
"featureC1\\=1.0"};
|
||||
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||
|
||||
String[] tests = new String[16];
|
||||
tests[0] = "/response/numFound/==5";
|
||||
for (int i = 1; i <= 5; i++) {
|
||||
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||
if (expectedFeatureVectors[(i - 1)] != null) {
|
||||
tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
|
||||
}
|
||||
}
|
||||
assertJQ("/query" + query.toQueryString(), tests);
|
||||
|
||||
int[] nullFeatureVectorIndexes = new int[]{1, 2, 4};
|
||||
for (int index : nullFeatureVectorIndexes) {
|
||||
TeamDraftInterleaving.setRANDOM(new Random(10101010));
|
||||
String[] nullFeatureVectorTests = new String[1];
|
||||
try {
|
||||
nullFeatureVectorTests[0] = "/response/docs/[" + index + "]/features==";
|
||||
assertJQ("/query" + query.toQueryString(), nullFeatureVectorTests);
|
||||
} catch (Exception e) {
|
||||
assertEquals("Path not found: /response/docs/[" + index + "]/features", e.getMessage());
|
||||
continue;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void interleaving_modelsFromDifferentFeatureStores_shouldLogFeaturesCorrectly() throws Exception {
|
||||
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
loadFeaturesAndModels();
|
||||
|
||||
final SolrQuery query = new SolrQuery();
|
||||
query.setQuery("*:*");
|
||||
query.add("fl", "*, score,features:[fv format=sparse]");
|
||||
query.add("rows", "10");
|
||||
query.add("debugQuery", "true");
|
||||
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||
query.add("rq",
|
||||
"{!ltr model=modelA model=modelC reRankDocs=10 efi.user_query='w5'}");
|
||||
|
||||
/*
|
||||
Doc1 = "featureC2=1.0", ScoreA(0), ScoreC(25)
|
||||
Doc3 = "featureC2=1.0", ScoreA(0), ScoreC(25)
|
||||
Doc4 = "featureC1=1.0", ScoreA(0), ScoreC(5)
|
||||
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0", ScoreA(39), ScoreC(0)
|
||||
Doc8 = "featureA1=1.0,featureC2=1.0", ScoreA(3), ScoreC(25)
|
||||
ModelARerankedList = [7,8,1,3,4]
|
||||
ModelCRerankedList = [1,3,8,4,7]
|
||||
|
||||
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
*/
|
||||
String[] expectedFeatureVectors = new String[]{
|
||||
"featureA1\\=1.0\\,featureA2\\=1.0\\,featureAB\\=1.0\\,featureB2\\=1.0",
|
||||
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||
"featureA1\\=1.0\\,featureB2\\=1.0",
|
||||
"featureC1\\=1.0"};
|
||||
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||
|
||||
String[] tests = new String[16];
|
||||
tests[0] = "/response/numFound/==5";
|
||||
for (int i = 1; i <= 5; i++) {
|
||||
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||
tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
|
||||
}
|
||||
assertJQ("/query" + query.toQueryString(), tests);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void interleaving_featureLoggerFromNewFeatureStoreWithDifferentEfi_shouldReturnNewCalculatedFeatureVector() throws Exception {
|
||||
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
loadFeaturesAndModels();
|
||||
|
||||
final SolrQuery query = new SolrQuery();
|
||||
query.setQuery("*:*");
|
||||
query.add("fl", "*, score,features:[fv store=featureStore2 efi.user_query='w8' format=sparse]");
|
||||
query.add("rows", "10");
|
||||
query.add("debugQuery", "true");
|
||||
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||
query.add("rq",
|
||||
"{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}");
|
||||
|
||||
/*
|
||||
Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||
Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||
Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2)
|
||||
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12)
|
||||
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4)
|
||||
ModelARerankedList = [7,8,1,3,4]
|
||||
ModelBRerankedList = [7,1,3,8,4]
|
||||
|
||||
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
*/
|
||||
String[] expectedFeatureVectors = new String[]{
|
||||
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
""};
|
||||
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||
|
||||
String[] tests = new String[16];
|
||||
tests[0] = "/response/numFound/==5";
|
||||
for (int i = 1; i <= 5; i++) {
|
||||
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||
tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
|
||||
}
|
||||
assertJQ("/query" + query.toQueryString(), tests);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void interleaving_explicitFeatureStoreReusableFromModel_shouldLogFeaturesCorrectly() throws Exception {
|
||||
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
loadFeaturesAndModels();
|
||||
|
||||
final SolrQuery query = new SolrQuery();
|
||||
query.setQuery("*:*");
|
||||
query.add("fl", "*, score,features:[fv store=featureStore2 format=sparse]");
|
||||
query.add("rows", "10");
|
||||
query.add("debugQuery", "true");
|
||||
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||
query.add("rq",
|
||||
"{!ltr model=modelA model=modelC reRankDocs=10 efi.user_query='w5'}");
|
||||
|
||||
/*
|
||||
Doc1 = "featureC2=1.0", ScoreA(0), ScoreC(25)
|
||||
Doc3 = "featureC2=1.0", ScoreA(0), ScoreC(25)
|
||||
Doc4 = "featureC1=1.0", ScoreA(0), ScoreC(5)
|
||||
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0", ScoreA(39), ScoreC(0)
|
||||
Doc8 = "featureA1=1.0,featureC2=1.0", ScoreA(3), ScoreC(25)
|
||||
ModelARerankedList = [7,8,1,3,4]
|
||||
ModelCRerankedList = [1,3,8,4,7]
|
||||
|
||||
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
*/
|
||||
String[] expectedFeatureVectors = new String[]{
|
||||
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||
"featureC2\\=1.0\\,featureC3\\=1.0",
|
||||
"featureC1\\=1.0"};
|
||||
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||
|
||||
String[] tests = new String[16];
|
||||
tests[0] = "/response/numFound/==5";
|
||||
for (int i = 1; i <= 5; i++) {
|
||||
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||
tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
|
||||
}
|
||||
assertJQ("/query" + query.toQueryString(), tests);
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,277 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.solr.ltr.response.transform;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
import org.apache.solr.client.solrj.SolrQuery;
|
||||
import org.apache.solr.ltr.TestRerankBase;
|
||||
import org.apache.solr.ltr.feature.SolrFeature;
|
||||
import org.apache.solr.ltr.interleaving.algorithms.TeamDraftInterleaving;
|
||||
import org.apache.solr.ltr.model.LinearModel;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
public class TestInterleavingTransformer extends TestRerankBase {
|
||||
@Before
|
||||
public void before() throws Exception {
|
||||
setuptest(false);
|
||||
|
||||
assertU(adoc("id", "1", "title", "w1", "description", "w5", "popularity",
|
||||
"1"));
|
||||
assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description",
|
||||
"w2 2asd asdd didid", "popularity", "2"));
|
||||
assertU(adoc("id", "3", "title", "w1", "description", "w5", "popularity",
|
||||
"3"));
|
||||
assertU(adoc("id", "4", "title", "w1", "description", "w1", "popularity",
|
||||
"6"));
|
||||
assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity",
|
||||
"5"));
|
||||
assertU(adoc("id", "6", "title", "w6 w2", "description", "w1 w2",
|
||||
"popularity", "6"));
|
||||
assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description",
|
||||
"w6 w2 w3 w4 w5 w8", "popularity", "88888"));
|
||||
assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description",
|
||||
"w1 w1 w1 w2 w2 w5", "popularity", "88888"));
|
||||
assertU(commit());
|
||||
|
||||
|
||||
}
|
||||
|
||||
@After
|
||||
public void after() throws Exception {
|
||||
aftertest();
|
||||
}
|
||||
|
||||
protected void loadFeaturesAndModelsForInterleaving() throws Exception {
|
||||
loadFeature("featureA1", SolrFeature.class.getName(),
|
||||
"{\"fq\":[\"{!terms f=popularity}88888\"]}");
|
||||
loadFeature("featureA2", SolrFeature.class.getName(),
|
||||
"{\"fq\":[\"{!terms f=title}${user_query}\"]}");
|
||||
loadFeature("featureAB", SolrFeature.class.getName(),
|
||||
"{\"fq\":[\"{!terms f=title}${user_query}\"]}");
|
||||
loadFeature("featureB1", SolrFeature.class.getName(),
|
||||
"{\"fq\":[\"{!terms f=popularity}6\"]}");
|
||||
loadFeature("featureB2", SolrFeature.class.getName(),
|
||||
"{\"fq\":[\"{!terms f=description}${user_query}\"]}");
|
||||
loadFeature("featureC1", SolrFeature.class.getName(),"featureStore2",
|
||||
"{\"fq\":[\"{!terms f=popularity}6\"]}");
|
||||
loadFeature("featureC2", SolrFeature.class.getName(),"featureStore2",
|
||||
"{\"fq\":[\"{!terms f=popularity}1\"]}");
|
||||
|
||||
loadModel("modelA", LinearModel.class.getName(),
|
||||
new String[]{"featureA1", "featureA2", "featureAB"},
|
||||
"{\"weights\":{\"featureA1\":3.0, \"featureA2\":9.0, \"featureAB\":27.0}}");
|
||||
|
||||
loadModel("modelB", LinearModel.class.getName(),
|
||||
new String[]{"featureB1", "featureB2", "featureAB"},
|
||||
"{\"weights\":{\"featureB1\":2.0, \"featureB2\":4.0, \"featureAB\":8.0}}");
|
||||
|
||||
loadModel("modelC", LinearModel.class.getName(),
|
||||
new String[]{"featureC1", "featureC2"},"featureStore2",
|
||||
"{\"weights\":{\"featureC1\":5.0, \"featureC2\":25.0}}");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void interleavingTransformer_shouldReturnInterleavingPickInTheResults() throws Exception {
|
||||
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
loadFeaturesAndModelsForInterleaving();
|
||||
|
||||
final SolrQuery query = new SolrQuery();
|
||||
query.setQuery("*:*");
|
||||
query.add("fl", "*, score,interleavingPick:[interleaving]");
|
||||
query.add("rows", "10");
|
||||
query.add("debugQuery", "true");
|
||||
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||
query.add("rq",
|
||||
"{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}");
|
||||
|
||||
/*
|
||||
Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||
Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||
Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2)
|
||||
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12)
|
||||
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4)
|
||||
ModelARerankedList = [7,8,1,3,4]
|
||||
ModelBRerankedList = [7,1,3,8,4]
|
||||
|
||||
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
*/
|
||||
String[] expectedInterleavingPicks = new String[]{"modelA", "modelB", "modelB", "modelA", "modelB"};
|
||||
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||
|
||||
String[] tests = new String[11];
|
||||
tests[0] = "/response/numFound/==5";
|
||||
for (int i = 1; i <= 5; i++) {
|
||||
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||
tests[i + 5] = "/response/docs/[" + (i - 1) + "]/interleavingPick==" + expectedInterleavingPicks[(i - 1)];
|
||||
}
|
||||
assertJQ("/query" + query.toQueryString(), tests);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void interleavingTransformerWithOriginalRanking_shouldReturnInterleavingPickInTheResults() throws Exception {
|
||||
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
loadFeaturesAndModelsForInterleaving();
|
||||
|
||||
final SolrQuery query = new SolrQuery();
|
||||
query.setQuery("*:*");
|
||||
query.add("fl", "*, score,interleavingPick:[interleaving]");
|
||||
query.add("rows", "10");
|
||||
query.add("debugQuery", "true");
|
||||
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||
query.add("rq",
|
||||
"{!ltr model=modelA model=_OriginalRanking_ reRankDocs=10 efi.user_query='w5'}");
|
||||
|
||||
/*
|
||||
Doc1 = "featureB2=1.0", ScoreA(0)
|
||||
Doc3 = "featureB2=1.0", ScoreA(0)
|
||||
Doc4 = "featureB1=1.0", ScoreA(0)
|
||||
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39)
|
||||
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3)
|
||||
|
||||
ModelARerankedList = [7,8,1,3,4]
|
||||
OriginalRanking = [1,3,4,7,8]
|
||||
|
||||
|
||||
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
*/
|
||||
String[] expectedInterleavingPicks = new String[]{"modelA", "_OriginalRanking_", "_OriginalRanking_", "modelA", "_OriginalRanking_"};
|
||||
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||
|
||||
String[] tests = new String[11];
|
||||
tests[0] = "/response/numFound/==5";
|
||||
for (int i = 1; i <= 5; i++) {
|
||||
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||
tests[i + 5] = "/response/docs/[" + (i - 1) + "]/interleavingPick==" + expectedInterleavingPicks[(i - 1)];
|
||||
}
|
||||
assertJQ("/query" + query.toQueryString(), tests);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void interleavingTransformer_shouldBeCompatibleWithFeatureTransformer() throws Exception {
|
||||
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
loadFeaturesAndModelsForInterleaving();
|
||||
|
||||
final SolrQuery query = new SolrQuery();
|
||||
query.setQuery("*:*");
|
||||
query.add("fl", "*, score,interleavingPick:[interleaving],features:[fv format=sparse]");
|
||||
query.add("rows", "10");
|
||||
query.add("debugQuery", "true");
|
||||
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||
query.add("rq",
|
||||
"{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}");
|
||||
|
||||
/*
|
||||
Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||
Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4)
|
||||
Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2)
|
||||
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12)
|
||||
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4)
|
||||
ModelARerankedList = [7,8,1,3,4]
|
||||
ModelBRerankedList = [7,1,3,8,4]
|
||||
|
||||
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
*/
|
||||
String[] expectedFeatureVectors = new String[]{
|
||||
"featureA1\\=1.0\\,featureA2\\=1.0\\,featureAB\\=1.0\\,featureB2\\=1.0",
|
||||
"featureB2\\=1.0",
|
||||
"featureB2\\=1.0",
|
||||
"featureA1\\=1.0\\,featureB2\\=1.0",
|
||||
"featureB1\\=1.0"};
|
||||
String[] expectedInterleavingPicks = new String[]{"modelA", "modelB", "modelB", "modelA", "modelB"};
|
||||
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||
|
||||
String[] tests = new String[16];
|
||||
tests[0] = "/response/numFound/==5";
|
||||
for (int i = 1; i <= 5; i++) {
|
||||
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||
tests[i + 5] = "/response/docs/[" + (i - 1) + "]/interleavingPick==" + expectedInterleavingPicks[(i - 1)];
|
||||
tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
|
||||
}
|
||||
assertJQ("/query" + query.toQueryString(), tests);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void interleavingTransformerWithOriginalRanking_shouldBeCompatibleWithFeatureTransformer() throws Exception {
|
||||
TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
loadFeaturesAndModelsForInterleaving();
|
||||
|
||||
final SolrQuery query = new SolrQuery();
|
||||
query.setQuery("*:*");
|
||||
query.add("fl", "*, score,interleavingPick:[interleaving],features:[fv format=sparse]");
|
||||
query.add("rows", "10");
|
||||
query.add("debugQuery", "true");
|
||||
query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
|
||||
query.add("rq",
|
||||
"{!ltr model=modelA model=_OriginalRanking_ reRankDocs=10 efi.user_query='w5'}");
|
||||
|
||||
/*
|
||||
Doc1 = "featureB2=1.0", ScoreA(0)
|
||||
Doc3 = "featureB2=1.0", ScoreA(0)
|
||||
Doc4 = "featureB1=1.0", ScoreA(0)
|
||||
Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39)
|
||||
Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3)
|
||||
|
||||
ModelARerankedList = [7,8,1,3,4]
|
||||
_OriginalRanking_ = [1,3,4,7,8]
|
||||
|
||||
|
||||
Random Boolean Choices Generation from Seed: [0,1,1]
|
||||
*/
|
||||
String[] expectedFeatureVectors = new String[]{
|
||||
"featureA1\\=1.0\\,featureA2\\=1.0\\,featureAB\\=1.0\\,featureB2\\=1.0",
|
||||
null,
|
||||
null,
|
||||
"featureA1\\=1.0\\,featureB2\\=1.0",
|
||||
null};
|
||||
String[] expectedInterleavingPicks = new String[]{"modelA", "_OriginalRanking_", "_OriginalRanking_", "modelA", "_OriginalRanking_"};
|
||||
int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4};
|
||||
|
||||
String[] tests = new String[16];
|
||||
tests[0] = "/response/numFound/==5";
|
||||
for (int i = 1; i <= 5; i++) {
|
||||
tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
|
||||
tests[i + 5] = "/response/docs/[" + (i - 1) + "]/interleavingPick==" + expectedInterleavingPicks[(i - 1)];
|
||||
if (expectedFeatureVectors[(i - 1)] != null) {
|
||||
tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
|
||||
}
|
||||
}
|
||||
assertJQ("/query" + query.toQueryString(), tests);
|
||||
|
||||
int[] nullFeatureVectorIndexes = new int[]{1, 2, 4};
|
||||
for (int index : nullFeatureVectorIndexes) {
|
||||
TeamDraftInterleaving.setRANDOM(new Random(10101010));
|
||||
String[] nullFeatureVectorTests = new String[1];
|
||||
try {
|
||||
nullFeatureVectorTests[0] = "/response/docs/[" + index + "]/features==";
|
||||
assertJQ("/query" + query.toQueryString(), nullFeatureVectorTests);
|
||||
} catch (Exception e) {
|
||||
assertEquals("Path not found: /response/docs/[" + index + "]/features", e.getMessage());
|
||||
continue;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -38,6 +38,13 @@ A ranking model computes the scores used to rerank documents. Irrespective of an
|
|||
* features that represent the document being scored
|
||||
* features that represent the query for which the document is being scored
|
||||
|
||||
==== Interleaving
|
||||
|
||||
Interleaving is an approach to Online Search Quality evaluation that allows to compare two models interleaving their results in the final ranked list returned to the user.
|
||||
|
||||
* currently only the Team Draft Interleaving algorithm is supported (and its implementation assumes all results are from the same shard)
|
||||
|
||||
|
||||
==== Feature
|
||||
|
||||
A feature is a value, a number, that represents some quantity or quality of the document being scored or of the query for which documents are being scored. For example documents often have a 'recency' quality and 'number of past purchases' might be a quantity that is passed to Solr as part of the search query.
|
||||
|
@ -247,6 +254,81 @@ The output XML will include feature values as a comma-separated list, resembling
|
|||
}}
|
||||
----
|
||||
|
||||
=== Running a Rerank Query Interleaving Two Models
|
||||
|
||||
To rerank the results of a query, interleaving two models (myModelA, myModelB) add the `rq` parameter to your search, passing two models in input, for example:
|
||||
|
||||
[source,text]
|
||||
http://localhost:8983/solr/techproducts/query?q=test&rq={!ltr model=myModelA model=myModelB reRankDocs=100}&fl=id,score
|
||||
|
||||
To obtain the model that interleaving picked for a search result, computed during reranking, add `[interleaving]` to the `fl` parameter, for example:
|
||||
|
||||
[source,text]
|
||||
http://localhost:8983/solr/techproducts/query?q=test&rq={!ltr model=myModelA model=myModelB reRankDocs=100}&fl=id,score,[interleaving]
|
||||
|
||||
The output XML will include the model picked for each search result, resembling the output shown here:
|
||||
|
||||
[source,json]
|
||||
----
|
||||
{
|
||||
"responseHeader":{
|
||||
"status":0,
|
||||
"QTime":0,
|
||||
"params":{
|
||||
"q":"test",
|
||||
"fl":"id,score,[interleaving]",
|
||||
"rq":"{!ltr model=myModelA model=myModelB reRankDocs=100}"}},
|
||||
"response":{"numFound":2,"start":0,"maxScore":1.0005897,"docs":[
|
||||
{
|
||||
"id":"GB18030TEST",
|
||||
"score":1.0005897,
|
||||
"[interleaving]":"myModelB"},
|
||||
{
|
||||
"id":"UTF8TEST",
|
||||
"score":0.79656565,
|
||||
"[interleaving]":"myModelA"}]
|
||||
}}
|
||||
----
|
||||
|
||||
=== Running a Rerank Query Interleaving a model with the original ranking
|
||||
When approaching Search Quality Evaluation with interleaving it may be useful to compare a model with the original ranking.
|
||||
To rerank the results of a query, interleaving a model with the original ranking, add the `rq` parameter to your search, passing the special inbuilt `_OriginalRanking_` model identifier as one model and your comparison model as the other model, for example:
|
||||
|
||||
|
||||
[source,text]
|
||||
http://localhost:8983/solr/techproducts/query?q=test&rq={!ltr model=_OriginalRanking_ model=myModel reRankDocs=100}&fl=id,score
|
||||
|
||||
The addition of the `rq` parameter will not change the output XML of the search.
|
||||
|
||||
To obtain the model that interleaving picked for a search result, computed during reranking, add `[interleaving]` to the `fl` parameter, for example:
|
||||
|
||||
[source,text]
|
||||
http://localhost:8983/solr/techproducts/query?q=test&rq={!ltr model=_OriginalRanking_ model=myModel reRankDocs=100}&fl=id,score,[interleaving]
|
||||
|
||||
The output XML will include the model picked for each search result, resembling the output shown here:
|
||||
|
||||
[source,json]
|
||||
----
|
||||
{
|
||||
"responseHeader":{
|
||||
"status":0,
|
||||
"QTime":0,
|
||||
"params":{
|
||||
"q":"test",
|
||||
"fl":"id,score,[features]",
|
||||
"rq":"{!ltr model=_OriginalRanking_ model=myModel reRankDocs=100}"}},
|
||||
"response":{"numFound":2,"start":0,"maxScore":1.0005897,"docs":[
|
||||
{
|
||||
"id":"GB18030TEST",
|
||||
"score":1.0005897,
|
||||
"[interleaving]":"_OriginalRanking_"},
|
||||
{
|
||||
"id":"UTF8TEST",
|
||||
"score":0.79656565,
|
||||
"[interleaving]":"myModel"}]
|
||||
}}
|
||||
----
|
||||
|
||||
=== External Feature Information
|
||||
|
||||
The {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/feature/ValueFeature.html[ValueFeature] and {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/feature/SolrFeature.html[SolrFeature] classes support the use of external feature information, `efi` for short.
|
||||
|
@ -418,6 +500,13 @@ Learning-To-Rank is a contrib module and therefore its plugins must be configure
|
|||
</transformer>
|
||||
----
|
||||
|
||||
* Declaration of the `[interleaving]` transformer.
|
||||
+
|
||||
[source,xml]
|
||||
----
|
||||
<transformer name="interleaving" class="org.apache.solr.ltr.response.transform.LTRInterleavingTransformerFactory"/>
|
||||
----
|
||||
|
||||
=== Advanced Options
|
||||
|
||||
==== LTRThreadModule
|
||||
|
@ -446,11 +535,12 @@ How does Solr Learning-To-Rank work under the hood?::
|
|||
Please refer to the `ltr` {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/package-summary.html[javadocs] for an implementation overview.
|
||||
|
||||
How could I write additional models and/or features?::
|
||||
Contributions for further models, features and normalizers are welcome. Related links:
|
||||
Contributions for further models, features, normalizers and interleaving algorithms are welcome. Related links:
|
||||
+
|
||||
* {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/model/LTRScoringModel.html[LTRScoringModel javadocs]
|
||||
* {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/feature/Feature.html[Feature javadocs]
|
||||
* {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/norm/Normalizer.html[Normalizer javadocs]
|
||||
* {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/interleaving/Interleaving.html[Interleaving javadocs]
|
||||
* https://cwiki.apache.org/confluence/display/solr/HowToContribute
|
||||
* https://cwiki.apache.org/confluence/display/LUCENE/HowToContribute
|
||||
|
||||
|
@ -779,3 +869,7 @@ The feature store and the model store are both <<managed-resources.adoc#managed-
|
|||
* "Learning to Rank in Solr" presentation at Lucene/Solr Revolution 2015 in Austin:
|
||||
** Slides: http://www.slideshare.net/lucidworks/learning-to-rank-in-solr-presented-by-michael-nilsson-diego-ceccarelli-bloomberg-lp
|
||||
** Video: https://www.youtube.com/watch?v=M7BKwJoh96s
|
||||
|
||||
* The importance of Online Testing in Learning To Rank:
|
||||
** Blog: https://sease.io/2020/04/the-importance-of-online-testing-in-learning-to-rank-part-1.html
|
||||
** Blog: https://sease.io/2020/05/online-testing-for-learning-to-rank-interleaving.html
|
||||
|
|
Loading…
Reference in New Issue