diff --git a/solr/CHANGES.txt b/solr/CHANGES.txt index 53aff7c0a48..7b470931c6f 100644 --- a/solr/CHANGES.txt +++ b/solr/CHANGES.txt @@ -167,6 +167,8 @@ New Features * SOLR-14907: Add v2 API for configSet upload, including single-file insertion. (Houston Putman) +* SOLR-14560: Add interleaving support in Learning To Rank. (Alessandro Benedetti, Christine Poerschke) + Improvements --------------------- * SOLR-14942: Reduce leader election time on node shutdown by removing election nodes before closing cores. diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java index 5412715518e..c388c269e31 100644 --- a/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java @@ -31,6 +31,7 @@ import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.Weight; +import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery; import org.apache.solr.search.SolrIndexSearcher; @@ -42,12 +43,17 @@ import org.apache.solr.search.SolrIndexSearcher; * */ public class LTRRescorer extends Rescorer { - LTRScoringQuery scoringQuery; + final private LTRScoringQuery scoringQuery; + + public LTRRescorer() { + this.scoringQuery = null; + } + public LTRRescorer(LTRScoringQuery scoringQuery) { this.scoringQuery = scoringQuery; } - private void heapAdjust(ScoreDoc[] hits, int size, int root) { + protected static void heapAdjust(ScoreDoc[] hits, int size, int root) { final ScoreDoc doc = hits[root]; final float score = doc.score; int i = root; @@ -82,7 +88,7 @@ public class LTRRescorer extends Rescorer { } } - private void heapify(ScoreDoc[] hits, int size) { + protected static void heapify(ScoreDoc[] hits, int size) { for (int i = (size >> 1) - 1; i >= 0; i--) { heapAdjust(hits, size, i); } @@ -104,23 +110,27 @@ public class LTRRescorer extends Rescorer { if ((topN == 0) || (firstPassTopDocs.scoreDocs.length == 0)) { return firstPassTopDocs; } - final ScoreDoc[] hits = firstPassTopDocs.scoreDocs; - Arrays.sort(hits, new Comparator() { - @Override - public int compare(ScoreDoc a, ScoreDoc b) { - return a.doc - b.doc; - } - }); - - assert firstPassTopDocs.totalHits.relation == TotalHits.Relation.EQUAL_TO; + final ScoreDoc[] firstPassResults = getFirstPassDocsRanked(firstPassTopDocs); topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value)); + + final ScoreDoc[] reranked = rerank(searcher, topN, firstPassResults); + + return new TopDocs(firstPassTopDocs.totalHits, reranked); + } + + private ScoreDoc[] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPassResults) throws IOException { final ScoreDoc[] reranked = new ScoreDoc[topN]; final List leaves = searcher.getIndexReader().leaves(); final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) searcher .createWeight(searcher.rewrite(scoringQuery), ScoreMode.COMPLETE, 1); - scoreFeatures(searcher, firstPassTopDocs,topN, modelWeight, hits, leaves, reranked); + scoreFeatures(searcher,topN, modelWeight, firstPassResults, leaves, reranked); // Must sort all documents that we reranked, and then select the top + sortByScore(reranked); + return reranked; + } + + protected static void sortByScore(ScoreDoc[] reranked) { Arrays.sort(reranked, new Comparator() { @Override public int compare(ScoreDoc a, ScoreDoc b) { @@ -136,13 +146,24 @@ public class LTRRescorer extends Rescorer { } } }); - - return new TopDocs(firstPassTopDocs.totalHits, reranked); } - public void scoreFeatures(IndexSearcher indexSearcher, TopDocs firstPassTopDocs, - int topN, LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List leaves, - ScoreDoc[] reranked) throws IOException { + protected static ScoreDoc[] getFirstPassDocsRanked(TopDocs firstPassTopDocs) { + final ScoreDoc[] hits = firstPassTopDocs.scoreDocs; + Arrays.sort(hits, new Comparator() { + @Override + public int compare(ScoreDoc a, ScoreDoc b) { + return a.doc - b.doc; + } + }); + + assert firstPassTopDocs.totalHits.relation == TotalHits.Relation.EQUAL_TO; + return hits; + } + + public void scoreFeatures(IndexSearcher indexSearcher, + int topN, LTRScoringQuery.ModelWeight modelWeight, ScoreDoc[] hits, List leaves, + ScoreDoc[] reranked) throws IOException { int readerUpto = -1; int endDoc = 0; @@ -150,7 +171,6 @@ public class LTRRescorer extends Rescorer { LTRScoringQuery.ModelWeight.ModelScorer scorer = null; int hitUpto = 0; - final FeatureLogger featureLogger = scoringQuery.getFeatureLogger(); while (hitUpto < hits.length) { final ScoreDoc hit = hits[hitUpto]; @@ -166,64 +186,77 @@ public class LTRRescorer extends Rescorer { docBase = readerContext.docBase; scorer = modelWeight.scorer(readerContext); } - // Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to - // call score - // even if no feature scorers match, since a model might use that info to - // return a - // non-zero score. Same applies for the case of advancing a LTRScoringQuery.ModelWeight.ModelScorer - // past the target - // doc since the model algorithm still needs to compute a potentially - // non-zero score from blank features. - assert (scorer != null); - final int targetDoc = docID - docBase; - scorer.docID(); - scorer.iterator().advance(targetDoc); + scoreSingleHit(indexSearcher, topN, modelWeight, docBase, hitUpto, hit, docID, scoringQuery, scorer, reranked); + hitUpto++; + } + } - scorer.getDocInfo().setOriginalDocScore(hit.score); - hit.score = scorer.score(); - if (hitUpto < topN) { - reranked[hitUpto] = hit; - // if the heap is not full, maybe I want to log the features for this - // document + protected static void scoreSingleHit(IndexSearcher indexSearcher, int topN, LTRScoringQuery.ModelWeight modelWeight, int docBase, int hitUpto, ScoreDoc hit, int docID, LTRScoringQuery rerankingQuery, LTRScoringQuery.ModelWeight.ModelScorer scorer, ScoreDoc[] reranked) throws IOException { + final FeatureLogger featureLogger = rerankingQuery.getFeatureLogger(); + // Scorer for a LTRScoringQuery.ModelWeight should never be null since we always have to + // call score + // even if no feature scorers match, since a model might use that info to + // return a + // non-zero score. Same applies for the case of advancing a LTRScoringQuery.ModelWeight.ModelScorer + // past the target + // doc since the model algorithm still needs to compute a potentially + // non-zero score from blank features. + assert (scorer != null); + final int targetDoc = docID - docBase; + scorer.docID(); + scorer.iterator().advance(targetDoc); + + scorer.getDocInfo().setOriginalDocScore(hit.score); + hit.score = scorer.score(); + if (hitUpto < topN) { + reranked[hitUpto] = hit; + // if the heap is not full, maybe I want to log the features for this + // document + if (featureLogger != null && indexSearcher instanceof SolrIndexSearcher) { + featureLogger.log(hit.doc, rerankingQuery, (SolrIndexSearcher) indexSearcher, + modelWeight.getFeaturesInfo()); + } + } else if (hitUpto == topN) { + // collected topN document, I create the heap + heapify(reranked, topN); + } + if (hitUpto >= topN) { + // once that heap is ready, if the score of this document is lower that + // the minimum + // i don't want to log the feature. Otherwise I replace it with the + // minimum and fix the + // heap. + if (hit.score > reranked[0].score) { + reranked[0] = hit; + heapAdjust(reranked, topN, 0); if (featureLogger != null && indexSearcher instanceof SolrIndexSearcher) { - featureLogger.log(hit.doc, scoringQuery, (SolrIndexSearcher)indexSearcher, + featureLogger.log(hit.doc, rerankingQuery, (SolrIndexSearcher) indexSearcher, modelWeight.getFeaturesInfo()); } - } else if (hitUpto == topN) { - // collected topN document, I create the heap - heapify(reranked, topN); } - if (hitUpto >= topN) { - // once that heap is ready, if the score of this document is lower that - // the minimum - // i don't want to log the feature. Otherwise I replace it with the - // minimum and fix the - // heap. - if (hit.score > reranked[0].score) { - reranked[0] = hit; - heapAdjust(reranked, topN, 0); - if (featureLogger != null && indexSearcher instanceof SolrIndexSearcher) { - featureLogger.log(hit.doc, scoringQuery, (SolrIndexSearcher)indexSearcher, - modelWeight.getFeaturesInfo()); - } - } - } - hitUpto++; } } @Override public Explanation explain(IndexSearcher searcher, Explanation firstPassExplanation, int docID) throws IOException { + return getExplanation(searcher, docID, scoringQuery); + } + protected static Explanation getExplanation(IndexSearcher searcher, int docID, LTRScoringQuery rerankingQuery) throws IOException { final List leafContexts = searcher.getTopReaderContext() .leaves(); final int n = ReaderUtil.subIndex(docID, leafContexts); final LeafReaderContext context = leafContexts.get(n); final int deBasedDoc = docID - context.docBase; - final Weight modelWeight = searcher.createWeight(searcher.rewrite(scoringQuery), - ScoreMode.COMPLETE, 1); - return modelWeight.explain(context, deBasedDoc); + final Weight rankingWeight; + if (rerankingQuery instanceof OriginalRankingLTRScoringQuery) { + rankingWeight = rerankingQuery.getOriginalQuery().createWeight(searcher, ScoreMode.COMPLETE, 1); + } else { + rankingWeight = searcher.createWeight(searcher.rewrite(rerankingQuery), + ScoreMode.COMPLETE, 1); + } + return rankingWeight.explain(context, deBasedDoc); } public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo(LTRScoringQuery.ModelWeight modelWeight, diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index 20b1fef59c6..dcffbd6b26b 100644 --- a/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -102,6 +102,10 @@ public class LTRScoringQuery extends Query implements Accountable { return ltrScoringModel; } + public String getScoringModelName() { + return ltrScoringModel.getName(); + } + public void setFeatureLogger(FeatureLogger fl) { this.fl = fl; } diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/SolrQueryRequestContextUtils.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/SolrQueryRequestContextUtils.java index 2cff28f6b78..295597b1bc4 100644 --- a/solr/contrib/ltr/src/java/org/apache/solr/ltr/SolrQueryRequestContextUtils.java +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/SolrQueryRequestContextUtils.java @@ -26,8 +26,8 @@ public class SolrQueryRequestContextUtils { /** key of the feature logger in the request context **/ private static final String FEATURE_LOGGER = LTR_PREFIX + "feature_logger"; - /** key of the scoring query in the request context **/ - private static final String SCORING_QUERY = LTR_PREFIX + "scoring_query"; + /** key of the scoring queries in the request context **/ + private static final String SCORING_QUERIES = LTR_PREFIX + "scoring_queries"; /** key of the isExtractingFeatures flag in the request context **/ private static final String IS_EXTRACTING_FEATURES = LTR_PREFIX + "isExtractingFeatures"; @@ -47,12 +47,12 @@ public class SolrQueryRequestContextUtils { /** scoring query accessors **/ - public static void setScoringQuery(SolrQueryRequest req, LTRScoringQuery scoringQuery) { - req.getContext().put(SCORING_QUERY, scoringQuery); + public static void setScoringQueries(SolrQueryRequest req, LTRScoringQuery[] scoringQueries) { + req.getContext().put(SCORING_QUERIES, scoringQueries); } - public static LTRScoringQuery getScoringQuery(SolrQueryRequest req) { - return (LTRScoringQuery) req.getContext().get(SCORING_QUERY); + public static LTRScoringQuery[] getScoringQueries(SolrQueryRequest req) { + return (LTRScoringQuery[]) req.getContext().get(SCORING_QUERIES); } /** isExtractingFeatures flag accessors **/ diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/Interleaving.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/Interleaving.java new file mode 100644 index 00000000000..1038acad77f --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/Interleaving.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.ltr.interleaving; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.solr.ltr.interleaving.algorithms.TeamDraftInterleaving; + +/** + * Interleaving considers two ranking models: modelA and modelB. + * For a given query, each model returns its ranked list of documents La = (a1,a2,...) and Lb = (b1, b2, ...). + * An Interleaving algorithm creates a unique ranked list I = (i1, i2, ...). + * This list is created by interleaving elements from the two lists la and lb as described by the implementation algorithm. + * Each element Ij is labelled TeamA if it is selected from La and TeamB if it is selected from Lb. + */ +public interface Interleaving { + + String TEAM_DRAFT = "TeamDraft"; + + InterleavingResult interleave(ScoreDoc[] rerankedA, ScoreDoc[] rerankedB); + + static Interleaving getImplementation(String algorithm) { + switch(algorithm) { + case TEAM_DRAFT: + default: + return new TeamDraftInterleaving(); + } + } +} \ No newline at end of file diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/InterleavingResult.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/InterleavingResult.java new file mode 100644 index 00000000000..aeaac2997a6 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/InterleavingResult.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.ltr.interleaving; + +import java.util.ArrayList; +import java.util.Set; + +import org.apache.lucene.search.ScoreDoc; + +public class InterleavingResult { + final private ScoreDoc[] interleavedResults; + final private ArrayList> interleavingPicks; + + public InterleavingResult(ScoreDoc[] interleavedResults, ArrayList> interleavingPicks) { + this.interleavedResults = interleavedResults; + this.interleavingPicks = interleavingPicks; + } + + public ScoreDoc[] getInterleavedResults() { + return interleavedResults; + } + + public ArrayList> getInterleavingPicks() { + return interleavingPicks; + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingQuery.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingQuery.java new file mode 100644 index 00000000000..7c00195ceb6 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingQuery.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.ltr.interleaving; + +import java.io.IOException; +import java.util.Arrays; + +import org.apache.lucene.search.Query; +import org.apache.solr.ltr.search.LTRQuery; +import org.apache.solr.search.RankQuery; + +/** + * A learning to rank Query with Interleaving, will incapsulate two models, and delegate to it the rescoring + * of the documents. + **/ +public class LTRInterleavingQuery extends LTRQuery { + private final LTRInterleavingScoringQuery[] rerankingQueries; + private final Interleaving interlavingAlgorithm; + + public LTRInterleavingQuery(Interleaving interleavingAlgorithm, LTRInterleavingScoringQuery[] rerankingQueries, int rerankDocs) { + super(null, rerankDocs, new LTRInterleavingRescorer(interleavingAlgorithm, rerankingQueries)); + this.rerankingQueries = rerankingQueries; + this.interlavingAlgorithm = interleavingAlgorithm; + } + + @Override + public int hashCode() { + return 31 * classHash() + (mainQuery.hashCode() + rerankingQueries.hashCode() + reRankDocs); + } + + @Override + public boolean equals(Object o) { + return sameClassAs(o) && equalsTo(getClass().cast(o)); + } + + private boolean equalsTo(LTRInterleavingQuery other) { + return (mainQuery.equals(other.mainQuery) + && rerankingQueries.equals(other.rerankingQueries) && (reRankDocs == other.reRankDocs)); + } + + @Override + public RankQuery wrap(Query _mainQuery) { + super.wrap(_mainQuery); + for(LTRInterleavingScoringQuery rerankingQuery: rerankingQueries){ + rerankingQuery.setOriginalQuery(_mainQuery); + } + return this; + } + + @Override + public String toString(String field) { + return "{!ltr mainQuery='" + mainQuery.toString() + "' rerankingQueries='" + + Arrays.toString(rerankingQueries) + "' reRankDocs=" + reRankDocs + "}"; + } + + @Override + protected Query rewrite(Query rewrittenMainQuery) throws IOException { + return new LTRInterleavingQuery(interlavingAlgorithm, rerankingQueries, reRankDocs).wrap(rewrittenMainQuery); + } +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java new file mode 100644 index 00000000000..f71317cf946 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingRescorer.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.interleaving; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TopDocs; +import org.apache.solr.ltr.LTRRescorer; +import org.apache.solr.ltr.LTRScoringQuery; + +/** + * Implements the rescoring logic. The top documents returned by solr with their + * original scores, will be processed by a {@link LTRScoringQuery} that will assign a + * new score to each document. The top documents will be resorted based on the + * new score. + * */ +public class LTRInterleavingRescorer extends LTRRescorer { + + final private LTRInterleavingScoringQuery[] rerankingQueries; + private Integer originalRankingIndex = null; + final private Interleaving interleavingAlgorithm; + + public LTRInterleavingRescorer( Interleaving interleavingAlgorithm, LTRInterleavingScoringQuery[] rerankingQueries) { + this.rerankingQueries = rerankingQueries; + this.interleavingAlgorithm = interleavingAlgorithm; + for(int i=0;i> interleavingPicks = interleaved.getInterleavingPicks(); + rerankingQueries[0].setPickedInterleavingDocIds(interleavingPicks.get(0)); + rerankingQueries[1].setPickedInterleavingDocIds(interleavingPicks.get(1)); + + return new TopDocs(firstPassTopDocs.totalHits, interleavedResults); + } + + private ScoreDoc[][] rerank(IndexSearcher searcher, int topN, ScoreDoc[] firstPassResults) throws IOException { + ScoreDoc[][] reRankedPerModel = new ScoreDoc[rerankingQueries.length][topN]; + final List leaves = searcher.getIndexReader().leaves(); + LTRScoringQuery.ModelWeight[] modelWeights = new LTRScoringQuery.ModelWeight[rerankingQueries.length]; + for (int i = 0; i < rerankingQueries.length; i++) { + if (originalRankingIndex == null || originalRankingIndex != i) { + modelWeights[i] = (LTRScoringQuery.ModelWeight) searcher + .createWeight(searcher.rewrite(rerankingQueries[i]), ScoreMode.COMPLETE, 1); + } + } + scoreFeatures(searcher, topN, modelWeights, firstPassResults, leaves, reRankedPerModel); + + for (int i = 0; i < rerankingQueries.length; i++) { + if (originalRankingIndex == null || originalRankingIndex != i) { + sortByScore(reRankedPerModel[i]); + } + } + + return reRankedPerModel; + } + + public void scoreFeatures(IndexSearcher indexSearcher, + int topN, LTRScoringQuery.ModelWeight[] modelWeights, ScoreDoc[] hits, List leaves, + ScoreDoc[][] rerankedPerModel) throws IOException { + + int readerUpto = -1; + int endDoc = 0; + int docBase = 0; + int hitUpto = 0; + LTRScoringQuery.ModelWeight.ModelScorer[] scorers = new LTRScoringQuery.ModelWeight.ModelScorer[rerankingQueries.length]; + while (hitUpto < hits.length) { + final ScoreDoc hit = hits[hitUpto]; + final int docID = hit.doc; + LeafReaderContext readerContext = null; + while (docID >= endDoc) { + readerUpto++; + readerContext = leaves.get(readerUpto); + endDoc = readerContext.docBase + readerContext.reader().maxDoc(); + } + + // We advanced to another segment + if (readerContext != null) { + docBase = readerContext.docBase; + for (int i = 0; i < modelWeights.length; i++) { + if (modelWeights[i] != null) { + scorers[i] = modelWeights[i].scorer(readerContext); + } + } + } + for (int i = 0; i < rerankingQueries.length; i++) { + if (modelWeights[i] != null) { + scoreSingleHit(indexSearcher, topN, modelWeights[i], docBase, hitUpto, new ScoreDoc(hit.doc, hit.score, hit.shardIndex), docID, rerankingQueries[i], scorers[i], rerankedPerModel[i]); + } + } + hitUpto++; + } + + } + + @Override + public Explanation explain(IndexSearcher searcher, + Explanation firstPassExplanation, int docID) throws IOException { + LTRScoringQuery pickedRerankModel = rerankingQueries[0]; + if (rerankingQueries[1].getPickedInterleavingDocIds().contains(docID)) { + pickedRerankModel = rerankingQueries[1]; + } + return getExplanation(searcher, docID, pickedRerankModel); + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingScoringQuery.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingScoringQuery.java new file mode 100644 index 00000000000..cb2da2c45d1 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/LTRInterleavingScoringQuery.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.interleaving; + +import java.util.Map; +import java.util.Set; + +import org.apache.solr.ltr.LTRScoringQuery; +import org.apache.solr.ltr.LTRThreadModule; +import org.apache.solr.ltr.model.LTRScoringModel; + +public class LTRInterleavingScoringQuery extends LTRScoringQuery { + + // Model was picked for this Docs + private Set pickedInterleavingDocIds; + + public LTRInterleavingScoringQuery(LTRScoringModel ltrScoringModel) { + super(ltrScoringModel); + } + + public LTRInterleavingScoringQuery(LTRScoringModel ltrScoringModel, boolean extractAllFeatures) { + super(ltrScoringModel, extractAllFeatures); + } + + public LTRInterleavingScoringQuery(LTRScoringModel ltrScoringModel, + Map externalFeatureInfo, + boolean extractAllFeatures, LTRThreadModule ltrThreadMgr) { + super(ltrScoringModel, externalFeatureInfo, extractAllFeatures, ltrThreadMgr); + } + + public Set getPickedInterleavingDocIds() { + return pickedInterleavingDocIds; + } + + public void setPickedInterleavingDocIds(Set pickedInterleavingDocIds) { + this.pickedInterleavingDocIds = pickedInterleavingDocIds; + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/OriginalRankingLTRScoringQuery.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/OriginalRankingLTRScoringQuery.java new file mode 100644 index 00000000000..9c4db056e8b --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/OriginalRankingLTRScoringQuery.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.interleaving; + +public final class OriginalRankingLTRScoringQuery extends LTRInterleavingScoringQuery { + + private final String originalRankingModelName; + + public OriginalRankingLTRScoringQuery(String originalRankingModelName) { + super(null /* LTRScoringModel */); + this.originalRankingModelName = originalRankingModelName; + } + + @Override + public String getScoringModelName() { + return this.originalRankingModelName; + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/algorithms/TeamDraftInterleaving.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/algorithms/TeamDraftInterleaving.java new file mode 100644 index 00000000000..a9cd38ad8db --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/algorithms/TeamDraftInterleaving.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.ltr.interleaving.algorithms; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.Random; +import java.util.Set; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.solr.ltr.interleaving.Interleaving; +import org.apache.solr.ltr.interleaving.InterleavingResult; + +/** + * Interleaving was introduced the first time by Joachims in [1, 2]. + * Team Draft Interleaving is among the most successful and used interleaving approaches[3]. + * Team Draft Interleaving implements a method similar to the way in which captains select their players in team-matches. + * Team Draft Interleaving produces a fair distribution of ranking models’ elements in the final interleaved list. + * "Team draft interleaving" has also proved to overcome an issue of the "Balanced interleaving" approach, in determining the winning model[4]. + *

+ * [1] T. Joachims. Optimizing search engines using clickthrough data. KDD (2002) + * [2] T.Joachims.Evaluatingretrievalperformanceusingclickthroughdata.InJ.Franke, G. Nakhaeizadeh, and I. Renz, editors, + * Text Mining, pages 79–96. Physica/Springer (2003) + * [3] F. Radlinski, M. Kurup, and T. Joachims. How does clickthrough data reflect re- + * trieval quality? In CIKM, pages 43–52. ACM Press (2008) + * [4] O. Chapelle, T. Joachims, F. Radlinski, and Y. Yue. + * Large-scale validation and analysis of interleaved search evaluation. ACM TOIS, 30(1):1–41, Feb. (2012) + */ +public class TeamDraftInterleaving implements Interleaving { + public static Random RANDOM; + + static { + // We try to make things reproducible in the context of our tests by initializing the random instance + // based on the current seed + String seed = System.getProperty("tests.seed"); + if (seed == null) { + RANDOM = new Random(); + } else { + RANDOM = new Random(seed.hashCode()); + } + } + + /** + * Team Draft Interleaving considers two ranking models: modelA and modelB. + * For a given query, each model returns its ranked list of documents La = (a1,a2,...) and Lb = (b1, b2, ...). + * The algorithm creates a unique ranked list I = (i1, i2, ...). + * This list is created by interleaving elements from the two lists la and lb as described by Chapelle et al.[1]. + * Each element Ij is labelled TeamA if it is selected from La and TeamB if it is selected from Lb. + *

+ * [1] O. Chapelle, T. Joachims, F. Radlinski, and Y. Yue. + * Large-scale validation and analysis of interleaved search evaluation. ACM TOIS, 30(1):1–41, Feb. (2012) + *

+ * Assumptions: + * - rerankedA and rerankedB has the same length. + * They contains the same search results, ranked differently by two ranking models + * - each reranked list can not contain the same search result more than once. + * - results are all from the same shard + * + * @param rerankedA a ranked list of search results produced by a ranking model A + * @param rerankedB a ranked list of search results produced by a ranking model B + * @return the interleaved ranking list + */ + public InterleavingResult interleave(ScoreDoc[] rerankedA, ScoreDoc[] rerankedB) { + LinkedList interleavedResults = new LinkedList<>(); + HashSet alreadyAdded = new HashSet<>(); + ScoreDoc[] interleavedResultArray = new ScoreDoc[rerankedA.length]; + ArrayList> interleavingPicks = new ArrayList<>(2); + Set teamA = new HashSet<>(); + Set teamB = new HashSet<>(); + int topN = rerankedA.length; + int indexA = 0, indexB = 0; + + while (interleavedResults.size() < topN && indexA < rerankedA.length && indexB < rerankedB.length) { + if(teamA.size() alreadyAdded, int index, ScoreDoc[] reranked) { + boolean foundElementToAdd = false; + while (index < reranked.length && !foundElementToAdd) { + ScoreDoc elementToCheck = reranked[index]; + if (alreadyAdded.contains(elementToCheck.doc)) { + index++; + } else { + foundElementToAdd = true; + } + } + return index; + } + + public static void setRANDOM(Random RANDOM) { + TeamDraftInterleaving.RANDOM = RANDOM; + } +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/algorithms/package-info.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/algorithms/package-info.java new file mode 100644 index 00000000000..d66eedf6841 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/algorithms/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Contains Various Interleaving Algorithms + */ +package org.apache.solr.ltr.interleaving.algorithms; diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/package-info.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/package-info.java new file mode 100644 index 00000000000..82f59758717 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/interleaving/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Contains Various Interleaving auxiliary classes + */ +package org.apache.solr.ltr.interleaving; diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java index 3e76b4e9797..3e301858c65 100644 --- a/solr/contrib/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/response/transform/LTRFeatureLoggerTransformerFactory.java @@ -36,6 +36,8 @@ import org.apache.solr.ltr.LTRScoringQuery; import org.apache.solr.ltr.LTRThreadModule; import org.apache.solr.ltr.SolrQueryRequestContextUtils; import org.apache.solr.ltr.feature.Feature; +import org.apache.solr.ltr.interleaving.LTRInterleavingScoringQuery; +import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery; import org.apache.solr.ltr.model.LTRScoringModel; import org.apache.solr.ltr.norm.Normalizer; import org.apache.solr.ltr.search.LTRQParserPlugin; @@ -126,14 +128,15 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory { SolrQueryRequestContextUtils.setIsExtractingFeatures(req); // Communicate which feature store we are requesting features for - SolrQueryRequestContextUtils.setFvStoreName(req, localparams.get(FV_STORE, defaultStore)); + final String fvStoreName = localparams.get(FV_STORE); + SolrQueryRequestContextUtils.setFvStoreName(req, (fvStoreName == null ? defaultStore : fvStoreName)); // Create and supply the feature logger to be used SolrQueryRequestContextUtils.setFeatureLogger(req, createFeatureLogger( localparams.get(FV_FORMAT))); - return new FeatureTransformer(name, localparams, req); + return new FeatureTransformer(name, localparams, req, (fvStoreName != null) /* hasExplicitFeatureStore */); } /** @@ -163,11 +166,18 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory { final private String name; final private SolrParams localparams; final private SolrQueryRequest req; + final private boolean hasExplicitFeatureStore; private List leafContexts; private SolrIndexSearcher searcher; - private LTRScoringQuery scoringQuery; - private LTRScoringQuery.ModelWeight modelWeight; + /** + * rerankingQueries, modelWeights have: + * length=1 - [Classic LTR] When reranking with a single model + * length=2 - [Interleaving] When reranking with interleaving (two ranking models are involved) + */ + private LTRScoringQuery[] rerankingQueriesFromContext; + private LTRScoringQuery[] rerankingQueries; + private LTRScoringQuery.ModelWeight[] modelWeights; private FeatureLogger featureLogger; private boolean docsWereNotReranked; @@ -177,10 +187,11 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory { * feature vectors */ public FeatureTransformer(String name, SolrParams localparams, - SolrQueryRequest req) { + SolrQueryRequest req, boolean hasExplicitFeatureStore) { this.name = name; this.localparams = localparams; this.req = req; + this.hasExplicitFeatureStore = hasExplicitFeatureStore; } @Override @@ -208,55 +219,106 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory { if (threadManager != null) { threadManager.setExecutor(context.getRequest().getCore().getCoreContainer().getUpdateShardHandler().getUpdateExecutor()); } - - // Setup LTRScoringQuery - scoringQuery = SolrQueryRequestContextUtils.getScoringQuery(req); - docsWereNotReranked = (scoringQuery == null); - String featureStoreName = SolrQueryRequestContextUtils.getFvStoreName(req); - if (docsWereNotReranked || (featureStoreName != null && (!featureStoreName.equals(scoringQuery.getScoringModel().getFeatureStoreName())))) { - // if store is set in the transformer we should overwrite the logger - final ManagedFeatureStore fr = ManagedFeatureStore.getManagedFeatureStore(req.getCore()); + rerankingQueriesFromContext = SolrQueryRequestContextUtils.getScoringQueries(req); + docsWereNotReranked = (rerankingQueriesFromContext == null || rerankingQueriesFromContext.length == 0); + String transformerFeatureStore = SolrQueryRequestContextUtils.getFvStoreName(req); + Map transformerExternalFeatureInfo = LTRQParserPlugin.extractEFIParams(localparams); - final FeatureStore store = fr.getFeatureStore(featureStoreName); - featureStoreName = store.getName(); // if featureStoreName was null before this gets actual name - - try { - final LoggingModel lm = new LoggingModel(loggingModelName, - featureStoreName, store.getFeatures()); - - scoringQuery = new LTRScoringQuery(lm, - LTRQParserPlugin.extractEFIParams(localparams), - true, - threadManager); // request feature weights to be created for all features - - }catch (final Exception e) { - throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, - "retrieving the feature store "+featureStoreName, e); - } - } - - if (scoringQuery.getOriginalQuery() == null) { - scoringQuery.setOriginalQuery(context.getQuery()); - } - if (scoringQuery.getFeatureLogger() == null){ - scoringQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) ); - } - scoringQuery.setRequest(req); - - featureLogger = scoringQuery.getFeatureLogger(); - - try { - modelWeight = scoringQuery.createWeight(searcher, ScoreMode.COMPLETE, 1f); - } catch (final IOException e) { - throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e.getMessage(), e); - } - if (modelWeight == null) { - throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, - "error logging the features, model weight is null"); - } + final LoggingModel loggingModel = createLoggingModel(transformerFeatureStore); + setupRerankingQueriesForLogging(transformerFeatureStore, transformerExternalFeatureInfo, loggingModel); + setupRerankingWeightsForLogging(context); } + /** + * The loggingModel is an empty model that is just used to extract the features + * and log them + * @param transformerFeatureStore the explicit transformer feature store + */ + private LoggingModel createLoggingModel(String transformerFeatureStore) { + final ManagedFeatureStore fr = ManagedFeatureStore.getManagedFeatureStore(req.getCore()); + + final FeatureStore store = fr.getFeatureStore(transformerFeatureStore); + transformerFeatureStore = store.getName(); // if transformerFeatureStore was null before this gets actual name + + return new LoggingModel(loggingModelName, + transformerFeatureStore, store.getFeatures()); + } + + /** + * When preparing the reranking queries for logging features various scenarios apply: + * + * No Reranking + * There is the need of a logger model from the default feature store or the explicit feature store passed + * to extract the feature vector + * + * Re Ranking + * 1) If no explicit feature store is passed, the models for each reranking query can be safely re-used + * the feature vector can be fetched from the feature vector cache. + * 2) If an explicit feature store is passed, and no reranking query uses a model with that feature store, + * There is the need of a logger model to extract the feature vector + * 3) If an explicit feature store is passed, and there is a reranking query that uses a model with that feature store, + * the model can be re-used and there is no need for a logging model + * + * @param transformerFeatureStore explicit feature store for the transformer + * @param transformerExternalFeatureInfo explicit efi for the transformer + */ + private void setupRerankingQueriesForLogging(String transformerFeatureStore, Map transformerExternalFeatureInfo, LoggingModel loggingModel) { + if (docsWereNotReranked) { //no reranking query + LTRScoringQuery loggingQuery = new LTRScoringQuery(loggingModel, + transformerExternalFeatureInfo, + true /* extractAllFeatures */, + threadManager); + rerankingQueries = new LTRScoringQuery[]{loggingQuery}; + } else { + rerankingQueries = new LTRScoringQuery[rerankingQueriesFromContext.length]; + System.arraycopy(rerankingQueriesFromContext, 0, rerankingQueries, 0, rerankingQueriesFromContext.length); + + if (transformerFeatureStore != null) {// explicit feature store for the transformer + LTRScoringModel matchingRerankingModel = loggingModel; + for (LTRScoringQuery rerankingQuery : rerankingQueries) { + if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) && + transformerFeatureStore.equals(rerankingQuery.getScoringModel().getFeatureStoreName())) { + matchingRerankingModel = rerankingQuery.getScoringModel(); + } + } + + for (int i = 0; i < rerankingQueries.length; i++) { + rerankingQueries[i] = new LTRScoringQuery( + matchingRerankingModel, + (!transformerExternalFeatureInfo.isEmpty() ? transformerExternalFeatureInfo : rerankingQueries[i].getExternalFeatureInfo()), + true /* extractAllFeatures */, + threadManager); + } + } + } + } + + private void setupRerankingWeightsForLogging(ResultContext context) { + modelWeights = new LTRScoringQuery.ModelWeight[rerankingQueries.length]; + for (int i = 0; i < rerankingQueries.length; i++) { + if (rerankingQueries[i].getOriginalQuery() == null) { + rerankingQueries[i].setOriginalQuery(context.getQuery()); + } + rerankingQueries[i].setRequest(req); + if (!(rerankingQueries[i] instanceof OriginalRankingLTRScoringQuery) || hasExplicitFeatureStore) { + if (rerankingQueries[i].getFeatureLogger() == null) { + rerankingQueries[i].setFeatureLogger(SolrQueryRequestContextUtils.getFeatureLogger(req)); + } + featureLogger = rerankingQueries[i].getFeatureLogger(); + try { + modelWeights[i] = rerankingQueries[i].createWeight(searcher, ScoreMode.COMPLETE, 1f); + } catch (final IOException e) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e.getMessage(), e); + } + if (modelWeights[i] == null) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, + "error logging the features, model weight is null"); + } + } + } + } + @Override public void transform(SolrDocument doc, int docid, float score) throws IOException { @@ -271,17 +333,26 @@ public class LTRFeatureLoggerTransformerFactory extends TransformerFactory { private void implTransform(SolrDocument doc, int docid, Float score) throws IOException { - Object fv = featureLogger.getFeatureVector(docid, scoringQuery, searcher); - if (fv == null) { // FV for this document was not in the cache - fv = featureLogger.makeFeatureVector( - LTRRescorer.extractFeaturesInfo( - modelWeight, - docid, - (docsWereNotReranked ? score : null), - leafContexts)); + LTRScoringQuery rerankingQuery = rerankingQueries[0]; + LTRScoringQuery.ModelWeight rerankingModelWeight = modelWeights[0]; + for (int i = 1; i < rerankingQueries.length; i++) { + if (((LTRInterleavingScoringQuery)rerankingQueriesFromContext[i]).getPickedInterleavingDocIds().contains(docid)) { + rerankingQuery = rerankingQueries[i]; + rerankingModelWeight = modelWeights[i]; + } + } + if (!(rerankingQuery instanceof OriginalRankingLTRScoringQuery) || hasExplicitFeatureStore) { + Object featureVector = featureLogger.getFeatureVector(docid, rerankingQuery, searcher); + if (featureVector == null) { // FV for this document was not in the cache + featureVector = featureLogger.makeFeatureVector( + LTRRescorer.extractFeaturesInfo( + rerankingModelWeight, + docid, + (docsWereNotReranked ? score : null), + leafContexts)); + } + doc.addField(name, featureVector); } - - doc.addField(name, fv); } } diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/response/transform/LTRInterleavingTransformerFactory.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/response/transform/LTRInterleavingTransformerFactory.java new file mode 100644 index 00000000000..f126e3ee2bc --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/response/transform/LTRInterleavingTransformerFactory.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.response.transform; + +import java.io.IOException; +import org.apache.solr.common.SolrDocument; +import org.apache.solr.common.params.SolrParams; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.ltr.interleaving.LTRInterleavingScoringQuery; +import org.apache.solr.ltr.LTRScoringQuery; +import org.apache.solr.ltr.SolrQueryRequestContextUtils; +import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.response.ResultContext; +import org.apache.solr.response.transform.DocTransformer; +import org.apache.solr.response.transform.TransformerFactory; +import org.apache.solr.util.SolrPluginUtils; + +public class LTRInterleavingTransformerFactory extends TransformerFactory { + + @Override + @SuppressWarnings({"unchecked"}) + public void init(@SuppressWarnings("rawtypes") NamedList args) { + super.init(args); + SolrPluginUtils.invokeSetters(this, args); + } + + @Override + public DocTransformer create(String name, SolrParams localparams, + SolrQueryRequest req) { + return new InterleavingTransformer(name, req); + } + + class InterleavingTransformer extends DocTransformer { + + final private String name; + final private SolrQueryRequest req; + + private LTRInterleavingScoringQuery[] rerankingQueries; + + /** + * @param name + * Name of the field to be added in a document representing the + * model picked by the interleaving process + */ + public InterleavingTransformer(String name, + SolrQueryRequest req) { + this.name = name; + this.req = req; + } + + @Override + public String getName() { + return name; + } + + @Override + public void setContext(ResultContext context) { + super.setContext(context); + if (context == null) { + return; + } + if (context.getRequest() == null) { + return; + } + rerankingQueries = (LTRInterleavingScoringQuery[])SolrQueryRequestContextUtils.getScoringQueries(req); + for (int i = 0; i < rerankingQueries.length; i++) { + LTRScoringQuery scoringQuery = rerankingQueries[i]; + + if (scoringQuery.getOriginalQuery() == null) { + scoringQuery.setOriginalQuery(context.getQuery()); + } + if (scoringQuery.getFeatureLogger() == null) { + scoringQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) ); + } + scoringQuery.setRequest(req); + } + } + + @Override + public void transform(SolrDocument doc, int docid, float score) + throws IOException { + implTransform(doc, docid); + } + + @Override + public void transform(SolrDocument doc, int docid) + throws IOException { + implTransform(doc, docid); + } + + private void implTransform(SolrDocument doc, int docid) { + LTRScoringQuery rerankingQuery = rerankingQueries[0]; + if (rerankingQueries.length > 1 && rerankingQueries[1].getPickedInterleavingDocIds().contains(docid)) { + rerankingQuery = rerankingQueries[1]; + } + doc.addField(name, rerankingQuery.getScoringModelName()); + } + } + +} diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/search/LTRQParserPlugin.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/search/LTRQParserPlugin.java index 3e0931022e0..51009f7dcf2 100644 --- a/solr/contrib/ltr/src/java/org/apache/solr/ltr/search/LTRQParserPlugin.java +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/search/LTRQParserPlugin.java @@ -23,26 +23,26 @@ import java.util.Map; import org.apache.lucene.util.ResourceLoader; import org.apache.lucene.util.ResourceLoaderAware; -import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; import org.apache.solr.common.SolrException; import org.apache.solr.common.params.SolrParams; import org.apache.solr.common.util.NamedList; import org.apache.solr.core.SolrResourceLoader; -import org.apache.solr.ltr.LTRRescorer; +import org.apache.solr.ltr.interleaving.LTRInterleavingScoringQuery; +import org.apache.solr.ltr.interleaving.OriginalRankingLTRScoringQuery; import org.apache.solr.ltr.LTRScoringQuery; import org.apache.solr.ltr.LTRThreadModule; import org.apache.solr.ltr.SolrQueryRequestContextUtils; +import org.apache.solr.ltr.interleaving.Interleaving; +import org.apache.solr.ltr.interleaving.LTRInterleavingQuery; import org.apache.solr.ltr.model.LTRScoringModel; import org.apache.solr.ltr.store.rest.ManagedFeatureStore; import org.apache.solr.ltr.store.rest.ManagedModelStore; import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.rest.ManagedResource; import org.apache.solr.rest.ManagedResourceObserver; -import org.apache.solr.search.AbstractReRankQuery; import org.apache.solr.search.QParser; import org.apache.solr.search.QParserPlugin; -import org.apache.solr.search.RankQuery; import org.apache.solr.search.SyntaxError; import org.apache.solr.util.SolrPluginUtils; @@ -55,7 +55,7 @@ import org.apache.solr.util.SolrPluginUtils; */ public class LTRQParserPlugin extends QParserPlugin implements ResourceLoaderAware, ManagedResourceObserver { public static final String NAME = "ltr"; - private static Query defaultQuery = new MatchAllDocsQuery(); + private static final String ORIGINAL_RANKING = "_OriginalRanking_"; // params for setting custom external info that features can use, like query // intent @@ -145,95 +145,74 @@ public class LTRQParserPlugin extends QParserPlugin implements ResourceLoaderAwa @Override public Query parse() throws SyntaxError { - // ReRanking Model - final String modelName = localParams.get(LTRQParserPlugin.MODEL); - if ((modelName == null) || modelName.isEmpty()) { - throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, - "Must provide model in the request"); - } - - final LTRScoringModel ltrScoringModel = mr.getModel(modelName); - if (ltrScoringModel == null) { - throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, - "cannot find " + LTRQParserPlugin.MODEL + " " + modelName); - } - - final String modelFeatureStoreName = ltrScoringModel.getFeatureStoreName(); - final boolean extractFeatures = SolrQueryRequestContextUtils.isExtractingFeatures(req); - final String fvStoreName = SolrQueryRequestContextUtils.getFvStoreName(req); - // Check if features are requested and if the model feature store and feature-transform feature store are the same - final boolean featuresRequestedFromSameStore = (modelFeatureStoreName.equals(fvStoreName) || fvStoreName == null) ? extractFeatures:false; if (threadManager != null) { threadManager.setExecutor(req.getCore().getCoreContainer().getUpdateShardHandler().getUpdateExecutor()); } - final LTRScoringQuery scoringQuery = new LTRScoringQuery(ltrScoringModel, - extractEFIParams(localParams), - featuresRequestedFromSameStore, threadManager); - - // Enable the feature vector caching if we are extracting features, and the features - // we requested are the same ones we are reranking with - if (featuresRequestedFromSameStore) { - scoringQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) ); + // ReRanking Model + final String[] modelNames = localParams.getParams(LTRQParserPlugin.MODEL); + if ((modelNames == null) || (modelNames.length!=1 && modelNames.length!=2)) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, + "Must provide one or two models in the request"); + } + final boolean isInterleaving = (modelNames.length > 1); + final boolean extractFeatures = SolrQueryRequestContextUtils.isExtractingFeatures(req); + final String tranformerFeatureStoreName = SolrQueryRequestContextUtils.getFvStoreName(req); + final Map externalFeatureInfo = extractEFIParams(localParams); + + LTRScoringQuery rerankingQuery = null; + LTRInterleavingScoringQuery[] rerankingQueries = new LTRInterleavingScoringQuery[modelNames.length]; + for (int i = 0; i < modelNames.length; i++) { + if (modelNames[i].isEmpty()) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, + "the " + LTRQParserPlugin.MODEL + " "+ i +" is empty"); + } + if (!ORIGINAL_RANKING.equals(modelNames[i])) { + final LTRScoringModel ltrScoringModel = mr.getModel(modelNames[i]); + if (ltrScoringModel == null) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, + "cannot find " + LTRQParserPlugin.MODEL + " " + modelNames[i]); + } + final String modelFeatureStoreName = ltrScoringModel.getFeatureStoreName(); + // Check if features are requested and if the model feature store and feature-transform feature store are the same + final boolean featuresRequestedFromSameStore = (modelFeatureStoreName.equals(tranformerFeatureStoreName) || tranformerFeatureStoreName == null) ? extractFeatures : false; + + if (isInterleaving) { + rerankingQuery = rerankingQueries[i] = new LTRInterleavingScoringQuery(ltrScoringModel, + externalFeatureInfo, + featuresRequestedFromSameStore, threadManager); + } else { + rerankingQuery = new LTRScoringQuery(ltrScoringModel, + externalFeatureInfo, + featuresRequestedFromSameStore, threadManager); + rerankingQueries[i] = null; + } + + // Enable the feature vector caching if we are extracting features, and the features + // we requested are the same ones we are reranking with + if (featuresRequestedFromSameStore) { + rerankingQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) ); + } + }else{ + rerankingQuery = rerankingQueries[i] = new OriginalRankingLTRScoringQuery(ORIGINAL_RANKING); + } + + // External features + rerankingQuery.setRequest(req); } - SolrQueryRequestContextUtils.setScoringQuery(req, scoringQuery); int reRankDocs = localParams.getInt(RERANK_DOCS, DEFAULT_RERANK_DOCS); if (reRankDocs <= 0) { throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, - "Must rerank at least 1 document"); + "Must rerank at least 1 document"); + } + if (!isInterleaving) { + SolrQueryRequestContextUtils.setScoringQueries(req, new LTRScoringQuery[] { rerankingQuery }); + return new LTRQuery(rerankingQuery, reRankDocs); + } else { + SolrQueryRequestContextUtils.setScoringQueries(req, rerankingQueries); + return new LTRInterleavingQuery(Interleaving.getImplementation(Interleaving.TEAM_DRAFT),rerankingQueries, reRankDocs); } - - // External features - scoringQuery.setRequest(req); - - return new LTRQuery(scoringQuery, reRankDocs); } } - - /** - * A learning to rank Query, will incapsulate a learning to rank model, and delegate to it the rescoring - * of the documents. - **/ - public class LTRQuery extends AbstractReRankQuery { - private final LTRScoringQuery scoringQuery; - - public LTRQuery(LTRScoringQuery scoringQuery, int reRankDocs) { - super(defaultQuery, reRankDocs, new LTRRescorer(scoringQuery)); - this.scoringQuery = scoringQuery; - } - - @Override - public int hashCode() { - return 31 * classHash() + (mainQuery.hashCode() + scoringQuery.hashCode() + reRankDocs); - } - - @Override - public boolean equals(Object o) { - return sameClassAs(o) && equalsTo(getClass().cast(o)); - } - - private boolean equalsTo(LTRQuery other) { - return (mainQuery.equals(other.mainQuery) - && scoringQuery.equals(other.scoringQuery) && (reRankDocs == other.reRankDocs)); - } - - @Override - public RankQuery wrap(Query _mainQuery) { - super.wrap(_mainQuery); - scoringQuery.setOriginalQuery(_mainQuery); - return this; - } - - @Override - public String toString(String field) { - return "{!ltr mainQuery='" + mainQuery.toString() + "' scoringQuery='" - + scoringQuery.toString() + "' reRankDocs=" + reRankDocs + "}"; - } - - @Override - protected Query rewrite(Query rewrittenMainQuery) throws IOException { - return new LTRQuery(scoringQuery, reRankDocs).wrap(rewrittenMainQuery); - } - } - + } diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/search/LTRQuery.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/search/LTRQuery.java new file mode 100644 index 00000000000..e86fb7ec042 --- /dev/null +++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/search/LTRQuery.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.ltr.search; + +import java.io.IOException; + +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.solr.ltr.LTRRescorer; +import org.apache.solr.ltr.LTRScoringQuery; +import org.apache.solr.search.AbstractReRankQuery; +import org.apache.solr.search.RankQuery; + +/** + * A learning to rank Query, will incapsulate a learning to rank model, and delegate to it the rescoring + * of the documents. + **/ +public class LTRQuery extends AbstractReRankQuery { + private static final Query defaultQuery = new MatchAllDocsQuery(); + private final LTRScoringQuery scoringQuery; + + public LTRQuery(LTRScoringQuery scoringQuery, int reRankDocs) { + this(scoringQuery, reRankDocs, new LTRRescorer(scoringQuery)); + } + + protected LTRQuery(LTRScoringQuery scoringQuery, int reRankDocs, LTRRescorer rescorer) { + super(defaultQuery, reRankDocs, rescorer); + this.scoringQuery = scoringQuery; + } + + @Override + public int hashCode() { + return 31 * classHash() + (mainQuery.hashCode() + scoringQuery.hashCode() + reRankDocs); + } + + @Override + public boolean equals(Object o) { + return sameClassAs(o) && equalsTo(getClass().cast(o)); + } + + private boolean equalsTo(LTRQuery other) { + return (mainQuery.equals(other.mainQuery) + && scoringQuery.equals(other.scoringQuery) && (reRankDocs == other.reRankDocs)); + } + + @Override + public RankQuery wrap(Query _mainQuery) { + super.wrap(_mainQuery); + if (scoringQuery != null) { + scoringQuery.setOriginalQuery(_mainQuery); + } + return this; + } + + @Override + public String toString(String field) { + return "{!ltr mainQuery='" + mainQuery.toString() + "' scoringQuery='" + + scoringQuery.toString() + "' reRankDocs=" + reRankDocs + "}"; + } + + @Override + protected Query rewrite(Query rewrittenMainQuery) throws IOException { + return new LTRQuery(scoringQuery, reRankDocs).wrap(rewrittenMainQuery); + } +} diff --git a/solr/contrib/ltr/src/java/overview.html b/solr/contrib/ltr/src/java/overview.html index fa602824e8c..5e9b7b28b94 100644 --- a/solr/contrib/ltr/src/java/overview.html +++ b/solr/contrib/ltr/src/java/overview.html @@ -39,7 +39,7 @@ A Learning to Rank model is plugged into the ranking through the {@link org.apac a {@link org.apache.solr.search.QParserPlugin}. The plugin will read from the request the model (instance of {@link org.apache.solr.ltr.model.LTRScoringModel}) used to perform the request plus other -parameters. The plugin will generate a {@link org.apache.solr.ltr.search.LTRQParserPlugin.LTRQuery LTRQuery}: +parameters. The plugin will generate a {@link org.apache.solr.ltr.search.LTRQuery LTRQuery}: a particular {@link org.apache.solr.search.RankQuery} that will encapsulate the given model and use it to rescore and rerank the document (by using an {@link org.apache.solr.ltr.LTRRescorer}). diff --git a/solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml b/solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml index 057718aae30..31ed2631ea0 100644 --- a/solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml +++ b/solr/contrib/ltr/src/test-files/solr/collection1/conf/solrconfig-ltr.xml @@ -46,6 +46,15 @@ QUERY_DOC_FV + + + + 15000 diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserExplain.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserExplain.java index 14fb7e1a0f9..0b154a28b15 100644 --- a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserExplain.java +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserExplain.java @@ -16,7 +16,11 @@ */ package org.apache.solr.ltr; +import java.util.Random; + import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.feature.SolrFeature; +import org.apache.solr.ltr.interleaving.algorithms.TeamDraftInterleaving; import org.apache.solr.ltr.model.LinearModel; import org.junit.After; import org.junit.Before; @@ -149,4 +153,160 @@ public class TestLTRQParserExplain extends TestRerankBase { " 65.0 = tree 1 | \\'user_device_tablet\\':1.0 > 0.500001, Go Right | val: 65.0\n'}"); } + @Test + public void interleavingModels_shouldReturnExplainForTheModelPicked() throws Exception { + TeamDraftInterleaving.setRANDOM(new Random(10));//Random Boolean Choices Generation from Seed: [1,0] + + loadFeature("featureA1", SolrFeature.class.getName(), + "{\"fq\":[\"{!terms f=popularity}1\"]}"); + loadFeature("featureA2", SolrFeature.class.getName(), + "{\"fq\":[\"{!terms f=description}bloomberg\"]}"); + loadFeature("featureAB", SolrFeature.class.getName(), + "{\"fq\":[\"{!terms f=popularity}2\"]}"); + loadFeature("featureB1", SolrFeature.class.getName(), + "{\"fq\":[\"{!terms f=popularity}5\"]}"); + loadFeature("featureB2", SolrFeature.class.getName(), + "{\"fq\":[\"{!terms f=title}different\"]}"); + + loadModel("modelA", LinearModel.class.getName(), + new String[]{"featureA1", "featureA2", "featureAB"}, + "{\"weights\":{\"featureA1\":3.0, \"featureA2\":9.0, \"featureAB\":27.0}}"); + + loadModel("modelB", LinearModel.class.getName(), + new String[]{"featureB1", "featureB2", "featureAB"}, + "{\"weights\":{\"featureB1\":2.0, \"featureB2\":4.0, \"featureAB\":8.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:bloomberg"); + query.setParam("debugQuery", "on"); + query.add("rows", "10"); + query.add("rq", "{!ltr reRankDocs=10 model=modelA model=modelB}"); + query.add("fl", "*,score"); + + /* + Doc6 = "featureA1=1.0 featureA2=1.0 featureB2=1.0", ScoreA(12), ScoreB(4) + Doc7 = "featureA2=1.0 featureAB=1.0", ScoreA(36), ScoreB(8) + Doc8 = "featureA2=1.0", ScoreA(9), ScoreB(0) + Doc9 = "featureA2=1.0 featureB1=1.0", ScoreA(9), ScoreB(2) + + ModelARerankedList = [7,6,8,9] + ModelBRerankedList = [7,6,9,8] + + Random Boolean Choices Generation from Seed: [1,0] + + */ + + int[] expectedInterleaved = new int[]{7, 6, 8, 9}; + String[] expectedExplains = new String[]{ + "\n8.0 = LinearModel(name=modelB," + + "featureWeights=[featureB1=2.0,featureB2=4.0,featureAB=8.0]) " + + "model applied to features, sum of:\n " + + "0.0 = prod of:\n 2.0 = weight on feature\n 0.0 = SolrFeature [name=featureB1, params={fq=[{!terms f=popularity}5]}]\n " + + "0.0 = prod of:\n 4.0 = weight on feature\n 0.0 = SolrFeature [name=featureB2, params={fq=[{!terms f=title}different]}]\n " + + "8.0 = prod of:\n 8.0 = weight on feature\n 1.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n", + "\n12.0 = LinearModel(name=modelA," + + "featureWeights=[featureA1=3.0,featureA2=9.0,featureAB=27.0]) " + + "model applied to features, sum of:\n " + + "3.0 = prod of:\n 3.0 = weight on feature\n 1.0 = SolrFeature [name=featureA1, params={fq=[{!terms f=popularity}1]}]\n " + + "9.0 = prod of:\n 9.0 = weight on feature\n 1.0 = SolrFeature [name=featureA2, params={fq=[{!terms f=description}bloomberg]}]\n " + + "0.0 = prod of:\n 27.0 = weight on feature\n 0.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n", + "\n9.0 = LinearModel(name=modelA," + + "featureWeights=[featureA1=3.0,featureA2=9.0,featureAB=27.0]) " + + "model applied to features, sum of:\n " + + "0.0 = prod of:\n 3.0 = weight on feature\n 0.0 = SolrFeature [name=featureA1, params={fq=[{!terms f=popularity}1]}]\n " + + "9.0 = prod of:\n 9.0 = weight on feature\n 1.0 = SolrFeature [name=featureA2, params={fq=[{!terms f=description}bloomberg]}]\n " + + "0.0 = prod of:\n 27.0 = weight on feature\n 0.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n", + "\n2.0 = LinearModel(name=modelB," + + "featureWeights=[featureB1=2.0,featureB2=4.0,featureAB=8.0]) " + + "model applied to features, sum of:\n " + + "2.0 = prod of:\n 2.0 = weight on feature\n 1.0 = SolrFeature [name=featureB1, params={fq=[{!terms f=popularity}5]}]\n " + + "0.0 = prod of:\n 4.0 = weight on feature\n 0.0 = SolrFeature [name=featureB2, params={fq=[{!terms f=title}different]}]\n " + + "0.0 = prod of:\n 8.0 = weight on feature\n 0.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n"}; + + + + String[] tests = new String[16]; + tests[0] = "/response/numFound/==4"; + for (int i = 1; i <= 4; i++) { + tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\""; + tests[i + 4] = "/debug/explain/" + expectedInterleaved[(i - 1)] + "=='" + expectedExplains[(i - 1)]+"'}"; + } + assertJQ("/query" + query.toQueryString(), tests); + } + + @Test + public void interleavingModelsWithOriginalRanking_shouldReturnExplainForTheModelPicked() throws Exception { + TeamDraftInterleaving.setRANDOM(new Random(10));//Random Boolean Choices Generation from Seed: [1,0] + + loadFeature("featureA1", SolrFeature.class.getName(), + "{\"fq\":[\"{!terms f=popularity}1\"]}"); + loadFeature("featureA2", SolrFeature.class.getName(), + "{\"fq\":[\"{!terms f=description}bloomberg\"]}"); + loadFeature("featureAB", SolrFeature.class.getName(), + "{\"fq\":[\"{!terms f=popularity}2\"]}"); + + loadModel("modelA", LinearModel.class.getName(), + new String[]{"featureA1", "featureA2", "featureAB"}, + "{\"weights\":{\"featureA1\":3.0, \"featureA2\":9.0, \"featureAB\":27.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:bloomberg"); + query.setParam("debugQuery", "on"); + query.add("rows", "10"); + query.add("rq", "{!ltr reRankDocs=10 model=modelA model=_OriginalRanking_}"); + query.add("fl", "*,score"); + + /* + Doc6 = "featureA1=1.0 featureA2=1.0 featureB2=1.0", ScoreA(12) + Doc7 = "featureA2=1.0 featureAB=1.0", ScoreA(36) + Doc8 = "featureA2=1.0", ScoreA(9) + Doc9 = "featureA2=1.0 featureB1=1.0", ScoreA(9) + + ModelARerankedList = [7,6,8,9] + OriginalRanking = [9,8,7,6] + + Random Boolean Choices Generation from Seed: [1,0] + + */ + + int[] expectedInterleaved = new int[]{9, 7, 6, 8}; + String[] expectedExplains = new String[]{ + "\n0.07662583 = weight(title:bloomberg in 3) [SchemaSimilarity], result of:\n " + + "0.07662583 = score(freq=4.0), computed as boost * idf * tf from:\n " + + "0.105360515 = idf, computed as log(1 + (N - n + 0.5) / (n + 0.5)) from:\n 4 = n, number of documents containing term\n 4 = N, total number of documents with field\n " + + "0.72727275 = tf, computed as freq / (freq + k1 * (1 - b + b * dl / avgdl)) from:\n 4.0 = freq, occurrences of term within document\n " + + "1.2 = k1, term saturation parameter\n " + + "0.75 = b, length normalization parameter\n " + + "4.0 = dl, length of field\n " + + "3.0 = avgdl, average length of field\n", + "\n36.0 = LinearModel(name=modelA," + + "featureWeights=[featureA1=3.0,featureA2=9.0,featureAB=27.0]) " + + "model applied to features, sum of:\n " + + "0.0 = prod of:\n 3.0 = weight on feature\n 0.0 = SolrFeature [name=featureA1, params={fq=[{!terms f=popularity}1]}]\n " + + "9.0 = prod of:\n 9.0 = weight on feature\n 1.0 = SolrFeature [name=featureA2, params={fq=[{!terms f=description}bloomberg]}]\n " + + "27.0 = prod of:\n 27.0 = weight on feature\n 1.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n", + "\n12.0 = LinearModel(name=modelA," + + "featureWeights=[featureA1=3.0,featureA2=9.0,featureAB=27.0]) " + + "model applied to features, sum of:\n " + + "3.0 = prod of:\n 3.0 = weight on feature\n 1.0 = SolrFeature [name=featureA1, params={fq=[{!terms f=popularity}1]}]\n " + + "9.0 = prod of:\n 9.0 = weight on feature\n 1.0 = SolrFeature [name=featureA2, params={fq=[{!terms f=description}bloomberg]}]\n " + + "0.0 = prod of:\n 27.0 = weight on feature\n 0.0 = SolrFeature [name=featureAB, params={fq=[{!terms f=popularity}2]}]\n", + "\n0.07525751 = weight(title:bloomberg in 2) [SchemaSimilarity], result of:\n " + + "0.07525751 = score(freq=3.0), computed as boost * idf * tf from:\n " + + "0.105360515 = idf, computed as log(1 + (N - n + 0.5) / (n + 0.5)) from:\n 4 = n, number of documents containing term\n 4 = N, total number of documents with field\n " + + "0.71428573 = tf, computed as freq / (freq + k1 * (1 - b + b * dl / avgdl)) from:\n 3.0 = freq, occurrences of term within document\n " + + "1.2 = k1, term saturation parameter\n " + + "0.75 = b, length normalization parameter\n " + + "3.0 = dl, length of field\n " + + "3.0 = avgdl, average length of field\n"}; + + String[] tests = new String[16]; + tests[0] = "/response/numFound/==4"; + for (int i = 1; i <= 4; i++) { + tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\""; + tests[i + 4] = "/debug/explain/" + expectedInterleaved[(i - 1)] + "=='" + expectedExplains[(i - 1)]+"'}"; + } + assertJQ("/query" + query.toQueryString(), tests); + } + } diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserPlugin.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserPlugin.java index decb1c0888b..e6132102ac0 100644 --- a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserPlugin.java +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserPlugin.java @@ -48,7 +48,21 @@ public class TestLTRQParserPlugin extends TestRerankBase { query.add("rq", "{!ltr reRankDocs=100}"); final String res = restTestHarness.query("/query" + query.toQueryString()); - assert (res.contains("Must provide model in the request")); + assert (res.contains("Must provide one or two models in the request")); + } + + @Test + public void interleavingLtrTooManyModelsTest() throws Exception { + final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}"; + final SolrQuery query = new SolrQuery(); + query.setQuery(solrQuery); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("fv", "true"); + query.add("rq", "{!ltr model=modelA model=modelB model=C reRankDocs=100}"); + + final String res = restTestHarness.query("/query" + query.toQueryString()); + assert (res.contains("Must provide one or two models in the request")); } @Test @@ -65,6 +79,34 @@ public class TestLTRQParserPlugin extends TestRerankBase { assert (res.contains("cannot find model")); } + @Test + public void ltrModelIsEmptyTest() throws Exception { + final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}"; + final SolrQuery query = new SolrQuery(); + query.setQuery(solrQuery); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("fv", "true"); + query.add("rq", "{!ltr model=\"\" reRankDocs=100}"); + + final String res = restTestHarness.query("/query" + query.toQueryString()); + assert (res.contains("the model 0 is empty")); + } + + @Test + public void interleavingLtrModelIsEmptyTest() throws Exception { + final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}"; + final SolrQuery query = new SolrQuery(); + query.setQuery(solrQuery); + query.add("fl", "*, score"); + query.add("rows", "4"); + query.add("fv", "true"); + query.add("rq", "{!ltr model=6029760550880411648 model=\"\" reRankDocs=100}"); + + final String res = restTestHarness.query("/query" + query.toQueryString()); + assert (res.contains("the model 1 is empty")); + } + @Test public void ltrBadRerankDocsTest() throws Exception { final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}"; diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRWithSort.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRWithSort.java index 708fdc8105b..c0157130bcf 100644 --- a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRWithSort.java +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRWithSort.java @@ -17,8 +17,11 @@ package org.apache.solr.ltr; +import java.util.Random; + import org.apache.solr.client.solrj.SolrQuery; import org.apache.solr.ltr.feature.SolrFeature; +import org.apache.solr.ltr.interleaving.algorithms.TeamDraftInterleaving; import org.apache.solr.ltr.model.LinearModel; import org.junit.After; import org.junit.Before; @@ -97,4 +100,104 @@ public class TestLTRWithSort extends TestRerankBase { } + @Test + public void interleavingTwoModelsWithSort_shouldInterleave() throws Exception { + TeamDraftInterleaving.setRANDOM(new Random(10));//Random Boolean Choices Generation from Seed: [1,0] + + loadFeature("featureA", SolrFeature.class.getName(), + "{\"q\":\"{!func}pow(popularity,2)\"}"); + + loadFeature("featureB", SolrFeature.class.getName(), + "{\"q\":\"{!func}pow(popularity,-2)\"}"); + + loadModel("modelA", LinearModel.class.getName(), + new String[] {"featureA"}, "{\"weights\":{\"featureA\":1.0}}"); + + loadModel("modelB", LinearModel.class.getName(), + new String[] {"featureB"}, "{\"weights\":{\"featureB\":1.0}}"); + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:a1"); + query.add("rows", "10"); + query.add("rq", "{!ltr reRankDocs=4 model=modelA model=modelB}"); + query.add("fl", "*,score"); + query.add("sort", "description desc"); + + /* + Doc1 = "popularity=1", ScoreA(1) ScoreB(1) + Doc5 = "popularity=5", ScoreA(25) ScoreB(0.04) + Doc7 = "popularity=7", ScoreA(49) ScoreB(0.02) + Doc8 = "popularity=8", ScoreA(64) ScoreB(0.01) + + ModelARerankedList = [8,7,5,1] + ModelBRerankedList = [1,5,7,8] + + OriginalRanking = [1,5,8,7] + + Random Boolean Choices Generation from Seed: [1,0] + */ + + int[] expectedInterleaved = new int[]{1, 8, 7, 5}; + + String[] tests = new String[5]; + tests[0] = "/response/numFound/==8"; + for (int i = 1; i <= 4; i++) { + tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\""; + } + assertJQ("/query" + query.toQueryString(), tests); + + } + + @Test + public void interleavingModelsWithOriginalRankingSort_shouldInterleave() throws Exception { + + loadFeature("powpularityS", SolrFeature.class.getName(), + "{\"q\":\"{!func}pow(popularity,2)\"}"); + + loadModel("powpularityS-model", LinearModel.class.getName(), + new String[] {"powpularityS"}, "{\"weights\":{\"powpularityS\":1.0}}"); + + for (boolean originalRankingLast : new boolean[] { true, false }) { + TeamDraftInterleaving.setRANDOM(new Random(10));//Random Boolean Choices Generation from Seed: [1,0] + + final SolrQuery query = new SolrQuery(); + query.setQuery("title:a1"); + query.add("rows", "10"); + if (originalRankingLast) { + query.add("rq", "{!ltr reRankDocs=4 model=powpularityS-model model=_OriginalRanking_}"); + } else { + query.add("rq", "{!ltr reRankDocs=4 model=_OriginalRanking_ model=powpularityS-model}"); + } + query.add("fl", "*,score"); + query.add("sort", "description desc"); + + /* + Doc1 = "popularity=1", ScorePowpularityS(1) + Doc5 = "popularity=5", ScorePowpularityS(25) + Doc7 = "popularity=7", ScorePowpularityS(49) + Doc8 = "popularity=8", ScorePowpularityS(64) + + PowpularitySRerankedList = [8,7,5,1] + OriginalRanking = [1,5,8,7] + + Random Boolean Choices Generation from Seed: [1,0] + */ + + final int[] expectedInterleaved; + if (originalRankingLast) { + expectedInterleaved = new int[]{1, 8, 7, 5}; + } else { + expectedInterleaved = new int[]{8, 1, 5, 7}; + } + + String[] tests = new String[5]; + tests[0] = "/response/numFound/==8"; + for (int i = 1; i <= 4; i++) { + tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\""; + } + assertJQ("/query" + query.toQueryString(), tests); + } + + } + } diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/interleaving/algorithms/TeamDraftInterleavingTest.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/interleaving/algorithms/TeamDraftInterleavingTest.java new file mode 100644 index 00000000000..0f484a5df76 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/interleaving/algorithms/TeamDraftInterleavingTest.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.ltr.interleaving.algorithms; + + +import java.util.ArrayList; +import java.util.Random; +import java.util.Set; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.solr.ltr.interleaving.InterleavingResult; +import org.junit.Before; +import org.junit.Test; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.CoreMatchers.is; +import static org.junit.Assert.assertTrue; + +public class TeamDraftInterleavingTest { + private TeamDraftInterleaving toTest; + private ScoreDoc[] rerankedA,rerankedB; + private ScoreDoc a1,a2,a3,a4,a5; + private ScoreDoc b1,b2,b3,b4,b5; + + + @Before + public void setup() { + toTest = new TeamDraftInterleaving(); + TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1] + } + + protected void initDifferentOrderRerankLists() { + a1 = new ScoreDoc(1,10,1); + a2 = new ScoreDoc(5,7,1); + a3 = new ScoreDoc(4,6,1); + a4 = new ScoreDoc(2,5,1); + a5 = new ScoreDoc(3,4,1); + rerankedA = new ScoreDoc[]{a1,a2,a3,a4,a5}; + + b1 = new ScoreDoc(1,10,1); + b2 = new ScoreDoc(4,7,1); + b3 = new ScoreDoc(5,6,1); + b4 = new ScoreDoc(3,5,1); + b5 = new ScoreDoc(2,4,1); + rerankedB = new ScoreDoc[]{b1,b2,b3,b4,b5}; + } + + /** + * Random Boolean Choices Generation from Seed: [0,1,1] + */ + @Test + public void interleaving_twoDifferentLists_shouldInterleaveTeamDraft() { + initDifferentOrderRerankLists(); + + InterleavingResult interleaved = toTest.interleave(rerankedA, rerankedB); + ScoreDoc[] interleavedResults = interleaved.getInterleavedResults(); + + assertThat(interleavedResults.length,is(5)); + + assertThat(interleavedResults[0],is(a1)); + assertThat(interleavedResults[1],is(b2)); + + assertThat(interleavedResults[2],is(b3)); + assertThat(interleavedResults[3],is(a4)); + + assertThat(interleavedResults[4],is(b4)); + } + + /** + * Random Boolean Choices Generation from Seed: [0,1,1] + */ + @Test + public void interleaving_twoDifferentLists_shouldBuildCorrectInterleavingPicks() { + initDifferentOrderRerankLists(); + + InterleavingResult interleaved = toTest.interleave(rerankedA, rerankedB); + + ArrayList> interleavingPicks = interleaved.getInterleavingPicks(); + Set modelAPicks = interleavingPicks.get(0); + Set modelBPicks = interleavingPicks.get(1); + + assertThat(modelAPicks.size(),is(2)); + assertThat(modelBPicks.size(),is(3)); + + assertTrue(modelAPicks.contains(a1.doc)); + assertTrue(modelAPicks.contains(a4.doc)); + + assertTrue(modelBPicks.contains(b2.doc)); + assertTrue(modelBPicks.contains(b3.doc)); + assertTrue(modelBPicks.contains(b4.doc)); + } + + protected void initIdenticalOrderRerankLists() { + a1 = new ScoreDoc(1,10,1); + a2 = new ScoreDoc(5,7,1); + a3 = new ScoreDoc(4,6,1); + a4 = new ScoreDoc(2,5,1); + a5 = new ScoreDoc(3,4,1); + rerankedA = new ScoreDoc[]{a1,a2,a3,a4,a5}; + + b1 = new ScoreDoc(1,10,1); + b2 = new ScoreDoc(5,7,1); + b3 = new ScoreDoc(4,6,1); + b4 = new ScoreDoc(2,5,1); + b5 = new ScoreDoc(3,4,1); + rerankedB = new ScoreDoc[]{b1,b2,b3,b4,b5}; + } + + /** + * Random Boolean Choices Generation from Seed: [0,1,1] + */ + @Test + public void interleaving_identicalRerankLists_shouldInterleaveTeamDraft() { + initIdenticalOrderRerankLists(); + + InterleavingResult interleaved = toTest.interleave(rerankedA, rerankedB); + ScoreDoc[] interleavedResults = interleaved.getInterleavedResults(); + + assertThat(interleavedResults.length,is(5)); + + assertThat(interleavedResults[0],is(a1)); + assertThat(interleavedResults[1],is(b2)); + + assertThat(interleavedResults[2],is(b3)); + assertThat(interleavedResults[3],is(a4)); + + assertThat(interleavedResults[4],is(b5)); + } + + /** + * Random Boolean Choices Generation from Seed: [0,1,1] + */ + @Test + public void interleaving_identicalRerankLists_shouldBuildCorrectInterleavingPicks() { + initIdenticalOrderRerankLists(); + + InterleavingResult interleaved = toTest.interleave(rerankedA, rerankedB); + + ArrayList> interleavingPicks = interleaved.getInterleavingPicks(); + Set modelAPicks = interleavingPicks.get(0); + Set modelBPicks = interleavingPicks.get(1); + + assertThat(modelAPicks.size(),is(2)); + assertThat(modelBPicks.size(),is(3)); + + assertTrue(modelAPicks.contains(a1.doc)); + assertTrue(modelAPicks.contains(a4.doc)); + + assertTrue(modelBPicks.contains(b2.doc)); + assertTrue(modelBPicks.contains(b3.doc)); + assertTrue(modelBPicks.contains(b5.doc)); + } + +} + diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/response/transform/TestFeatureLoggerTransformer.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/response/transform/TestFeatureLoggerTransformer.java new file mode 100644 index 00000000000..c7da81c7169 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/response/transform/TestFeatureLoggerTransformer.java @@ -0,0 +1,400 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.response.transform; + +import java.util.Random; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.feature.SolrFeature; +import org.apache.solr.ltr.interleaving.algorithms.TeamDraftInterleaving; +import org.apache.solr.ltr.model.LinearModel; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class TestFeatureLoggerTransformer extends TestRerankBase { + @Before + public void before() throws Exception { + setuptest(false); + + assertU(adoc("id", "1", "title", "w1", "description", "w5", "popularity", + "1")); + assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description", + "w2 2asd asdd didid", "popularity", "2")); + assertU(adoc("id", "3", "title", "w1", "description", "w5", "popularity", + "3")); + assertU(adoc("id", "4", "title", "w1", "description", "w1", "popularity", + "6")); + assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", + "5")); + assertU(adoc("id", "6", "title", "w6 w2", "description", "w1 w2", + "popularity", "6")); + assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description", + "w6 w2 w3 w4 w5 w8", "popularity", "88888")); + assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description", + "w1 w1 w1 w2 w2 w5", "popularity", "88888")); + assertU(commit()); + + + } + + @After + public void after() throws Exception { + aftertest(); + } + + protected void loadFeaturesAndModels() throws Exception { + loadFeature("featureA1", SolrFeature.class.getName(), + "{\"fq\":[\"{!terms f=popularity}88888\"]}"); + loadFeature("featureA2", SolrFeature.class.getName(), + "{\"fq\":[\"{!terms f=title}${user_query}\"]}"); + loadFeature("featureAB", SolrFeature.class.getName(), + "{\"fq\":[\"{!terms f=title}${user_query}\"]}"); + loadFeature("featureB1", SolrFeature.class.getName(), + "{\"fq\":[\"{!terms f=popularity}6\"]}"); + loadFeature("featureB2", SolrFeature.class.getName(), + "{\"fq\":[\"{!terms f=description}${user_query}\"]}"); + loadFeature("featureC1", SolrFeature.class.getName(),"featureStore2", + "{\"fq\":[\"{!terms f=popularity}6\"]}"); + loadFeature("featureC2", SolrFeature.class.getName(),"featureStore2", + "{\"fq\":[\"{!terms f=description}${user_query}\"]}"); + loadFeature("featureC3", SolrFeature.class.getName(),"featureStore2", + "{\"fq\":[\"{!terms f=description}${user_query}\"]}"); + + loadModel("modelA", LinearModel.class.getName(), + new String[]{"featureA1", "featureA2", "featureAB"}, + "{\"weights\":{\"featureA1\":3.0, \"featureA2\":9.0, \"featureAB\":27.0}}"); + + loadModel("modelB", LinearModel.class.getName(), + new String[]{"featureB1", "featureB2", "featureAB"}, + "{\"weights\":{\"featureB1\":2.0, \"featureB2\":4.0, \"featureAB\":8.0}}"); + + loadModel("modelC", LinearModel.class.getName(), + new String[]{"featureC1", "featureC2"},"featureStore2", + "{\"weights\":{\"featureC1\":5.0, \"featureC2\":25.0}}"); + } + + @Test + public void interleaving_featureTransformer_shouldWorkInSparseFormat() throws Exception { + TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1] + loadFeaturesAndModels(); + + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*, score,features:[fv format=sparse]"); + query.add("rows", "10"); + query.add("debugQuery", "true"); + query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8 + query.add("rq", + "{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}"); + + /* + Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4) + Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4) + Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2) + Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12) + Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4) + ModelARerankedList = [7,8,1,3,4] + ModelBRerankedList = [7,1,3,8,4] + + Random Boolean Choices Generation from Seed: [0,1,1] + */ + String[] expectedFeatureVectors = new String[]{"featureA1\\=1.0\\,featureA2\\=1.0\\,featureAB\\=1.0\\,featureB2\\=1.0", "featureB2\\=1.0", "featureB2\\=1.0", "featureA1\\=1.0\\,featureB2\\=1.0", "featureB1\\=1.0"}; + int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4}; + + String[] tests = new String[11]; + tests[0] = "/response/numFound/==5"; + for (int i = 1; i <= 5; i++) { + tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\""; + tests[i + 5] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)]; + } + assertJQ("/query" + query.toQueryString(), tests); + } + + @Test + public void interleaving_featureTransformer_shouldWorkInDenseFormat() throws Exception { + TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1] + loadFeaturesAndModels(); + + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*, score,features:[fv format=dense]"); + query.add("rows", "10"); + query.add("debugQuery", "true"); + query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8 + query.add("rq", + "{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}"); + + /* + Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4) + Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4) + Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2) + Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12) + Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4) + ModelARerankedList = [7,8,1,3,4] + ModelBRerankedList = [7,1,3,8,4] + + Random Boolean Choices Generation from Seed: [0,1,1] + */ + String[] expectedFeatureVectors = new String[]{"featureA1\\=1.0\\,featureA2\\=1.0\\,featureAB\\=1.0\\,featureB1\\=0.0\\,featureB2\\=1.0", + "featureA1\\=0.0\\,featureA2\\=0.0\\,featureAB\\=0.0\\,featureB1\\=0.0\\,featureB2\\=1.0", + "featureA1\\=0.0\\,featureA2\\=0.0\\,featureAB\\=0.0\\,featureB1\\=0.0\\,featureB2\\=1.0", + "featureA1\\=1.0\\,featureA2\\=0.0\\,featureAB\\=0.0\\,featureB1\\=0.0\\,featureB2\\=1.0", + "featureA1\\=0.0\\,featureA2\\=0.0\\,featureAB\\=0.0\\,featureB1\\=1.0\\,featureB2\\=0.0"}; + int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4}; + + String[] tests = new String[11]; + tests[0] = "/response/numFound/==5"; + for (int i = 1; i <= 5; i++) { + tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\""; + tests[i + 5] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)]; + } + assertJQ("/query" + query.toQueryString(), tests); + } + + @Test + public void interleaving_explicitNewFeatureStore_shouldExtractAllFeaturesFromNewStore() throws Exception { + TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1] + loadFeaturesAndModels(); + + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*, score,features:[fv store=featureStore2 efi.user_query='w5' format=dense]"); + query.add("rows", "10"); + query.add("debugQuery", "true"); + query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8 + query.add("rq", + "{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}"); + + /* + Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4) + Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4) + Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2) + Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12) + Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4) + ModelARerankedList = [7,8,1,3,4] + ModelBRerankedList = [7,1,3,8,4] + + Random Boolean Choices Generation from Seed: [0,1,1] + */ + String[] expectedFeatureVectors = new String[]{ + "featureC1\\=0.0\\,featureC2\\=1.0\\,featureC3\\=1.0", + "featureC1\\=0.0\\,featureC2\\=1.0\\,featureC3\\=1.0", + "featureC1\\=0.0\\,featureC2\\=1.0\\,featureC3\\=1.0", + "featureC1\\=0.0\\,featureC2\\=1.0\\,featureC3\\=1.0", + "featureC1\\=1.0\\,featureC2\\=0.0\\,featureC3\\=0.0"}; + int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4}; + + String[] tests = new String[16]; + tests[0] = "/response/numFound/==5"; + for (int i = 1; i <= 5; i++) { + tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\""; + tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)]; + } + assertJQ("/query" + query.toQueryString(), tests); + } + + @Test + public void interleaving_withOriginalRankingAndExplicitFeatureStore_shouldReturnNewCalculatedFeatureVector() throws Exception { + TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1] + loadFeaturesAndModels(); + + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*, score,features:[fv store=featureStore2 efi.user_query='w5' format=sparse]"); + query.add("rows", "10"); + query.add("debugQuery", "true"); + query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8 + query.add("rq", + "{!ltr model=modelA model=_OriginalRanking_ reRankDocs=10 efi.user_query='w5'}"); + + /* + Doc1 = "featureB2=1.0", ScoreA(0) + Doc3 = "featureB2=1.0", ScoreA(0) + Doc4 = "featureB1=1.0", ScoreA(0) + Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39) + Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3) + + ModelARerankedList = [7,8,1,3,4] + _OriginalRanking_ = [1,3,4,7,8] + + + Random Boolean Choices Generation from Seed: [0,1,1] + */ + String[] expectedFeatureVectors = new String[]{ + "featureC2\\=1.0\\,featureC3\\=1.0", + "featureC2\\=1.0\\,featureC3\\=1.0", + "featureC2\\=1.0\\,featureC3\\=1.0", + "featureC2\\=1.0\\,featureC3\\=1.0", + "featureC1\\=1.0"}; + int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4}; + + String[] tests = new String[16]; + tests[0] = "/response/numFound/==5"; + for (int i = 1; i <= 5; i++) { + tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\""; + if (expectedFeatureVectors[(i - 1)] != null) { + tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)]; + } + } + assertJQ("/query" + query.toQueryString(), tests); + + int[] nullFeatureVectorIndexes = new int[]{1, 2, 4}; + for (int index : nullFeatureVectorIndexes) { + TeamDraftInterleaving.setRANDOM(new Random(10101010)); + String[] nullFeatureVectorTests = new String[1]; + try { + nullFeatureVectorTests[0] = "/response/docs/[" + index + "]/features=="; + assertJQ("/query" + query.toQueryString(), nullFeatureVectorTests); + } catch (Exception e) { + assertEquals("Path not found: /response/docs/[" + index + "]/features", e.getMessage()); + continue; + } + + } + + } + + @Test + public void interleaving_modelsFromDifferentFeatureStores_shouldLogFeaturesCorrectly() throws Exception { + TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1] + loadFeaturesAndModels(); + + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*, score,features:[fv format=sparse]"); + query.add("rows", "10"); + query.add("debugQuery", "true"); + query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8 + query.add("rq", + "{!ltr model=modelA model=modelC reRankDocs=10 efi.user_query='w5'}"); + + /* + Doc1 = "featureC2=1.0", ScoreA(0), ScoreC(25) + Doc3 = "featureC2=1.0", ScoreA(0), ScoreC(25) + Doc4 = "featureC1=1.0", ScoreA(0), ScoreC(5) + Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0", ScoreA(39), ScoreC(0) + Doc8 = "featureA1=1.0,featureC2=1.0", ScoreA(3), ScoreC(25) + ModelARerankedList = [7,8,1,3,4] + ModelCRerankedList = [1,3,8,4,7] + + Random Boolean Choices Generation from Seed: [0,1,1] + */ + String[] expectedFeatureVectors = new String[]{ + "featureA1\\=1.0\\,featureA2\\=1.0\\,featureAB\\=1.0\\,featureB2\\=1.0", + "featureC2\\=1.0\\,featureC3\\=1.0", + "featureC2\\=1.0\\,featureC3\\=1.0", + "featureA1\\=1.0\\,featureB2\\=1.0", + "featureC1\\=1.0"}; + int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4}; + + String[] tests = new String[16]; + tests[0] = "/response/numFound/==5"; + for (int i = 1; i <= 5; i++) { + tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\""; + tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)]; + } + assertJQ("/query" + query.toQueryString(), tests); + } + + @Test + public void interleaving_featureLoggerFromNewFeatureStoreWithDifferentEfi_shouldReturnNewCalculatedFeatureVector() throws Exception { + TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1] + loadFeaturesAndModels(); + + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*, score,features:[fv store=featureStore2 efi.user_query='w8' format=sparse]"); + query.add("rows", "10"); + query.add("debugQuery", "true"); + query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8 + query.add("rq", + "{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}"); + + /* + Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4) + Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4) + Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2) + Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12) + Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4) + ModelARerankedList = [7,8,1,3,4] + ModelBRerankedList = [7,1,3,8,4] + + Random Boolean Choices Generation from Seed: [0,1,1] + */ + String[] expectedFeatureVectors = new String[]{ + "featureC2\\=1.0\\,featureC3\\=1.0", + "", + "", + "", + ""}; + int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4}; + + String[] tests = new String[16]; + tests[0] = "/response/numFound/==5"; + for (int i = 1; i <= 5; i++) { + tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\""; + tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)]; + } + assertJQ("/query" + query.toQueryString(), tests); + } + + @Test + public void interleaving_explicitFeatureStoreReusableFromModel_shouldLogFeaturesCorrectly() throws Exception { + TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1] + loadFeaturesAndModels(); + + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*, score,features:[fv store=featureStore2 format=sparse]"); + query.add("rows", "10"); + query.add("debugQuery", "true"); + query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8 + query.add("rq", + "{!ltr model=modelA model=modelC reRankDocs=10 efi.user_query='w5'}"); + + /* + Doc1 = "featureC2=1.0", ScoreA(0), ScoreC(25) + Doc3 = "featureC2=1.0", ScoreA(0), ScoreC(25) + Doc4 = "featureC1=1.0", ScoreA(0), ScoreC(5) + Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0", ScoreA(39), ScoreC(0) + Doc8 = "featureA1=1.0,featureC2=1.0", ScoreA(3), ScoreC(25) + ModelARerankedList = [7,8,1,3,4] + ModelCRerankedList = [1,3,8,4,7] + + Random Boolean Choices Generation from Seed: [0,1,1] + */ + String[] expectedFeatureVectors = new String[]{ + "featureC2\\=1.0\\,featureC3\\=1.0", + "featureC2\\=1.0\\,featureC3\\=1.0", + "featureC2\\=1.0\\,featureC3\\=1.0", + "featureC2\\=1.0\\,featureC3\\=1.0", + "featureC1\\=1.0"}; + int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4}; + + String[] tests = new String[16]; + tests[0] = "/response/numFound/==5"; + for (int i = 1; i <= 5; i++) { + tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\""; + tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)]; + } + assertJQ("/query" + query.toQueryString(), tests); + } + + +} diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/response/transform/TestInterleavingTransformer.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/response/transform/TestInterleavingTransformer.java new file mode 100644 index 00000000000..4f048d5afb1 --- /dev/null +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/response/transform/TestInterleavingTransformer.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.ltr.response.transform; + +import java.util.Random; + +import org.apache.solr.client.solrj.SolrQuery; +import org.apache.solr.ltr.TestRerankBase; +import org.apache.solr.ltr.feature.SolrFeature; +import org.apache.solr.ltr.interleaving.algorithms.TeamDraftInterleaving; +import org.apache.solr.ltr.model.LinearModel; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class TestInterleavingTransformer extends TestRerankBase { + @Before + public void before() throws Exception { + setuptest(false); + + assertU(adoc("id", "1", "title", "w1", "description", "w5", "popularity", + "1")); + assertU(adoc("id", "2", "title", "w2 2asd asdd didid", "description", + "w2 2asd asdd didid", "popularity", "2")); + assertU(adoc("id", "3", "title", "w1", "description", "w5", "popularity", + "3")); + assertU(adoc("id", "4", "title", "w1", "description", "w1", "popularity", + "6")); + assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity", + "5")); + assertU(adoc("id", "6", "title", "w6 w2", "description", "w1 w2", + "popularity", "6")); + assertU(adoc("id", "7", "title", "w1 w2 w3 w4 w5", "description", + "w6 w2 w3 w4 w5 w8", "popularity", "88888")); + assertU(adoc("id", "8", "title", "w1 w1 w1 w2 w2 w8", "description", + "w1 w1 w1 w2 w2 w5", "popularity", "88888")); + assertU(commit()); + + + } + + @After + public void after() throws Exception { + aftertest(); + } + + protected void loadFeaturesAndModelsForInterleaving() throws Exception { + loadFeature("featureA1", SolrFeature.class.getName(), + "{\"fq\":[\"{!terms f=popularity}88888\"]}"); + loadFeature("featureA2", SolrFeature.class.getName(), + "{\"fq\":[\"{!terms f=title}${user_query}\"]}"); + loadFeature("featureAB", SolrFeature.class.getName(), + "{\"fq\":[\"{!terms f=title}${user_query}\"]}"); + loadFeature("featureB1", SolrFeature.class.getName(), + "{\"fq\":[\"{!terms f=popularity}6\"]}"); + loadFeature("featureB2", SolrFeature.class.getName(), + "{\"fq\":[\"{!terms f=description}${user_query}\"]}"); + loadFeature("featureC1", SolrFeature.class.getName(),"featureStore2", + "{\"fq\":[\"{!terms f=popularity}6\"]}"); + loadFeature("featureC2", SolrFeature.class.getName(),"featureStore2", + "{\"fq\":[\"{!terms f=popularity}1\"]}"); + + loadModel("modelA", LinearModel.class.getName(), + new String[]{"featureA1", "featureA2", "featureAB"}, + "{\"weights\":{\"featureA1\":3.0, \"featureA2\":9.0, \"featureAB\":27.0}}"); + + loadModel("modelB", LinearModel.class.getName(), + new String[]{"featureB1", "featureB2", "featureAB"}, + "{\"weights\":{\"featureB1\":2.0, \"featureB2\":4.0, \"featureAB\":8.0}}"); + + loadModel("modelC", LinearModel.class.getName(), + new String[]{"featureC1", "featureC2"},"featureStore2", + "{\"weights\":{\"featureC1\":5.0, \"featureC2\":25.0}}"); + } + + @Test + public void interleavingTransformer_shouldReturnInterleavingPickInTheResults() throws Exception { + TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1] + loadFeaturesAndModelsForInterleaving(); + + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*, score,interleavingPick:[interleaving]"); + query.add("rows", "10"); + query.add("debugQuery", "true"); + query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8 + query.add("rq", + "{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}"); + + /* + Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4) + Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4) + Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2) + Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12) + Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4) + ModelARerankedList = [7,8,1,3,4] + ModelBRerankedList = [7,1,3,8,4] + + Random Boolean Choices Generation from Seed: [0,1,1] + */ + String[] expectedInterleavingPicks = new String[]{"modelA", "modelB", "modelB", "modelA", "modelB"}; + int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4}; + + String[] tests = new String[11]; + tests[0] = "/response/numFound/==5"; + for (int i = 1; i <= 5; i++) { + tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\""; + tests[i + 5] = "/response/docs/[" + (i - 1) + "]/interleavingPick==" + expectedInterleavingPicks[(i - 1)]; + } + assertJQ("/query" + query.toQueryString(), tests); + + } + + @Test + public void interleavingTransformerWithOriginalRanking_shouldReturnInterleavingPickInTheResults() throws Exception { + TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1] + loadFeaturesAndModelsForInterleaving(); + + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*, score,interleavingPick:[interleaving]"); + query.add("rows", "10"); + query.add("debugQuery", "true"); + query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8 + query.add("rq", + "{!ltr model=modelA model=_OriginalRanking_ reRankDocs=10 efi.user_query='w5'}"); + + /* + Doc1 = "featureB2=1.0", ScoreA(0) + Doc3 = "featureB2=1.0", ScoreA(0) + Doc4 = "featureB1=1.0", ScoreA(0) + Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39) + Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3) + + ModelARerankedList = [7,8,1,3,4] + OriginalRanking = [1,3,4,7,8] + + + Random Boolean Choices Generation from Seed: [0,1,1] + */ + String[] expectedInterleavingPicks = new String[]{"modelA", "_OriginalRanking_", "_OriginalRanking_", "modelA", "_OriginalRanking_"}; + int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4}; + + String[] tests = new String[11]; + tests[0] = "/response/numFound/==5"; + for (int i = 1; i <= 5; i++) { + tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\""; + tests[i + 5] = "/response/docs/[" + (i - 1) + "]/interleavingPick==" + expectedInterleavingPicks[(i - 1)]; + } + assertJQ("/query" + query.toQueryString(), tests); + + } + + + + @Test + public void interleavingTransformer_shouldBeCompatibleWithFeatureTransformer() throws Exception { + TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1] + loadFeaturesAndModelsForInterleaving(); + + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*, score,interleavingPick:[interleaving],features:[fv format=sparse]"); + query.add("rows", "10"); + query.add("debugQuery", "true"); + query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8 + query.add("rq", + "{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}"); + + /* + Doc1 = "featureB2=1.0", ScoreA(0), ScoreB(4) + Doc3 = "featureB2=1.0", ScoreA(0), ScoreB(4) + Doc4 = "featureB1=1.0", ScoreA(0), ScoreB(2) + Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39), ScoreB(12) + Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3), ScoreB(4) + ModelARerankedList = [7,8,1,3,4] + ModelBRerankedList = [7,1,3,8,4] + + Random Boolean Choices Generation from Seed: [0,1,1] + */ + String[] expectedFeatureVectors = new String[]{ + "featureA1\\=1.0\\,featureA2\\=1.0\\,featureAB\\=1.0\\,featureB2\\=1.0", + "featureB2\\=1.0", + "featureB2\\=1.0", + "featureA1\\=1.0\\,featureB2\\=1.0", + "featureB1\\=1.0"}; + String[] expectedInterleavingPicks = new String[]{"modelA", "modelB", "modelB", "modelA", "modelB"}; + int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4}; + + String[] tests = new String[16]; + tests[0] = "/response/numFound/==5"; + for (int i = 1; i <= 5; i++) { + tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\""; + tests[i + 5] = "/response/docs/[" + (i - 1) + "]/interleavingPick==" + expectedInterleavingPicks[(i - 1)]; + tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)]; + } + assertJQ("/query" + query.toQueryString(), tests); + } + + @Test + public void interleavingTransformerWithOriginalRanking_shouldBeCompatibleWithFeatureTransformer() throws Exception { + TeamDraftInterleaving.setRANDOM(new Random(10101010));//Random Boolean Choices Generation from Seed: [0,1,1] + loadFeaturesAndModelsForInterleaving(); + + final SolrQuery query = new SolrQuery(); + query.setQuery("*:*"); + query.add("fl", "*, score,interleavingPick:[interleaving],features:[fv format=sparse]"); + query.add("rows", "10"); + query.add("debugQuery", "true"); + query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8 + query.add("rq", + "{!ltr model=modelA model=_OriginalRanking_ reRankDocs=10 efi.user_query='w5'}"); + + /* + Doc1 = "featureB2=1.0", ScoreA(0) + Doc3 = "featureB2=1.0", ScoreA(0) + Doc4 = "featureB1=1.0", ScoreA(0) + Doc7 ="featureA1=1.0,featureA2=1.0,featureAB=1.0,featureB2=1.0", ScoreA(39) + Doc8 = "featureA1=1.0,featureB2=1.0", ScoreA(3) + + ModelARerankedList = [7,8,1,3,4] + _OriginalRanking_ = [1,3,4,7,8] + + + Random Boolean Choices Generation from Seed: [0,1,1] + */ + String[] expectedFeatureVectors = new String[]{ + "featureA1\\=1.0\\,featureA2\\=1.0\\,featureAB\\=1.0\\,featureB2\\=1.0", + null, + null, + "featureA1\\=1.0\\,featureB2\\=1.0", + null}; + String[] expectedInterleavingPicks = new String[]{"modelA", "_OriginalRanking_", "_OriginalRanking_", "modelA", "_OriginalRanking_"}; + int[] expectedInterleaved = new int[]{7, 1, 3, 8, 4}; + + String[] tests = new String[16]; + tests[0] = "/response/numFound/==5"; + for (int i = 1; i <= 5; i++) { + tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\""; + tests[i + 5] = "/response/docs/[" + (i - 1) + "]/interleavingPick==" + expectedInterleavingPicks[(i - 1)]; + if (expectedFeatureVectors[(i - 1)] != null) { + tests[i + 10] = "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)]; + } + } + assertJQ("/query" + query.toQueryString(), tests); + + int[] nullFeatureVectorIndexes = new int[]{1, 2, 4}; + for (int index : nullFeatureVectorIndexes) { + TeamDraftInterleaving.setRANDOM(new Random(10101010)); + String[] nullFeatureVectorTests = new String[1]; + try { + nullFeatureVectorTests[0] = "/response/docs/[" + index + "]/features=="; + assertJQ("/query" + query.toQueryString(), nullFeatureVectorTests); + } catch (Exception e) { + assertEquals("Path not found: /response/docs/[" + index + "]/features", e.getMessage()); + continue; + } + + } + + } + +} diff --git a/solr/solr-ref-guide/src/learning-to-rank.adoc b/solr/solr-ref-guide/src/learning-to-rank.adoc index 7365fb07537..481f3b76902 100644 --- a/solr/solr-ref-guide/src/learning-to-rank.adoc +++ b/solr/solr-ref-guide/src/learning-to-rank.adoc @@ -38,6 +38,13 @@ A ranking model computes the scores used to rerank documents. Irrespective of an * features that represent the document being scored * features that represent the query for which the document is being scored +==== Interleaving + +Interleaving is an approach to Online Search Quality evaluation that allows to compare two models interleaving their results in the final ranked list returned to the user. + +* currently only the Team Draft Interleaving algorithm is supported (and its implementation assumes all results are from the same shard) + + ==== Feature A feature is a value, a number, that represents some quantity or quality of the document being scored or of the query for which documents are being scored. For example documents often have a 'recency' quality and 'number of past purchases' might be a quantity that is passed to Solr as part of the search query. @@ -247,6 +254,81 @@ The output XML will include feature values as a comma-separated list, resembling }} ---- +=== Running a Rerank Query Interleaving Two Models + +To rerank the results of a query, interleaving two models (myModelA, myModelB) add the `rq` parameter to your search, passing two models in input, for example: + +[source,text] +http://localhost:8983/solr/techproducts/query?q=test&rq={!ltr model=myModelA model=myModelB reRankDocs=100}&fl=id,score + +To obtain the model that interleaving picked for a search result, computed during reranking, add `[interleaving]` to the `fl` parameter, for example: + +[source,text] +http://localhost:8983/solr/techproducts/query?q=test&rq={!ltr model=myModelA model=myModelB reRankDocs=100}&fl=id,score,[interleaving] + +The output XML will include the model picked for each search result, resembling the output shown here: + +[source,json] +---- +{ + "responseHeader":{ + "status":0, + "QTime":0, + "params":{ + "q":"test", + "fl":"id,score,[interleaving]", + "rq":"{!ltr model=myModelA model=myModelB reRankDocs=100}"}}, + "response":{"numFound":2,"start":0,"maxScore":1.0005897,"docs":[ + { + "id":"GB18030TEST", + "score":1.0005897, + "[interleaving]":"myModelB"}, + { + "id":"UTF8TEST", + "score":0.79656565, + "[interleaving]":"myModelA"}] + }} +---- + +=== Running a Rerank Query Interleaving a model with the original ranking +When approaching Search Quality Evaluation with interleaving it may be useful to compare a model with the original ranking. +To rerank the results of a query, interleaving a model with the original ranking, add the `rq` parameter to your search, passing the special inbuilt `_OriginalRanking_` model identifier as one model and your comparison model as the other model, for example: + + +[source,text] +http://localhost:8983/solr/techproducts/query?q=test&rq={!ltr model=_OriginalRanking_ model=myModel reRankDocs=100}&fl=id,score + +The addition of the `rq` parameter will not change the output XML of the search. + +To obtain the model that interleaving picked for a search result, computed during reranking, add `[interleaving]` to the `fl` parameter, for example: + +[source,text] +http://localhost:8983/solr/techproducts/query?q=test&rq={!ltr model=_OriginalRanking_ model=myModel reRankDocs=100}&fl=id,score,[interleaving] + +The output XML will include the model picked for each search result, resembling the output shown here: + +[source,json] +---- +{ + "responseHeader":{ + "status":0, + "QTime":0, + "params":{ + "q":"test", + "fl":"id,score,[features]", + "rq":"{!ltr model=_OriginalRanking_ model=myModel reRankDocs=100}"}}, + "response":{"numFound":2,"start":0,"maxScore":1.0005897,"docs":[ + { + "id":"GB18030TEST", + "score":1.0005897, + "[interleaving]":"_OriginalRanking_"}, + { + "id":"UTF8TEST", + "score":0.79656565, + "[interleaving]":"myModel"}] + }} +---- + === External Feature Information The {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/feature/ValueFeature.html[ValueFeature] and {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/feature/SolrFeature.html[SolrFeature] classes support the use of external feature information, `efi` for short. @@ -418,6 +500,13 @@ Learning-To-Rank is a contrib module and therefore its plugins must be configure ---- +* Declaration of the `[interleaving]` transformer. ++ +[source,xml] +---- + +---- + === Advanced Options ==== LTRThreadModule @@ -446,11 +535,12 @@ How does Solr Learning-To-Rank work under the hood?:: Please refer to the `ltr` {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/package-summary.html[javadocs] for an implementation overview. How could I write additional models and/or features?:: -Contributions for further models, features and normalizers are welcome. Related links: +Contributions for further models, features, normalizers and interleaving algorithms are welcome. Related links: + * {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/model/LTRScoringModel.html[LTRScoringModel javadocs] * {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/feature/Feature.html[Feature javadocs] * {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/norm/Normalizer.html[Normalizer javadocs] +* {solr-javadocs}/contrib/ltr/org/apache/solr/ltr/interleaving/Interleaving.html[Interleaving javadocs] * https://cwiki.apache.org/confluence/display/solr/HowToContribute * https://cwiki.apache.org/confluence/display/LUCENE/HowToContribute @@ -779,3 +869,7 @@ The feature store and the model store are both <