LUCENE-6821 - TermQuery's constructors should clone the incoming term

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1709576 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2015-10-20 11:57:37 +00:00
parent d2bee4788a
commit 9534904c25
10 changed files with 64 additions and 60 deletions

View File

@ -152,16 +152,14 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
while ((next = classesEnum.next()) != null) {
if (next.length > 0) {
// We are passing the term to IndexSearcher so we need to make sure it will not change over time
next = BytesRef.deepCopyOf(next);
double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedText, next, docsWithClassSize);
assignedClasses.add(new ClassificationResult<>(next, clVal));
Term term = new Term(this.classFieldName, next);
double clVal = calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(tokenizedText, term, docsWithClassSize);
assignedClasses.add(new ClassificationResult<>(term.bytes(), clVal));
}
}
// normalization; the values transforms to a 0-1 range
ArrayList<ClassificationResult<BytesRef>> assignedClassesNorm = normClassificationResults(assignedClasses);
return assignedClassesNorm;
return normClassificationResults(assignedClasses);
}
/**
@ -208,18 +206,18 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
return result.toArray(new String[result.size()]);
}
private double calculateLogLikelihood(String[] tokenizedText, BytesRef c, int docsWithClass) throws IOException {
private double calculateLogLikelihood(String[] tokenizedText, Term term, int docsWithClass) throws IOException {
// for each word
double result = 0d;
for (String word : tokenizedText) {
// search with text:word AND class:c
int hits = getWordFreqForClass(word,c);
int hits = getWordFreqForClass(word, term);
// num : count the no of times the word appears in documents of class c (+1)
double num = hits + 1; // +1 is added because of add 1 smoothing
// den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|)
double den = getTextTermFreqForClass(c) + docsWithClass;
double den = getTextTermFreqForClass(term) + docsWithClass;
// P(w|c) = num/den
double wordProbability = num / den;
@ -232,18 +230,18 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
/**
* Returns the average number of unique terms times the number of docs belonging to the input class
* @param c the class
* @param term the term representing the class
* @return the average number of unique terms
* @throws IOException if a low level I/O problem happens
*/
private double getTextTermFreqForClass(BytesRef c) throws IOException {
private double getTextTermFreqForClass(Term term) throws IOException {
double avgNumberOfUniqueTerms = 0;
for (String textFieldName : textFieldNames) {
Terms terms = MultiFields.getTerms(leafReader, textFieldName);
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
}
int docsWithC = leafReader.docFreq(new Term(classFieldName, c));
int docsWithC = leafReader.docFreq(term);
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
}
@ -251,18 +249,18 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
* Returns the number of documents of the input class ( from the whole index or from a subset)
* that contains the word ( in a specific field or in all the fields if no one selected)
* @param word the token produced by the analyzer
* @param c the class
* @param term the term representing the class
* @return the number of documents of the input class
* @throws IOException if a low level I/O problem happens
*/
private int getWordFreqForClass(String word, BytesRef c) throws IOException {
private int getWordFreqForClass(String word, Term term) throws IOException {
BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder();
BooleanQuery.Builder subQuery = new BooleanQuery.Builder();
for (String textFieldName : textFieldNames) {
subQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD));
}
booleanQuery.add(new BooleanClause(subQuery.build(), BooleanClause.Occur.MUST));
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
booleanQuery.add(new BooleanClause(new TermQuery(term), BooleanClause.Occur.MUST));
if (query != null) {
booleanQuery.add(query, BooleanClause.Occur.MUST);
}
@ -271,12 +269,12 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
return totalHitCountCollector.getTotalHits();
}
private double calculateLogPrior(BytesRef currentClass, int docsWithClassSize) throws IOException {
return Math.log((double) docCount(currentClass)) - Math.log(docsWithClassSize);
private double calculateLogPrior(Term term, int docsWithClassSize) throws IOException {
return Math.log((double) docCount(term)) - Math.log(docsWithClassSize);
}
private int docCount(BytesRef countedClass) throws IOException {
return leafReader.docFreq(new Term(classFieldName, countedClass));
private int docCount(Term term) throws IOException {
return leafReader.docFreq(term);
}
/**

View File

@ -122,18 +122,18 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
int docsWithClassSize = countDocsWithClass();
while ((c = classesEnum.next()) != null) {
double classScore = 0;
Term term = new Term(this.classFieldName, c);
for (String fieldName : textFieldNames) {
List<String[]> tokensArrays = fieldName2tokensArray.get(fieldName);
double fieldScore = 0;
for (String[] fieldTokensArray : tokensArrays) {
fieldScore += calculateLogPrior(c, docsWithClassSize) + calculateLogLikelihood(fieldTokensArray, fieldName, c, docsWithClassSize) * fieldName2boost.get(fieldName);
fieldScore += calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(fieldTokensArray, fieldName, term, docsWithClassSize) * fieldName2boost.get(fieldName);
}
classScore += fieldScore;
}
assignedClasses.add(new ClassificationResult<>(BytesRef.deepCopyOf(c), classScore));
assignedClasses.add(new ClassificationResult<>(term.bytes(), classScore));
}
ArrayList<ClassificationResult<BytesRef>> assignedClassesNorm = normClassificationResults(assignedClasses);
return assignedClassesNorm;
return normClassificationResults(assignedClasses);
}
/**
@ -211,23 +211,23 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
/**
* @param tokenizedText the tokenized content of a field
* @param fieldName the input field name
* @param c the class to calculate the score of
* @param term the {@link Term} referring to the class to calculate the score of
* @param docsWithClass the total number of docs that have a class
* @return a normalized score for the class
* @throws IOException If there is a low-level I/O error
*/
private double calculateLogLikelihood(String[] tokenizedText, String fieldName, BytesRef c, int docsWithClass) throws IOException {
private double calculateLogLikelihood(String[] tokenizedText, String fieldName, Term term, int docsWithClass) throws IOException {
// for each word
double result = 0d;
for (String word : tokenizedText) {
// search with text:word AND class:c
int hits = getWordFreqForClass(word, fieldName, c);
int hits = getWordFreqForClass(word, fieldName, term);
// num : count the no of times the word appears in documents of class c (+1)
double num = hits + 1; // +1 is added because of add 1 smoothing
// den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|)
double den = getTextTermFreqForClass(c, fieldName) + docsWithClass;
double den = getTextTermFreqForClass(term, fieldName) + docsWithClass;
// P(w|c) = num/den
double wordProbability = num / den;
@ -242,16 +242,16 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
/**
* Returns the average number of unique terms times the number of docs belonging to the input class
*
* @param c the class
* @param term the class term
* @return the average number of unique terms
* @throws java.io.IOException If there is a low-level I/O error
*/
private double getTextTermFreqForClass(BytesRef c, String fieldName) throws IOException {
private double getTextTermFreqForClass(Term term, String fieldName) throws IOException {
double avgNumberOfUniqueTerms;
Terms terms = MultiFields.getTerms(leafReader, fieldName);
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
int docsWithC = leafReader.docFreq(new Term(classFieldName, c));
int docsWithC = leafReader.docFreq(term);
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
}
@ -261,16 +261,16 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
*
* @param word the token produced by the analyzer
* @param fieldName the field the word is coming from
* @param c the class
* @param term the class term
* @return number of documents of the input class
* @throws java.io.IOException If there is a low-level I/O error
*/
private int getWordFreqForClass(String word, String fieldName, BytesRef c) throws IOException {
private int getWordFreqForClass(String word, String fieldName, Term term) throws IOException {
BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder();
BooleanQuery.Builder subQuery = new BooleanQuery.Builder();
subQuery.add(new BooleanClause(new TermQuery(new Term(fieldName, word)), BooleanClause.Occur.SHOULD));
booleanQuery.add(new BooleanClause(subQuery.build(), BooleanClause.Occur.MUST));
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST));
booleanQuery.add(new BooleanClause(new TermQuery(term), BooleanClause.Occur.MUST));
if (query != null) {
booleanQuery.add(query, BooleanClause.Occur.MUST);
}
@ -279,11 +279,11 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
return totalHitCountCollector.getTotalHits();
}
private double calculateLogPrior(BytesRef currentClass, int docsWithClassSize) throws IOException {
return Math.log((double) docCount(currentClass)) - Math.log(docsWithClassSize);
private double calculateLogPrior(Term term, int docsWithClassSize) throws IOException {
return Math.log((double) docCount(term)) - Math.log(docsWithClassSize);
}
private int docCount(BytesRef countedClass) throws IOException {
return leafReader.docFreq(new Term(classFieldName, countedClass));
private int docCount(Term term) throws IOException {
return leafReader.docFreq(term);
}
}

View File

@ -24,6 +24,7 @@ import java.nio.charset.CodingErrorAction;
import java.nio.charset.StandardCharsets;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
/**
A Term represents a word from text. This is the unit of search. It is
@ -39,16 +40,22 @@ public final class Term implements Comparable<Term> {
/** Constructs a Term with the given field and bytes.
* <p>Note that a null field or null bytes value results in undefined
* behavior for most Lucene APIs that accept a Term parameter.
* behavior for most Lucene APIs that accept a Term parameter.
*
* <p>WARNING: the provided BytesRef is not copied, but used directly.
* Therefore the bytes should not be modified after construction, for
* example, you should clone a copy by {@link BytesRef#deepCopyOf}
* rather than pass reused bytes from a TermsEnum.
* <p>The provided BytesRef is copied when it is non null.
*/
public Term(String fld, BytesRef bytes) {
field = fld;
this.bytes = bytes;
this.bytes = bytes == null ? null : BytesRef.deepCopyOf(bytes);
}
/** Constructs a Term with the given field and the bytes from a builder.
* <p>Note that a null field value results in undefined
* behavior for most Lucene APIs that accept a Term parameter.
*/
public Term(String fld, BytesRefBuilder bytesBuilder) {
field = fld;
this.bytes = bytesBuilder.toBytesRef();
}
/** Constructs a Term with the given field and text.
@ -61,7 +68,7 @@ public final class Term implements Comparable<Term> {
/** Constructs a Term with the given field and empty text.
* This serves two purposes: 1) reuse of a Term with the same field.
* 2) pattern for a query.
*
*
* @param fld field's name
*/
public Term(String fld) {
@ -75,10 +82,10 @@ public final class Term implements Comparable<Term> {
/** Returns the text of this term. In the case of words, this is simply the
text of the word. In the case of dates and other types, this is an
encoding of the object as a string. */
public final String text() {
public final String text() {
return toString(bytes);
}
/** Returns human-readable form of the term text. If the term is not unicode,
* the raw bytes will be printed instead. */
public static final String toString(BytesRef termText) {
@ -93,7 +100,7 @@ public final class Term implements Comparable<Term> {
}
}
/** Returns the bytes of this term. */
/** Returns the bytes of this term, these should not be modified. */
public final BytesRef bytes() { return bytes; }
@Override
@ -141,8 +148,8 @@ public final class Term implements Comparable<Term> {
}
}
/**
* Resets the field and text of a Term.
/**
* Resets the field and text of a Term.
* <p>WARNING: the provided BytesRef is not copied, but used directly.
* Therefore the bytes should not be modified after construction, for
* example, you should clone a copy rather than pass reused bytes from

View File

@ -92,7 +92,7 @@ public final class BlendedTermQuery extends Query {
terms = ArrayUtil.grow(terms, numTerms + 1);
boosts = ArrayUtil.grow(boosts, numTerms + 1);
contexts = ArrayUtil.grow(contexts, numTerms + 1);
terms[numTerms] = new Term(term.field(), BytesRef.deepCopyOf(term.bytes()));
terms[numTerms] = term;
boosts[numTerms] = boost;
contexts[numTerms] = context;
numTerms += 1;

View File

@ -101,7 +101,6 @@ public class PhraseQuery extends Query {
*
*/
public Builder add(Term term, int position) {
term = new Term(term.field(), BytesRef.deepCopyOf(term.bytes())); // be defensive
if (position < 0) {
throw new IllegalArgumentException("Positions must be >= 0, got " + position);
}
@ -186,7 +185,7 @@ public class PhraseQuery extends Query {
private static Term[] toTerms(String field, BytesRef... termBytes) {
Term[] terms = new Term[termBytes.length];
for (int i = 0; i < terms.length; ++i) {
terms[i] = new Term(field, BytesRef.deepCopyOf(termBytes[i]));
terms[i] = new Term(field, termBytes[i]);
}
return terms;
}

View File

@ -274,7 +274,7 @@ public class QueryBuilder {
throw new AssertionError();
}
return newTermQuery(new Term(field, BytesRef.deepCopyOf(termAtt.getBytesRef())));
return newTermQuery(new Term(field, termAtt.getBytesRef()));
}
/**
@ -286,7 +286,7 @@ public class QueryBuilder {
stream.reset();
List<Term> terms = new ArrayList<>();
while (stream.incrementToken()) {
terms.add(new Term(field, BytesRef.deepCopyOf(termAtt.getBytesRef())));
terms.add(new Term(field, termAtt.getBytesRef()));
}
return newSynonymQuery(terms.toArray(new Term[terms.size()]));
@ -319,7 +319,7 @@ public class QueryBuilder {
add(q, currentQuery, operator);
currentQuery.clear();
}
currentQuery.add(new Term(field, BytesRef.deepCopyOf(termAtt.getBytesRef())));
currentQuery.add(new Term(field, termAtt.getBytesRef()));
}
add(q, currentQuery, operator);
@ -376,7 +376,7 @@ public class QueryBuilder {
multiTerms.clear();
}
position += positionIncrement;
multiTerms.add(new Term(field, BytesRef.deepCopyOf(termAtt.getBytesRef())));
multiTerms.add(new Term(field, termAtt.getBytesRef()));
}
if (enablePositionIncrements) {

View File

@ -745,7 +745,7 @@ public abstract class FieldType extends FieldProperties {
// match-only
return getRangeQuery(parser, field, externalVal, externalVal, true, true);
} else {
return new TermQuery(new Term(field.getName(), br.toBytesRef()));
return new TermQuery(new Term(field.getName(), br));
}
}

View File

@ -1194,7 +1194,7 @@ public class SolrIndexSearcher extends IndexSearcher implements Closeable,SolrIn
TermQuery key = null;
if (useCache) {
key = new TermQuery(new Term(deState.fieldName, BytesRef.deepCopyOf(deState.termsEnum.term())));
key = new TermQuery(new Term(deState.fieldName, deState.termsEnum.term()));
DocSet result = filterCache.get(key);
if (result != null) return result;
}

View File

@ -653,7 +653,7 @@ abstract class FacetFieldProcessorFCBase extends FacetFieldProcessor {
bucket.add("val", val);
TermQuery filter = needFilter ? new TermQuery(new Term(sf.getName(), BytesRef.deepCopyOf(br))) : null;
TermQuery filter = needFilter ? new TermQuery(new Term(sf.getName(), br)) : null;
fillBucket(bucket, countAcc.getCount(slotNum), slotNum, null, filter);
bucketList.add(bucket);

View File

@ -133,7 +133,7 @@ public class SimpleMLTQParser extends QParser {
BytesRefBuilder bytesRefBuilder = new BytesRefBuilder();
bytesRefBuilder.grow(NumericUtils.BUF_SIZE_INT);
NumericUtils.intToPrefixCoded(Integer.parseInt(uniqueValue), 0, bytesRefBuilder);
return new Term(field, bytesRefBuilder.toBytesRef());
return new Term(field, bytesRefBuilder);
}