LUCENE-6821 - TermQuery's constructors should clone the incoming term

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1709576 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2015-10-20 11:57:37 +00:00
parent d2bee4788a
commit 9534904c25
10 changed files with 64 additions and 60 deletions

View File

@ -152,16 +152,14 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
while ((next = classesEnum.next()) != null) {
if (next.length > 0) {
// We are passing the term to IndexSearcher so we need to make sure it will not change over time
next = BytesRef.deepCopyOf(next);
double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedText, next, docsWithClassSize);
assignedClasses.add(new ClassificationResult<>(next, clVal));
Term term = new Term(this.classFieldName, next);
double clVal = calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(tokenizedText, term, docsWithClassSize);
assignedClasses.add(new ClassificationResult<>(term.bytes(), clVal));
}
}
// normalization; the values transforms to a 0-1 range
ArrayList<ClassificationResult<BytesRef>> assignedClassesNorm = normClassificationResults(assignedClasses);
return assignedClassesNorm;
return normClassificationResults(assignedClasses);
}
/**
@ -208,18 +206,18 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
return result.toArray(new String[result.size()]);
}
private double calculateLogLikelihood(String[] tokenizedText, BytesRef c, int docsWithClass) throws IOException {
private double calculateLogLikelihood(String[] tokenizedText, Term term, int docsWithClass) throws IOException {
// for each word
double result = 0d;
for (String word : tokenizedText) {
// search with text:word AND class:c
int hits = getWordFreqForClass(word,c);
int hits = getWordFreqForClass(word, term);
// num : count the no of times the word appears in documents of class c (+1)
double num = hits + 1; // +1 is added because of add 1 smoothing
// den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|)
double den = getTextTermFreqForClass(c) + docsWithClass;
double den = getTextTermFreqForClass(term) + docsWithClass;
// P(w|c) = num/den
double wordProbability = num / den;
@ -232,18 +230,18 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
/**
* Returns the average number of unique terms times the number of docs belonging to the input class
* @param c the class
* @param term the term representing the class
* @return the average number of unique terms
* @throws IOException if a low level I/O problem happens
*/
private double getTextTermFreqForClass(BytesRef c) throws IOException {
private double getTextTermFreqForClass(Term term) throws IOException {
double avgNumberOfUniqueTerms = 0;
for (String textFieldName : textFieldNames) {
Terms terms = MultiFields.getTerms(leafReader, textFieldName);
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
}
int docsWithC = leafReader.docFreq(new Term(classFieldName, c));
int docsWithC = leafReader.docFreq(term);
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
}
@ -251,18 +249,18 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
* Returns the number of documents of the input class ( from the whole index or from a subset)
* that contains the word ( in a specific field or in all the fields if no one selected)
* @param word the token produced by the analyzer
* @param c the class
* @param term the term representing the class
* @return the number of documents of the input class
* @throws IOException if a low level I/O problem happens
*/
private int getWordFreqForClass(String word, BytesRef c) throws IOException {
private int getWordFreqForClass(String word, Term term) throws IOException {
BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder();
BooleanQuery.Builder subQuery = new BooleanQuery.Builder();
for (String textFieldName : textFieldNames) {
subQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD));
}
booleanQuery.add(new BooleanClause(subQuery.build(), BooleanClause.Occur.MUST));
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
booleanQuery.add(new BooleanClause(new TermQuery(term), BooleanClause.Occur.MUST));
if (query != null) {
booleanQuery.add(query, BooleanClause.Occur.MUST);
}
@ -271,12 +269,12 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
return totalHitCountCollector.getTotalHits();
}
private double calculateLogPrior(BytesRef currentClass, int docsWithClassSize) throws IOException {
return Math.log((double) docCount(currentClass)) - Math.log(docsWithClassSize);
private double calculateLogPrior(Term term, int docsWithClassSize) throws IOException {
return Math.log((double) docCount(term)) - Math.log(docsWithClassSize);
}
private int docCount(BytesRef countedClass) throws IOException {
return leafReader.docFreq(new Term(classFieldName, countedClass));
private int docCount(Term term) throws IOException {
return leafReader.docFreq(term);
}
/**

View File

@ -122,18 +122,18 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
int docsWithClassSize = countDocsWithClass();
while ((c = classesEnum.next()) != null) {
double classScore = 0;
Term term = new Term(this.classFieldName, c);
for (String fieldName : textFieldNames) {
List<String[]> tokensArrays = fieldName2tokensArray.get(fieldName);
double fieldScore = 0;
for (String[] fieldTokensArray : tokensArrays) {
fieldScore += calculateLogPrior(c, docsWithClassSize) + calculateLogLikelihood(fieldTokensArray, fieldName, c, docsWithClassSize) * fieldName2boost.get(fieldName);
fieldScore += calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(fieldTokensArray, fieldName, term, docsWithClassSize) * fieldName2boost.get(fieldName);
}
classScore += fieldScore;
}
assignedClasses.add(new ClassificationResult<>(BytesRef.deepCopyOf(c), classScore));
assignedClasses.add(new ClassificationResult<>(term.bytes(), classScore));
}
ArrayList<ClassificationResult<BytesRef>> assignedClassesNorm = normClassificationResults(assignedClasses);
return assignedClassesNorm;
return normClassificationResults(assignedClasses);
}
/**
@ -211,23 +211,23 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
/**
* @param tokenizedText the tokenized content of a field
* @param fieldName the input field name
* @param c the class to calculate the score of
* @param term the {@link Term} referring to the class to calculate the score of
* @param docsWithClass the total number of docs that have a class
* @return a normalized score for the class
* @throws IOException If there is a low-level I/O error
*/
private double calculateLogLikelihood(String[] tokenizedText, String fieldName, BytesRef c, int docsWithClass) throws IOException {
private double calculateLogLikelihood(String[] tokenizedText, String fieldName, Term term, int docsWithClass) throws IOException {
// for each word
double result = 0d;
for (String word : tokenizedText) {
// search with text:word AND class:c
int hits = getWordFreqForClass(word, fieldName, c);
int hits = getWordFreqForClass(word, fieldName, term);
// num : count the no of times the word appears in documents of class c (+1)
double num = hits + 1; // +1 is added because of add 1 smoothing
// den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|)
double den = getTextTermFreqForClass(c, fieldName) + docsWithClass;
double den = getTextTermFreqForClass(term, fieldName) + docsWithClass;
// P(w|c) = num/den
double wordProbability = num / den;
@ -242,16 +242,16 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
/**
* Returns the average number of unique terms times the number of docs belonging to the input class
*
* @param c the class
* @param term the class term
* @return the average number of unique terms
* @throws java.io.IOException If there is a low-level I/O error
*/
private double getTextTermFreqForClass(BytesRef c, String fieldName) throws IOException {
private double getTextTermFreqForClass(Term term, String fieldName) throws IOException {
double avgNumberOfUniqueTerms;
Terms terms = MultiFields.getTerms(leafReader, fieldName);
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
int docsWithC = leafReader.docFreq(new Term(classFieldName, c));
int docsWithC = leafReader.docFreq(term);
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
}
@ -261,16 +261,16 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
*
* @param word the token produced by the analyzer
* @param fieldName the field the word is coming from
* @param c the class
* @param term the class term
* @return number of documents of the input class
* @throws java.io.IOException If there is a low-level I/O error
*/
private int getWordFreqForClass(String word, String fieldName, BytesRef c) throws IOException {
private int getWordFreqForClass(String word, String fieldName, Term term) throws IOException {
BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder();
BooleanQuery.Builder subQuery = new BooleanQuery.Builder();
subQuery.add(new BooleanClause(new TermQuery(new Term(fieldName, word)), BooleanClause.Occur.SHOULD));
booleanQuery.add(new BooleanClause(subQuery.build(), BooleanClause.Occur.MUST));
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
booleanQuery.add(new BooleanClause(new TermQuery(term), BooleanClause.Occur.MUST));
if (query != null) {
booleanQuery.add(query, BooleanClause.Occur.MUST);
}
@ -279,11 +279,11 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
return totalHitCountCollector.getTotalHits();
}
private double calculateLogPrior(BytesRef currentClass, int docsWithClassSize) throws IOException {
return Math.log((double) docCount(currentClass)) - Math.log(docsWithClassSize);
private double calculateLogPrior(Term term, int docsWithClassSize) throws IOException {
return Math.log((double) docCount(term)) - Math.log(docsWithClassSize);
}
private int docCount(BytesRef countedClass) throws IOException {
return leafReader.docFreq(new Term(classFieldName, countedClass));
private int docCount(Term term) throws IOException {
return leafReader.docFreq(term);
}
}

View File

@ -24,6 +24,7 @@ import java.nio.charset.CodingErrorAction;
import java.nio.charset.StandardCharsets;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
/**
A Term represents a word from text. This is the unit of search. It is
@ -41,14 +42,20 @@ public final class Term implements Comparable<Term> {
* <p>Note that a null field or null bytes value results in undefined
* behavior for most Lucene APIs that accept a Term parameter.
*
* <p>WARNING: the provided BytesRef is not copied, but used directly.
* Therefore the bytes should not be modified after construction, for
* example, you should clone a copy by {@link BytesRef#deepCopyOf}
* rather than pass reused bytes from a TermsEnum.
* <p>The provided BytesRef is copied when it is non null.
*/
public Term(String fld, BytesRef bytes) {
field = fld;
this.bytes = bytes;
this.bytes = bytes == null ? null : BytesRef.deepCopyOf(bytes);
}
/** Constructs a Term with the given field and the bytes from a builder.
* <p>Note that a null field value results in undefined
* behavior for most Lucene APIs that accept a Term parameter.
*/
public Term(String fld, BytesRefBuilder bytesBuilder) {
field = fld;
this.bytes = bytesBuilder.toBytesRef();
}
/** Constructs a Term with the given field and text.
@ -93,7 +100,7 @@ public final class Term implements Comparable<Term> {
}
}
/** Returns the bytes of this term. */
/** Returns the bytes of this term, these should not be modified. */
public final BytesRef bytes() { return bytes; }
@Override

View File

@ -92,7 +92,7 @@ public final class BlendedTermQuery extends Query {
terms = ArrayUtil.grow(terms, numTerms + 1);
boosts = ArrayUtil.grow(boosts, numTerms + 1);
contexts = ArrayUtil.grow(contexts, numTerms + 1);
terms[numTerms] = new Term(term.field(), BytesRef.deepCopyOf(term.bytes()));
terms[numTerms] = term;
boosts[numTerms] = boost;
contexts[numTerms] = context;
numTerms += 1;

View File

@ -101,7 +101,6 @@ public class PhraseQuery extends Query {
*
*/
public Builder add(Term term, int position) {
term = new Term(term.field(), BytesRef.deepCopyOf(term.bytes())); // be defensive
if (position < 0) {
throw new IllegalArgumentException("Positions must be >= 0, got " + position);
}
@ -186,7 +185,7 @@ public class PhraseQuery extends Query {
private static Term[] toTerms(String field, BytesRef... termBytes) {
Term[] terms = new Term[termBytes.length];
for (int i = 0; i < terms.length; ++i) {
terms[i] = new Term(field, BytesRef.deepCopyOf(termBytes[i]));
terms[i] = new Term(field, termBytes[i]);
}
return terms;
}

View File

@ -274,7 +274,7 @@ public class QueryBuilder {
throw new AssertionError();
}
return newTermQuery(new Term(field, BytesRef.deepCopyOf(termAtt.getBytesRef())));
return newTermQuery(new Term(field, termAtt.getBytesRef()));
}
/**
@ -286,7 +286,7 @@ public class QueryBuilder {
stream.reset();
List<Term> terms = new ArrayList<>();
while (stream.incrementToken()) {
terms.add(new Term(field, BytesRef.deepCopyOf(termAtt.getBytesRef())));
terms.add(new Term(field, termAtt.getBytesRef()));
}
return newSynonymQuery(terms.toArray(new Term[terms.size()]));
@ -319,7 +319,7 @@ public class QueryBuilder {
add(q, currentQuery, operator);
currentQuery.clear();
}
currentQuery.add(new Term(field, BytesRef.deepCopyOf(termAtt.getBytesRef())));
currentQuery.add(new Term(field, termAtt.getBytesRef()));
}
add(q, currentQuery, operator);
@ -376,7 +376,7 @@ public class QueryBuilder {
multiTerms.clear();
}
position += positionIncrement;
multiTerms.add(new Term(field, BytesRef.deepCopyOf(termAtt.getBytesRef())));
multiTerms.add(new Term(field, termAtt.getBytesRef()));
}
if (enablePositionIncrements) {

View File

@ -745,7 +745,7 @@ public abstract class FieldType extends FieldProperties {
// match-only
return getRangeQuery(parser, field, externalVal, externalVal, true, true);
} else {
return new TermQuery(new Term(field.getName(), br.toBytesRef()));
return new TermQuery(new Term(field.getName(), br));
}
}

View File

@ -1194,7 +1194,7 @@ public class SolrIndexSearcher extends IndexSearcher implements Closeable,SolrIn
TermQuery key = null;
if (useCache) {
key = new TermQuery(new Term(deState.fieldName, BytesRef.deepCopyOf(deState.termsEnum.term())));
key = new TermQuery(new Term(deState.fieldName, deState.termsEnum.term()));
DocSet result = filterCache.get(key);
if (result != null) return result;
}

View File

@ -653,7 +653,7 @@ abstract class FacetFieldProcessorFCBase extends FacetFieldProcessor {
bucket.add("val", val);
TermQuery filter = needFilter ? new TermQuery(new Term(sf.getName(), BytesRef.deepCopyOf(br))) : null;
TermQuery filter = needFilter ? new TermQuery(new Term(sf.getName(), br)) : null;
fillBucket(bucket, countAcc.getCount(slotNum), slotNum, null, filter);
bucketList.add(bucket);

View File

@ -133,7 +133,7 @@ public class SimpleMLTQParser extends QParser {
BytesRefBuilder bytesRefBuilder = new BytesRefBuilder();
bytesRefBuilder.grow(NumericUtils.BUF_SIZE_INT);
NumericUtils.intToPrefixCoded(Integer.parseInt(uniqueValue), 0, bytesRefBuilder);
return new Term(field, bytesRefBuilder.toBytesRef());
return new Term(field, bytesRefBuilder);
}