diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestOriginalScoreFeature.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestOriginalScoreFeature.java index d6512246896..4eb9bea280f 100644 --- a/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestOriginalScoreFeature.java +++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/feature/TestOriginalScoreFeature.java @@ -53,51 +53,10 @@ public class TestOriginalScoreFeature extends TestRerankBase { @Test public void testOriginalScore() throws Exception { loadFeature("score", OriginalScoreFeature.class.getCanonicalName(), "{}"); - loadModel("originalScore", LinearModel.class.getCanonicalName(), new String[] {"score"}, "{\"weights\":{\"score\":1.0}}"); - final SolrQuery query = new SolrQuery(); - query.setQuery("title:w1"); - query.add("fl", "*, score"); - query.add("rows", "4"); - query.add("wt", "json"); - - // Normal term match - assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); - assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); - assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='8'"); - assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); - assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='7'"); - - final String res = restTestHarness.query("/query" + query.toQueryString()); - final Map jsonParse = (Map) ObjectBuilder - .fromJSON(res); - final String doc0Score = ((Double) ((Map) ((ArrayList) ((Map) jsonParse - .get("response")).get("docs")).get(0)).get("score")).toString(); - final String doc1Score = ((Double) ((Map) ((ArrayList) ((Map) jsonParse - .get("response")).get("docs")).get(1)).get("score")).toString(); - final String doc2Score = ((Double) ((Map) ((ArrayList) ((Map) jsonParse - .get("response")).get("docs")).get(2)).get("score")).toString(); - final String doc3Score = ((Double) ((Map) ((ArrayList) ((Map) jsonParse - .get("response")).get("docs")).get(3)).get("score")).toString(); - - query.add("fl", "[fv]"); - query.add("rq", "{!ltr model=originalScore reRankDocs=4}"); - - assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); - assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); - assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==" - + doc0Score); - assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='8'"); - assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==" - + doc1Score); - assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); - assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==" - + doc2Score); - assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='7'"); - assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/score==" - + doc3Score); + implTestOriginalScoreResponseDocsCheck("originalScore", "score", null, null); } @Test @@ -111,12 +70,29 @@ public class TestOriginalScoreFeature extends TestRerankBase { new String[] {"origScore"}, "store2", "{\"weights\":{\"origScore\":1.0}}"); + implTestOriginalScoreResponseDocsCheck("origScore", "origScore", "c2", "2.0"); + } + + public static void implTestOriginalScoreResponseDocsCheck(String modelName, + String origScoreFeatureName, + String nonScoringFeatureName, String nonScoringFeatureValue) throws Exception { + final SolrQuery query = new SolrQuery(); query.setQuery("title:w1"); - query.add("fl", "*, score, fv:[fv]"); + query.add("fl", "*, score"); query.add("rows", "4"); query.add("wt", "json"); - query.add("rq", "{!ltr model=origScore reRankDocs=4}"); + + final int doc0Id = 1; + final int doc1Id = 8; + final int doc2Id = 6; + final int doc3Id = 7; + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='"+doc0Id+"'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='"+doc1Id+"'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='"+doc2Id+"'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='"+doc3Id+"'"); final String res = restTestHarness.query("/query" + query.toQueryString()); final Map jsonParse = (Map) ObjectBuilder @@ -130,20 +106,43 @@ public class TestOriginalScoreFeature extends TestRerankBase { final String doc3Score = ((Double) ((Map) ((ArrayList) ((Map) jsonParse .get("response")).get("docs")).get(3)).get("score")).toString(); - assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); - assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'"); - assertJQ("/query" + query.toQueryString(), - "/response/docs/[0]/fv=='" + FeatureLoggerTestUtils.toFeatureVector("origScore", doc0Score, "c2", "2.0")+"'"); - assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='8'"); + final boolean debugQuery = false; - assertJQ("/query" + query.toQueryString(), - "/response/docs/[1]/fv=='" + FeatureLoggerTestUtils.toFeatureVector("origScore", doc1Score, "c2", "2.0")+"'"); - assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='6'"); - assertJQ("/query" + query.toQueryString(), - "/response/docs/[2]/fv=='" + FeatureLoggerTestUtils.toFeatureVector("origScore", doc2Score, "c2", "2.0")+"'"); - assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='7'"); - assertJQ("/query" + query.toQueryString(), - "/response/docs/[3]/fv=='" + FeatureLoggerTestUtils.toFeatureVector("origScore", doc3Score, "c2", "2.0")+"'"); + query.remove("fl"); + query.add("fl", "*, score, fv:[fv]"); + query.add("rq", "{!ltr model="+modelName+" reRankDocs=4}"); + + assertJQ("/query" + query.toQueryString(), "/response/numFound/==4"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='"+doc0Id+"'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='"+doc1Id+"'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='"+doc2Id+"'"); + assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='"+doc3Id+"'"); + + implTestOriginalScoreResponseDocsCheck(modelName, query, 0, doc0Id, origScoreFeatureName, doc0Score, + nonScoringFeatureName, nonScoringFeatureValue, debugQuery); + implTestOriginalScoreResponseDocsCheck(modelName, query, 1, doc1Id, origScoreFeatureName, doc1Score, + nonScoringFeatureName, nonScoringFeatureValue, debugQuery); + implTestOriginalScoreResponseDocsCheck(modelName, query, 2, doc2Id, origScoreFeatureName, doc2Score, + nonScoringFeatureName, nonScoringFeatureValue, debugQuery); + implTestOriginalScoreResponseDocsCheck(modelName, query, 3, doc3Id, origScoreFeatureName, doc3Score, + nonScoringFeatureName, nonScoringFeatureValue, debugQuery); + } + + private static void implTestOriginalScoreResponseDocsCheck(String modelName, + SolrQuery query, int docIdx, int docId, + String origScoreFeatureName, String origScoreFeatureValue, + String nonScoringFeatureName, String nonScoringFeatureValue, + boolean debugQuery) throws Exception { + + final String fv; + if (nonScoringFeatureName == null) { + fv = FeatureLoggerTestUtils.toFeatureVector(origScoreFeatureName, origScoreFeatureValue); + } else { + fv = FeatureLoggerTestUtils.toFeatureVector(origScoreFeatureName, origScoreFeatureValue, nonScoringFeatureName, nonScoringFeatureValue); + } + + assertJQ("/query" + query.toQueryString(), "/response/docs/["+docIdx+"]/fv=='"+fv+"'"); + // TODO: use debugQuery } }