prefer totalTermFrequency over docFreq in PhraseSuggester

This commit is contained in:
Simon Willnauer 2013-03-05 10:46:25 +01:00
parent 315744be55
commit 876b5a3dcd
9 changed files with 80 additions and 70 deletions

View File

@ -29,7 +29,7 @@ public abstract class CandidateGenerator {
public abstract boolean isKnownWord(BytesRef term) throws IOException;
public abstract int frequency(BytesRef term) throws IOException;
public abstract long frequency(BytesRef term) throws IOException;
public CandidateSet drawCandidates(BytesRef term, int numCandidates) throws IOException {
CandidateSet set = new CandidateSet(Candidate.EMPTY, createCandidate(term));
@ -39,7 +39,7 @@ public abstract class CandidateGenerator {
public Candidate createCandidate(BytesRef term) throws IOException {
return createCandidate(term, frequency(term), 1.0);
}
public abstract Candidate createCandidate(BytesRef term, int frequency, double channelScore) throws IOException;
public abstract Candidate createCandidate(BytesRef term, long frequency, double channelScore) throws IOException;
public abstract CandidateSet drawCandidates(CandidateSet set, int numCandidates) throws IOException;

View File

@ -29,6 +29,7 @@ import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.spell.DirectSpellChecker;
import org.apache.lucene.search.spell.SuggestMode;
import org.apache.lucene.search.spell.SuggestWord;
@ -43,13 +44,17 @@ public final class DirectCandidateGenerator extends CandidateGenerator {
private final DirectSpellChecker spellchecker;
private final String field;
private final SuggestMode suggestMode;
private final TermsEnum termsEnum;
private final IndexReader reader;
private final int docCount;
private final long dictSize;
private final double logBase = 5;
private final int frequencyPlateau;
private final long frequencyPlateau;
private final Analyzer preFilter;
private final Analyzer postFilter;
private final double nonErrorLikelihood;
private final boolean useTotalTermFrequency;
private final CharsRef spare = new CharsRef();
private final BytesRef byteSpare = new BytesRef();
public DirectCandidateGenerator(DirectSpellChecker spellchecker, String field, SuggestMode suggestMode, IndexReader reader, double nonErrorLikelihood) throws IOException {
this(spellchecker, field, suggestMode, reader, nonErrorLikelihood, null, null);
@ -65,13 +70,15 @@ public final class DirectCandidateGenerator extends CandidateGenerator {
if (terms == null) {
throw new ElasticSearchIllegalArgumentException("generator field [" + field + "] doesn't exist");
}
final int docCount = terms.getDocCount();
this.docCount = docCount == -1 ? reader.maxDoc() : docCount;
final long dictSize = terms.getSumTotalTermFreq();
this.useTotalTermFrequency = dictSize != -1;
this.dictSize = dictSize == -1 ? reader.maxDoc() : dictSize;
this.preFilter = preFilter;
this.postFilter = postFilter;
this.nonErrorLikelihood = nonErrorLikelihood;
float thresholdFrequency = spellchecker.getThresholdFrequency();
this.frequencyPlateau = thresholdFrequency >= 1.0f ? (int) thresholdFrequency: (int)(docCount * thresholdFrequency);
this.frequencyPlateau = thresholdFrequency >= 1.0f ? (int) thresholdFrequency: (int)(dictSize * thresholdFrequency);
termsEnum = terms.iterator(null);
}
/* (non-Javadoc)
@ -86,8 +93,17 @@ public final class DirectCandidateGenerator extends CandidateGenerator {
* @see org.elasticsearch.search.suggest.phrase.CandidateGenerator#frequency(org.apache.lucene.util.BytesRef)
*/
@Override
public int frequency(BytesRef term) throws IOException {
return reader.docFreq(new Term(field, term));
public long frequency(BytesRef term) throws IOException {
term = preFilter(term, spare, byteSpare);
return internalFrequency(term);
}
public long internalFrequency(BytesRef term) throws IOException {
if (termsEnum.seekExact(term, true)) {
return useTotalTermFrequency ? termsEnum.totalTermFreq() : termsEnum.docFreq();
}
return 0;
}
public String getField() {
@ -99,18 +115,16 @@ public final class DirectCandidateGenerator extends CandidateGenerator {
*/
@Override
public CandidateSet drawCandidates(CandidateSet set, int numCandidates) throws IOException {
CharsRef spare = new CharsRef();
BytesRef byteSpare = new BytesRef();
Candidate original = set.originalTerm;
BytesRef term = preFilter(original.term, spare, byteSpare);
final int frequency = original.frequency;
spellchecker.setThresholdFrequency(thresholdFrequency(frequency, docCount));
final long frequency = original.frequency;
spellchecker.setThresholdFrequency(thresholdFrequency(frequency, dictSize));
SuggestWord[] suggestSimilar = spellchecker.suggestSimilar(new Term(field, term), numCandidates, reader, this.suggestMode);
List<Candidate> candidates = new ArrayList<Candidate>(suggestSimilar.length);
for (int i = 0; i < suggestSimilar.length; i++) {
SuggestWord suggestWord = suggestSimilar[i];
BytesRef candidate = new BytesRef(suggestWord.string);
postFilter(new Candidate(candidate, suggestWord.freq, suggestWord.score, score(suggestWord.freq, suggestWord.score, docCount)), spare, byteSpare, candidates);
postFilter(new Candidate(candidate, internalFrequency(candidate), suggestWord.score, score(suggestWord.freq, suggestWord.score, dictSize)), spare, byteSpare, candidates);
}
set.addCandidates(candidates);
return set;
@ -140,24 +154,26 @@ public final class DirectCandidateGenerator extends CandidateGenerator {
@Override
public void nextToken() throws IOException {
this.fillBytesRef(result);
if (posIncAttr.getPositionIncrement() > 0 && result.bytesEquals(candidate.term)) {
candidates.add(new Candidate(BytesRef.deepCopyOf(result), candidate.frequency, candidate.stringDistance, score(candidate.frequency, candidate.stringDistance, docCount)));
BytesRef term = BytesRef.deepCopyOf(result);
long freq = frequency(term);
candidates.add(new Candidate(BytesRef.deepCopyOf(term), freq, candidate.stringDistance, score(candidate.frequency, candidate.stringDistance, dictSize)));
} else {
int freq = frequency(result);
candidates.add(new Candidate(BytesRef.deepCopyOf(result), freq, nonErrorLikelihood, score(candidate.frequency, candidate.stringDistance, docCount)));
candidates.add(new Candidate(BytesRef.deepCopyOf(result), candidate.frequency, nonErrorLikelihood, score(candidate.frequency, candidate.stringDistance, dictSize)));
}
}
}, spare);
}
}
private double score(int frequency, double errorScore, int docCount) {
return errorScore * (((double)frequency + 1) / ((double)docCount +1));
private double score(long frequency, double errorScore, long dictionarySize) {
return errorScore * (((double)frequency + 1) / ((double)dictionarySize +1));
}
protected int thresholdFrequency(int termFrequency, int docCount) {
protected long thresholdFrequency(long termFrequency, long dictionarySize) {
if (termFrequency > 0) {
return (int) Math.round(termFrequency * (Math.log10(termFrequency - frequencyPlateau) * (1.0 / Math.log10(logBase))) + 1);
return (long) Math.round(termFrequency * (Math.log10(termFrequency - frequencyPlateau) * (1.0 / Math.log10(logBase))) + 1);
}
return 0;
@ -193,10 +209,10 @@ public final class DirectCandidateGenerator extends CandidateGenerator {
public static final Candidate[] EMPTY = new Candidate[0];
public final BytesRef term;
public final double stringDistance;
public final int frequency;
public final long frequency;
public final double score;
public Candidate(BytesRef term, int frequency, double stringDistance, double score) {
public Candidate(BytesRef term, long frequency, double stringDistance, double score) {
this.frequency = frequency;
this.term = term;
this.stringDistance = stringDistance;
@ -235,8 +251,8 @@ public final class DirectCandidateGenerator extends CandidateGenerator {
}
@Override
public Candidate createCandidate(BytesRef term, int frequency, double channelScore) throws IOException {
return new Candidate(term, frequency, channelScore, score(frequency, channelScore, docCount));
public Candidate createCandidate(BytesRef term, long frequency, double channelScore) throws IOException {
return new Candidate(term, frequency, channelScore, score(frequency, channelScore, dictSize));
}
}

View File

@ -42,23 +42,18 @@ public final class LaplaceScorer extends WordScorer {
this.alpha = alpha;
}
public double score(Candidate word, Candidate previousWord) throws IOException{
SuggestUtils.join(separator, spare, previousWord.term, word.term);
return (alpha + frequency(spare)) / (alpha + previousWord.frequency);
}
@Override
protected double scoreBigram(Candidate word, Candidate w_1) throws IOException {
SuggestUtils.join(separator, spare, w_1.term, word.term);
return (alpha + frequency(spare)) / (alpha + w_1.frequency);
return (alpha + frequency(spare)) / (alpha + w_1.frequency + vocabluarySize);
}
@Override
protected double scoreTrigram(Candidate word, Candidate w_1, Candidate w_2) throws IOException {
SuggestUtils.join(separator, spare, w_2.term, w_1.term, word.term);
int trigramCount = frequency(spare);
long trigramCount = frequency(spare);
SuggestUtils.join(separator, spare, w_1.term, word.term);
return (alpha + trigramCount) / (alpha + frequency(spare));
return (alpha + trigramCount) / (alpha + frequency(spare) + vocabluarySize);
}

View File

@ -44,7 +44,7 @@ public final class LinearInterpoatingScorer extends WordScorer {
@Override
protected double scoreBigram(Candidate word, Candidate w_1) throws IOException {
SuggestUtils.join(separator, spare, w_1.term, word.term);
final int count = frequency(spare);
final long count = frequency(spare);
if (count < 1) {
return unigramLambda * scoreUnigram(word);
}
@ -54,7 +54,7 @@ public final class LinearInterpoatingScorer extends WordScorer {
@Override
protected double scoreTrigram(Candidate w, Candidate w_1, Candidate w_2) throws IOException {
SuggestUtils.join(separator, spare, w.term, w_1.term, w_2.term);
final int count = frequency(spare);
final long count = frequency(spare);
if (count < 1) {
return scoreBigram(w, w_1);
}

View File

@ -40,7 +40,7 @@ public final class MultiCandidateGeneratorWrapper extends CandidateGenerator {
}
@Override
public int frequency(BytesRef term) throws IOException {
public long frequency(BytesRef term) throws IOException {
return candidateGenerator[0].frequency(term);
}
@ -70,7 +70,7 @@ public final class MultiCandidateGeneratorWrapper extends CandidateGenerator {
return set;
}
@Override
public Candidate createCandidate(BytesRef term, int frequency, double channelScore) throws IOException {
public Candidate createCandidate(BytesRef term, long frequency, double channelScore) throws IOException {
return candidateGenerator[0].createCandidate(term, frequency, channelScore);
}

View File

@ -81,7 +81,7 @@ public final class NoisyChannelSpellChecker {
anyUnigram = true;
if (posIncAttr.getPositionIncrement() == 0 && typeAttribute.type() == SynonymFilter.TYPE_SYNONYM) {
assert currentSet != null;
int freq = 0;
long freq = 0;
if ((freq = generator.frequency(term)) > 0) {
currentSet.addOneCandidate(generator.createCandidate(BytesRef.deepCopyOf(term), freq, realWordLikelihood));
}

View File

@ -44,7 +44,7 @@ public class StupidBackoffScorer extends WordScorer {
@Override
protected double scoreBigram(Candidate word, Candidate w_1) throws IOException {
SuggestUtils.join(separator, spare, w_1.term, word.term);
final int count = frequency(spare);
final long count = frequency(spare);
if (count < 1) {
return discount * scoreUnigram(word);
}
@ -54,17 +54,17 @@ public class StupidBackoffScorer extends WordScorer {
@Override
protected double scoreTrigram(Candidate w, Candidate w_1, Candidate w_2) throws IOException {
SuggestUtils.join(separator, spare, w_2.term, w_1.term, w.term);
final int trigramCount = frequency(spare);
final long trigramCount = frequency(spare);
if (trigramCount < 1) {
SuggestUtils.join(separator, spare, w_1.term, w.term);
final int count = frequency(spare);
final long count = frequency(spare);
if (count < 1) {
return discount * scoreUnigram(w);
}
return discount * (count / (w_1.frequency + 0.00000000001d));
}
SuggestUtils.join(separator, spare, w_1.term, w.term);
final int bigramCount = frequency(spare);
final long bigramCount = frequency(spare);
return trigramCount / (bigramCount + 0.00000000001d);
}

View File

@ -25,7 +25,6 @@ import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.ElasticSearchException;
import org.elasticsearch.ElasticSearchIllegalArgumentException;
import org.elasticsearch.search.suggest.phrase.DirectCandidateGenerator.Candidate;
import org.elasticsearch.search.suggest.phrase.DirectCandidateGenerator.CandidateSet;
@ -35,11 +34,13 @@ public abstract class WordScorer {
protected final IndexReader reader;
protected final String field;
protected final Terms terms;
protected final int totalDocuments;
protected final long vocabluarySize;
protected double realWordLikelyhood;
protected final BytesRef spare = new BytesRef();
protected final BytesRef separator;
protected final TermsEnum termsEnum;
private final long numTerms;
private final boolean useTotalTermFreq;
public WordScorer(IndexReader reader, String field, double realWordLikelyHood, BytesRef separator) throws IOException {
this.field = field;
@ -47,17 +48,19 @@ public abstract class WordScorer {
if (terms == null) {
throw new ElasticSearchIllegalArgumentException("Field: [" + field + "] does not exist");
}
final int docCount = terms.getDocCount();
this.totalDocuments = docCount == -1 ? reader.maxDoc() : docCount;
final long vocSize = terms.getSumTotalTermFreq();
this.vocabluarySize = vocSize == -1 ? reader.maxDoc() : vocSize;
this.useTotalTermFreq = vocSize != -1;
this.numTerms = terms.size();
this.termsEnum = terms.iterator(null);
this.reader = reader;
this.realWordLikelyhood = realWordLikelyHood;
this.separator = separator;
}
public int frequency(BytesRef term) throws IOException {
public long frequency(BytesRef term) throws IOException {
if (termsEnum.seekExact(term, true)) {
return termsEnum.docFreq();
return useTotalTermFreq ? termsEnum.totalTermFreq() : termsEnum.docFreq();
}
return 0;
}
@ -80,7 +83,7 @@ public abstract class WordScorer {
}
protected double scoreUnigram(Candidate word) throws IOException {
return (1.0 + word.frequency) / (1.0 + totalDocuments);
return (1.0 + frequency(word.term)) / (vocabluarySize + numTerms);
}
protected double scoreBigram(Candidate word, Candidate w_1) throws IOException {

View File

@ -123,15 +123,15 @@ public class NoisyChannelSpellCheckerTests {
assertThat(corrections.length, equalTo(4));
assertThat(corrections[0].join(new BytesRef(" ")).utf8ToString(), equalTo("xorr the god jewel"));
assertThat(corrections[1].join(new BytesRef(" ")).utf8ToString(), equalTo("xor the god jewel"));
assertThat(corrections[2].join(new BytesRef(" ")).utf8ToString(), equalTo("xorr the got jewel"));
assertThat(corrections[3].join(new BytesRef(" ")).utf8ToString(), equalTo("xorn the god jewel"));
assertThat(corrections[2].join(new BytesRef(" ")).utf8ToString(), equalTo("xorn the god jewel"));
assertThat(corrections[3].join(new BytesRef(" ")).utf8ToString(), equalTo("xorr the got jewel"));
corrections = suggester.getCorrections(wrapper, new BytesRef("Xor the Got-Jewel"), generator, 5, 0.5f, 4, ir, "body", wordScorer, 1, 2);
assertThat(corrections.length, equalTo(4));
assertThat(corrections[0].join(new BytesRef(" ")).utf8ToString(), equalTo("xorr the god jewel"));
assertThat(corrections[1].join(new BytesRef(" ")).utf8ToString(), equalTo("xor the god jewel"));
assertThat(corrections[2].join(new BytesRef(" ")).utf8ToString(), equalTo("xorr the got jewel"));
assertThat(corrections[3].join(new BytesRef(" ")).utf8ToString(), equalTo("xorn the god jewel"));
assertThat(corrections[2].join(new BytesRef(" ")).utf8ToString(), equalTo("xorn the god jewel"));
assertThat(corrections[3].join(new BytesRef(" ")).utf8ToString(), equalTo("xorr the got jewel"));
// test synonyms
@ -219,11 +219,11 @@ public class NoisyChannelSpellCheckerTests {
NoisyChannelSpellChecker suggester = new NoisyChannelSpellChecker();
DirectSpellChecker spellchecker = new DirectSpellChecker();
spellchecker.setMinQueryLength(1);
DirectCandidateGenerator forward = new DirectCandidateGenerator(spellchecker, "body", SuggestMode.SUGGEST_MORE_POPULAR, ir, 0.95);
DirectCandidateGenerator reverse = new DirectCandidateGenerator(spellchecker, "body_reverse", SuggestMode.SUGGEST_MORE_POPULAR, ir, 0.95, wrapper, wrapper);
DirectCandidateGenerator forward = new DirectCandidateGenerator(spellchecker, "body", SuggestMode.SUGGEST_ALWAYS, ir, 0.95);
DirectCandidateGenerator reverse = new DirectCandidateGenerator(spellchecker, "body_reverse", SuggestMode.SUGGEST_ALWAYS, ir, 0.95, wrapper, wrapper);
CandidateGenerator generator = new MultiCandidateGeneratorWrapper(forward, reverse);
Correction[] corrections = suggester.getCorrections(wrapper, new BytesRef("american cae"), generator, 5, 1, 1, ir, "body", wordScorer, 1, 2);
Correction[] corrections = suggester.getCorrections(wrapper, new BytesRef("american cae"), generator, 10, 1, 1, ir, "body", wordScorer, 1, 2);
assertThat(corrections.length, equalTo(1));
assertThat(corrections[0].join(new BytesRef(" ")).utf8ToString(), equalTo("american ace"));
@ -241,9 +241,9 @@ public class NoisyChannelSpellCheckerTests {
corrections = suggester.getCorrections(wrapper, new BytesRef("Zorr the Got-Jewel"), generator, 5, 0.5f, 4, ir, "body", wordScorer, 0, 2);
assertThat(corrections.length, equalTo(4));
assertThat(corrections[0].join(new BytesRef(" ")).utf8ToString(), equalTo("xorr the god jewel"));
assertThat(corrections[1].join(new BytesRef(" ")).utf8ToString(), equalTo("xorr the got jewel"));
assertThat(corrections[2].join(new BytesRef(" ")).utf8ToString(), equalTo("zorr the god jewel"));
assertThat(corrections[3].join(new BytesRef(" ")).utf8ToString(), equalTo("gorr the god jewel"));
assertThat(corrections[1].join(new BytesRef(" ")).utf8ToString(), equalTo("zorr the god jewel"));
assertThat(corrections[2].join(new BytesRef(" ")).utf8ToString(), equalTo("gorr the god jewel"));
assertThat(corrections[3].join(new BytesRef(" ")).utf8ToString(), equalTo("tarr the god jewel"));
@ -316,9 +316,9 @@ public class NoisyChannelSpellCheckerTests {
corrections = suggester.getCorrections(wrapper, new BytesRef("Xor the Got-Jewel"), generator, 5, 0.5f, 4, ir, "body", wordScorer, 0, 3);
assertThat(corrections.length, equalTo(4));
assertThat(corrections[0].join(new BytesRef(" ")).utf8ToString(), equalTo("xorr the god jewel"));
assertThat(corrections[1].join(new BytesRef(" ")).utf8ToString(), equalTo("xorn the god jewel"));
assertThat(corrections[2].join(new BytesRef(" ")).utf8ToString(), equalTo("xor the god jewel"));
assertThat(corrections[3].join(new BytesRef(" ")).utf8ToString(), equalTo("xorr the gog jewel"));
assertThat(corrections[1].join(new BytesRef(" ")).utf8ToString(), equalTo("xor the god jewel"));
assertThat(corrections[2].join(new BytesRef(" ")).utf8ToString(), equalTo("xorn the god jewel"));
assertThat(corrections[3].join(new BytesRef(" ")).utf8ToString(), equalTo("xorr the got jewel"));
@ -326,9 +326,9 @@ public class NoisyChannelSpellCheckerTests {
corrections = suggester.getCorrections(wrapper, new BytesRef("Xor the Got-Jewel"), generator, 5, 0.5f, 4, ir, "body", wordScorer, 1, 3);
assertThat(corrections.length, equalTo(4));
assertThat(corrections[0].join(new BytesRef(" ")).utf8ToString(), equalTo("xorr the god jewel"));
assertThat(corrections[1].join(new BytesRef(" ")).utf8ToString(), equalTo("xorn the god jewel"));
assertThat(corrections[2].join(new BytesRef(" ")).utf8ToString(), equalTo("xor the god jewel"));
assertThat(corrections[3].join(new BytesRef(" ")).utf8ToString(), equalTo("xorr the gog jewel"));
assertThat(corrections[1].join(new BytesRef(" ")).utf8ToString(), equalTo("xor the god jewel"));
assertThat(corrections[2].join(new BytesRef(" ")).utf8ToString(), equalTo("xorn the god jewel"));
assertThat(corrections[3].join(new BytesRef(" ")).utf8ToString(), equalTo("xorr the got jewel"));
corrections = suggester.getCorrections(wrapper, new BytesRef("Xor the Got-Jewel"), generator, 5, 0.5f, 1, ir, "body", wordScorer, 100, 3);
@ -362,20 +362,16 @@ public class NoisyChannelSpellCheckerTests {
wordScorer = new LinearInterpoatingScorer(ir, "body_ngram", 0.95d, new BytesRef(" "), 0.5, 0.4, 0.1);
corrections = suggester.getCorrections(analyzer, new BytesRef("captian usa"), generator, 10, 2, 4, ir, "body", wordScorer, 1, 3);
assertThat(corrections[0].join(new BytesRef(" ")).utf8ToString(), equalTo("captain america"));
assertThat(corrections[1].join(new BytesRef(" ")).utf8ToString(), equalTo("captain american"));
assertThat(corrections[2].join(new BytesRef(" ")).utf8ToString(), equalTo("captain ursa"));
generator = new DirectCandidateGenerator(spellchecker, "body", SuggestMode.SUGGEST_MORE_POPULAR, ir, 0.95, null, analyzer);
corrections = suggester.getCorrections(analyzer, new BytesRef("captian usw"), generator, 10, 2, 4, ir, "body", wordScorer, 1, 3);
assertThat(corrections[0].join(new BytesRef(" ")).utf8ToString(), equalTo("captain america"));
assertThat(corrections[1].join(new BytesRef(" ")).utf8ToString(), equalTo("captain american"));
assertThat(corrections[2].join(new BytesRef(" ")).utf8ToString(), equalTo("captain usw"));
wordScorer = new StupidBackoffScorer(ir, "body_ngram", 0.85d, new BytesRef(" "), 0.4);
corrections = suggester.getCorrections(wrapper, new BytesRef("Xor the Got-Jewel"), generator, 5, 0.5f, 2, ir, "body", wordScorer, 0, 3);
assertThat(corrections.length, equalTo(2));
assertThat(corrections[0].join(new BytesRef(" ")).utf8ToString(), equalTo("xorr the god jewel"));
assertThat(corrections[1].join(new BytesRef(" ")).utf8ToString(), equalTo("xorn the god jewel"));
assertThat(corrections[1].join(new BytesRef(" ")).utf8ToString(), equalTo("xor the god jewel"));
}
}