SOLR-6297: Fix for Distributed WordBreakSolrSpellChecker

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1622476 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
James Dyer 2014-09-04 13:57:05 +00:00
parent 96169ba6b2
commit 305d6829a7
8 changed files with 101 additions and 33 deletions

View File

@ -164,6 +164,9 @@ Bug Fixes
* SOLR-6024: Fix StatsComponent when using docValues="true" multiValued="true" * SOLR-6024: Fix StatsComponent when using docValues="true" multiValued="true"
(Vitaliy Zhovtyuk & Tomas Fernandez-Lobbe via hossman) (Vitaliy Zhovtyuk & Tomas Fernandez-Lobbe via hossman)
* SOLR-6297: Fix WordBreakSolrSpellChecker to not lose suggestions in shard/cloud
environments (James Dyer)
Other Changes Other Changes
--------------------- ---------------------

View File

@ -192,17 +192,18 @@ public class SpellCheckComponent extends SearchComponent implements SolrCoreAwar
boolean isCorrectlySpelled = hits > (maxResultsForSuggest==null ? 0 : maxResultsForSuggest); boolean isCorrectlySpelled = hits > (maxResultsForSuggest==null ? 0 : maxResultsForSuggest);
NamedList response = new SimpleOrderedMap(); NamedList response = new SimpleOrderedMap();
NamedList suggestions = toNamedList(shardRequest, spellingResult, q, extendedResults); NamedList suggestions = toNamedList(shardRequest, spellingResult, q, extendedResults);
response.add("suggestions", suggestions); response.add("suggestions", suggestions);
if (extendedResults) { if (extendedResults) {
response.add("correctlySpelled", isCorrectlySpelled); response.add("correctlySpelled", isCorrectlySpelled);
} }
if (collate) { if (collate) {
addCollationsToResponse(params, spellingResult, rb, q, response, spellChecker.isSuggestionsMayOverlap()); addCollationsToResponse(params, spellingResult, rb, q, response, spellChecker.isSuggestionsMayOverlap());
} }
if (shardRequest) {
addOriginalTermsToResponse(response, tokens);
}
rb.rsp.add("spellcheck", response); rb.rsp.add("spellcheck", response);
@ -261,6 +262,14 @@ public class SpellCheckComponent extends SearchComponent implements SolrCoreAwar
response.add("collations", collationList); response.add("collations", collationList);
} }
private void addOriginalTermsToResponse(NamedList response, Collection<Token> originalTerms) {
List<String> originalTermStr = new ArrayList<String>();
for(Token t : originalTerms) {
originalTermStr.add(t.toString());
}
response.add("originalTerms", originalTermStr);
}
/** /**
* For every param that is of the form "spellcheck.[dictionary name].XXXX=YYYY, add * For every param that is of the form "spellcheck.[dictionary name].XXXX=YYYY, add
* XXXX=YYYY as a param to the custom param list * XXXX=YYYY as a param to the custom param list
@ -392,8 +401,14 @@ public class SpellCheckComponent extends SearchComponent implements SolrCoreAwar
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private void collectShardSuggestions(NamedList nl, SpellCheckMergeData mergeData) { private void collectShardSuggestions(NamedList nl, SpellCheckMergeData mergeData) {
System.out.println(nl);
SpellCheckResponse spellCheckResp = new SpellCheckResponse(nl); SpellCheckResponse spellCheckResp = new SpellCheckResponse(nl);
Iterable<Object> originalTermStrings = (Iterable<Object>) nl.get("originalTerms");
if(originalTermStrings!=null) {
mergeData.originalTerms = new HashSet<>();
for (Object originalTermObj : originalTermStrings) {
mergeData.originalTerms.add(originalTermObj.toString());
}
}
for (SpellCheckResponse.Suggestion suggestion : spellCheckResp.getSuggestions()) { for (SpellCheckResponse.Suggestion suggestion : spellCheckResp.getSuggestions()) {
mergeData.origVsSuggestion.put(suggestion.getToken(), suggestion); mergeData.origVsSuggestion.put(suggestion.getToken(), suggestion);
HashSet<String> suggested = mergeData.origVsSuggested.get(suggestion.getToken()); HashSet<String> suggested = mergeData.origVsSuggested.get(suggestion.getToken());
@ -615,8 +630,8 @@ public class SpellCheckComponent extends SearchComponent implements SolrCoreAwar
} }
if (hasFreqInfo) { if (hasFreqInfo) {
int tokenFrequency = spellingResult.getTokenFrequency(inputToken); Integer tokenFrequency = spellingResult.getTokenFrequency(inputToken);
if (tokenFrequency == 0) { if (tokenFrequency==null || tokenFrequency == 0) {
hasZeroFrequencyToken = true; hasZeroFrequencyToken = true;
} }
} }

View File

@ -17,10 +17,12 @@ package org.apache.solr.handler.component;
* limitations under the License. * limitations under the License.
*/ */
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import java.util.Set;
import org.apache.lucene.search.spell.SuggestWord; import org.apache.lucene.search.spell.SuggestWord;
import org.apache.solr.client.solrj.response.SpellCheckResponse; import org.apache.solr.client.solrj.response.SpellCheckResponse;
@ -39,5 +41,14 @@ public class SpellCheckMergeData {
// alternative string -> corresponding SuggestWord object // alternative string -> corresponding SuggestWord object
public Map<String, SuggestWord> suggestedVsWord = new HashMap<>(); public Map<String, SuggestWord> suggestedVsWord = new HashMap<>();
public Map<String, SpellCheckCollation> collations = new HashMap<>(); public Map<String, SpellCheckCollation> collations = new HashMap<>();
//The original terms from the user's query.
public Set<String> originalTerms = null;
public int totalNumberShardResponses = 0; public int totalNumberShardResponses = 0;
public boolean isOriginalToQuery(String term) {
if(originalTerms==null) {
return true;
}
return originalTerms.contains(term);
}
} }

View File

@ -19,6 +19,7 @@ package org.apache.solr.spelling;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
@ -167,15 +168,18 @@ public class ConjunctionSolrSpellChecker extends SolrSpellChecker {
Map.Entry<String,Integer> corr = iter.next(); Map.Entry<String,Integer> corr = iter.next();
combinedResult.add(original, corr.getKey(), corr.getValue()); combinedResult.add(original, corr.getKey(), corr.getValue());
Integer tokenFrequency = combinedTokenFrequency.get(original); Integer tokenFrequency = combinedTokenFrequency.get(original);
if(tokenFrequency!=null) { combinedResult.addFrequency(original, tokenFrequency==null ? 0 : tokenFrequency);
combinedResult.addFrequency(original, tokenFrequency);
}
if(++numberAdded==numSug) { if(++numberAdded==numSug) {
break; break;
} }
} }
} }
if(!anyData) { if(!anyData) {
if(numberAdded==0) {
combinedResult.add(original, Collections.<String>emptyList());
Integer tokenFrequency = combinedTokenFrequency.get(original);
combinedResult.addFrequency(original, tokenFrequency==null ? 0 : tokenFrequency);
}
break; break;
} }
} }

View File

@ -100,9 +100,11 @@ public abstract class SolrSpellChecker {
for (Map.Entry<String, HashSet<String>> entry : mergeData.origVsSuggested.entrySet()) { for (Map.Entry<String, HashSet<String>> entry : mergeData.origVsSuggested.entrySet()) {
String original = entry.getKey(); String original = entry.getKey();
//Only use this suggestion if all shards reported it as misspelled. //Only use this suggestion if all shards reported it as misspelled,
//unless it was not a term original to the user's query
//(WordBreakSolrSpellChecker can add new terms to the response, and we want to keep these)
Integer numShards = mergeData.origVsShards.get(original); Integer numShards = mergeData.origVsShards.get(original);
if(numShards<mergeData.totalNumberShardResponses) { if(numShards<mergeData.totalNumberShardResponses && mergeData.isOriginalToQuery(original)) {
continue; continue;
} }

View File

@ -202,6 +202,7 @@ public class WordBreakSolrSpellChecker extends SolrSpellChecker {
List<Term> termArr = new ArrayList<>(options.tokens.size() + 2); List<Term> termArr = new ArrayList<>(options.tokens.size() + 2);
List<ResultEntry> breakSuggestionList = new ArrayList<>(); List<ResultEntry> breakSuggestionList = new ArrayList<>();
List<ResultEntry> noBreakSuggestionList = new ArrayList<>();
boolean lastOneProhibited = false; boolean lastOneProhibited = false;
boolean lastOneRequired = false; boolean lastOneRequired = false;
boolean lastOneprocedesNewBooleanOp = false; boolean lastOneprocedesNewBooleanOp = false;
@ -228,6 +229,9 @@ public class WordBreakSolrSpellChecker extends SolrSpellChecker {
if (breakWords) { if (breakWords) {
SuggestWord[][] breakSuggestions = wbsp.suggestWordBreaks(thisTerm, SuggestWord[][] breakSuggestions = wbsp.suggestWordBreaks(thisTerm,
numSuggestions, ir, options.suggestMode, sortMethod); numSuggestions, ir, options.suggestMode, sortMethod);
if(breakSuggestions.length==0) {
noBreakSuggestionList.add(new ResultEntry(tokenArr[i], null, 0));
}
for (SuggestWord[] breakSuggestion : breakSuggestions) { for (SuggestWord[] breakSuggestion : breakSuggestions) {
sb.delete(0, sb.length()); sb.delete(0, sb.length());
boolean firstOne = true; boolean firstOne = true;
@ -249,6 +253,8 @@ public class WordBreakSolrSpellChecker extends SolrSpellChecker {
} }
} }
} }
breakSuggestionList.addAll(noBreakSuggestionList);
List<ResultEntry> combineSuggestionList = Collections.emptyList(); List<ResultEntry> combineSuggestionList = Collections.emptyList();
CombineSuggestion[] combineSuggestions = wbsp.suggestWordCombinations( CombineSuggestion[] combineSuggestions = wbsp.suggestWordCombinations(
termArr.toArray(new Term[termArr.size()]), numSuggestions, ir, options.suggestMode); termArr.toArray(new Term[termArr.size()]), numSuggestions, ir, options.suggestMode);
@ -282,33 +288,24 @@ public class WordBreakSolrSpellChecker extends SolrSpellChecker {
int combineCount = 0; int combineCount = 0;
while (lastBreak != null || lastCombine != null) { while (lastBreak != null || lastCombine != null) {
if (lastBreak == null) { if (lastBreak == null) {
result.add(lastCombine.token, lastCombine.suggestion, lastCombine.freq); addToResult(result, lastCombine.token, getCombineFrequency(ir, lastCombine.token), lastCombine.suggestion, lastCombine.freq);
result.addFrequency(lastCombine.token, getCombineFrequency(ir, lastCombine.token));
lastCombine = null; lastCombine = null;
} else if (lastCombine == null) { } else if (lastCombine == null) {
result.add(lastBreak.token, lastBreak.suggestion, lastBreak.freq); addToResult(result, lastBreak.token, ir.docFreq(new Term(field, lastBreak.token.toString())), lastBreak.suggestion, lastBreak.freq);
result.addFrequency(lastBreak.token, ir.docFreq(new Term(field, lastBreak.token.toString())));
lastBreak = null; lastBreak = null;
} else if (lastBreak.freq < lastCombine.freq) { } else if (lastBreak.freq < lastCombine.freq) {
result.add(lastCombine.token, lastCombine.suggestion, lastCombine.freq); addToResult(result, lastCombine.token, getCombineFrequency(ir, lastCombine.token), lastCombine.suggestion, lastCombine.freq);
result.addFrequency(lastCombine.token, getCombineFrequency(ir, lastCombine.token));
lastCombine = null; lastCombine = null;
} else if (lastCombine.freq < lastBreak.freq) { } else if (lastCombine.freq < lastBreak.freq) {
result.add(lastBreak.token, lastBreak.suggestion, lastBreak.freq); addToResult(result, lastBreak.token, ir.docFreq(new Term(field, lastBreak.token.toString())), lastBreak.suggestion, lastBreak.freq);
result.addFrequency(lastBreak.token, ir.docFreq(new Term(field, lastBreak.token.toString())));
lastBreak = null; lastBreak = null;
} else if (breakCount >= combineCount) { } else if (breakCount >= combineCount) { //TODO: Should reverse >= to < ??S
result.add(lastCombine.token, lastCombine.suggestion, lastCombine.freq); addToResult(result, lastCombine.token, getCombineFrequency(ir, lastCombine.token), lastCombine.suggestion, lastCombine.freq);
result.addFrequency(lastCombine.token, getCombineFrequency(ir, lastCombine.token));
lastCombine = null; lastCombine = null;
} else { } else {
result.add(lastBreak.token, lastBreak.suggestion, lastBreak.freq); addToResult(result, lastBreak.token, ir.docFreq(new Term(field, lastBreak.token.toString())), lastBreak.suggestion, lastBreak.freq);
result.addFrequency(lastBreak.token, ir.docFreq(new Term(field, lastBreak.token.toString())));
lastBreak = null; lastBreak = null;
} }
if (result.getSuggestions().size() > numSuggestions) {
break;
}
if (lastBreak == null && breakIter.hasNext()) { if (lastBreak == null && breakIter.hasNext()) {
lastBreak = breakIter.next(); lastBreak = breakIter.next();
breakCount++; breakCount++;
@ -320,6 +317,15 @@ public class WordBreakSolrSpellChecker extends SolrSpellChecker {
} }
return result; return result;
} }
private void addToResult(SpellingResult result, Token token, int tokenFrequency, String suggestion, int suggestionFrequency) {
if(suggestion==null) {
result.add(token, Collections.<String>emptyList());
result.addFrequency(token, tokenFrequency);
} else {
result.add(token, suggestion, suggestionFrequency);
result.addFrequency(token, tokenFrequency);
}
}
private int getCombineFrequency(IndexReader ir, Token token) throws IOException { private int getCombineFrequency(IndexReader ir, Token token) throws IOException {
String[] words = spacePattern.split(token.toString()); String[] words = spacePattern.split(token.toString());

View File

@ -187,6 +187,12 @@ public class DistributedSpellCheckComponentTest extends BaseDistributedSearchTes
query(buildRequest("lowerfilt:(+quock +redfox +jum +ped)", query(buildRequest("lowerfilt:(+quock +redfox +jum +ped)",
false, reqHandlerWithWordbreak, random().nextBoolean(), extended, "true", count, "10", false, reqHandlerWithWordbreak, random().nextBoolean(), extended, "true", count, "10",
collate, "true", maxCollationTries, "0", maxCollations, "1", collateExtended, "true")); collate, "true", maxCollationTries, "0", maxCollations, "1", collateExtended, "true"));
query(buildRequest("lowerfilt:(+rodfix)",
false, reqHandlerWithWordbreak, random().nextBoolean(), extended, "true", count, "10",
collate, "true", maxCollationTries, "0", maxCollations, "1", collateExtended, "true"));
query(buildRequest("lowerfilt:(+son +ata)",
false, reqHandlerWithWordbreak, random().nextBoolean(), extended, "true", count, "10",
collate, "true", maxCollationTries, "0", maxCollations, "1", collateExtended, "true"));
} }
private Object[] buildRequest(String q, boolean useSpellcheckQ, String handlerName, boolean useGrouping, String... addlParams) { private Object[] buildRequest(String q, boolean useSpellcheckQ, String handlerName, boolean useGrouping, String... addlParams) {
List<Object> params = new ArrayList<>(); List<Object> params = new ArrayList<>();

View File

@ -76,7 +76,7 @@ public class WordBreakSolrSpellCheckerTest extends SolrTestCaseJ4 {
searcher.decref(); searcher.decref();
assertTrue(result != null && result.getSuggestions() != null); assertTrue(result != null && result.getSuggestions() != null);
assertTrue(result.getSuggestions().size()==6); assertTrue(result.getSuggestions().size()==9);
for(Map.Entry<Token, LinkedHashMap<String, Integer>> s : result.getSuggestions().entrySet()) { for(Map.Entry<Token, LinkedHashMap<String, Integer>> s : result.getSuggestions().entrySet()) {
Token orig = s.getKey(); Token orig = s.getKey();
@ -119,7 +119,28 @@ public class WordBreakSolrSpellCheckerTest extends SolrTestCaseJ4 {
assertTrue(orig.length()==4); assertTrue(orig.length()==4);
assertTrue(corr.length==1); assertTrue(corr.length==1);
assertTrue(corr[0].equals("pi ne")); assertTrue(corr[0].equals("pi ne"));
} else { } else if(orig.toString().equals("pine")) {
assertTrue(orig.startOffset()==10);
assertTrue(orig.endOffset()==14);
assertTrue(orig.length()==4);
assertTrue(corr.length==1);
assertTrue(corr[0].equals("pi ne"));
} else if(orig.toString().equals("apple")) {
assertTrue(orig.startOffset()==15);
assertTrue(orig.endOffset()==20);
assertTrue(orig.length()==5);
assertTrue(corr.length==0);
} else if(orig.toString().equals("good")) {
assertTrue(orig.startOffset()==21);
assertTrue(orig.endOffset()==25);
assertTrue(orig.length()==4);
assertTrue(corr.length==0);
} else if(orig.toString().equals("ness")) {
assertTrue(orig.startOffset()==26);
assertTrue(orig.endOffset()==30);
assertTrue(orig.length()==4);
assertTrue(corr.length==0);
}else {
fail("Unexpected original result: " + orig); fail("Unexpected original result: " + orig);
} }
} }