LUCENE-6821 - TermQuery's constructors should clone the incoming term

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1709576 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Tommaso Teofili 2015-10-20 11:57:37 +00:00
parent d2bee4788a
commit 9534904c25
10 changed files with 64 additions and 60 deletions

View File

@ -152,16 +152,14 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
while ((next = classesEnum.next()) != null) { while ((next = classesEnum.next()) != null) {
if (next.length > 0) { if (next.length > 0) {
// We are passing the term to IndexSearcher so we need to make sure it will not change over time // We are passing the term to IndexSearcher so we need to make sure it will not change over time
next = BytesRef.deepCopyOf(next); Term term = new Term(this.classFieldName, next);
double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedText, next, docsWithClassSize); double clVal = calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(tokenizedText, term, docsWithClassSize);
assignedClasses.add(new ClassificationResult<>(next, clVal)); assignedClasses.add(new ClassificationResult<>(term.bytes(), clVal));
} }
} }
// normalization; the values transforms to a 0-1 range // normalization; the values transforms to a 0-1 range
ArrayList<ClassificationResult<BytesRef>> assignedClassesNorm = normClassificationResults(assignedClasses); return normClassificationResults(assignedClasses);
return assignedClassesNorm;
} }
/** /**
@ -208,18 +206,18 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
return result.toArray(new String[result.size()]); return result.toArray(new String[result.size()]);
} }
private double calculateLogLikelihood(String[] tokenizedText, BytesRef c, int docsWithClass) throws IOException { private double calculateLogLikelihood(String[] tokenizedText, Term term, int docsWithClass) throws IOException {
// for each word // for each word
double result = 0d; double result = 0d;
for (String word : tokenizedText) { for (String word : tokenizedText) {
// search with text:word AND class:c // search with text:word AND class:c
int hits = getWordFreqForClass(word,c); int hits = getWordFreqForClass(word, term);
// num : count the no of times the word appears in documents of class c (+1) // num : count the no of times the word appears in documents of class c (+1)
double num = hits + 1; // +1 is added because of add 1 smoothing double num = hits + 1; // +1 is added because of add 1 smoothing
// den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|) // den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|)
double den = getTextTermFreqForClass(c) + docsWithClass; double den = getTextTermFreqForClass(term) + docsWithClass;
// P(w|c) = num/den // P(w|c) = num/den
double wordProbability = num / den; double wordProbability = num / den;
@ -232,18 +230,18 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
/** /**
* Returns the average number of unique terms times the number of docs belonging to the input class * Returns the average number of unique terms times the number of docs belonging to the input class
* @param c the class * @param term the term representing the class
* @return the average number of unique terms * @return the average number of unique terms
* @throws IOException if a low level I/O problem happens * @throws IOException if a low level I/O problem happens
*/ */
private double getTextTermFreqForClass(BytesRef c) throws IOException { private double getTextTermFreqForClass(Term term) throws IOException {
double avgNumberOfUniqueTerms = 0; double avgNumberOfUniqueTerms = 0;
for (String textFieldName : textFieldNames) { for (String textFieldName : textFieldNames) {
Terms terms = MultiFields.getTerms(leafReader, textFieldName); Terms terms = MultiFields.getTerms(leafReader, textFieldName);
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc avgNumberOfUniqueTerms += numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
} }
int docsWithC = leafReader.docFreq(new Term(classFieldName, c)); int docsWithC = leafReader.docFreq(term);
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
} }
@ -251,18 +249,18 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
* Returns the number of documents of the input class ( from the whole index or from a subset) * Returns the number of documents of the input class ( from the whole index or from a subset)
* that contains the word ( in a specific field or in all the fields if no one selected) * that contains the word ( in a specific field or in all the fields if no one selected)
* @param word the token produced by the analyzer * @param word the token produced by the analyzer
* @param c the class * @param term the term representing the class
* @return the number of documents of the input class * @return the number of documents of the input class
* @throws IOException if a low level I/O problem happens * @throws IOException if a low level I/O problem happens
*/ */
private int getWordFreqForClass(String word, BytesRef c) throws IOException { private int getWordFreqForClass(String word, Term term) throws IOException {
BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder(); BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder();
BooleanQuery.Builder subQuery = new BooleanQuery.Builder(); BooleanQuery.Builder subQuery = new BooleanQuery.Builder();
for (String textFieldName : textFieldNames) { for (String textFieldName : textFieldNames) {
subQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD)); subQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD));
} }
booleanQuery.add(new BooleanClause(subQuery.build(), BooleanClause.Occur.MUST)); booleanQuery.add(new BooleanClause(subQuery.build(), BooleanClause.Occur.MUST));
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST)); booleanQuery.add(new BooleanClause(new TermQuery(term), BooleanClause.Occur.MUST));
if (query != null) { if (query != null) {
booleanQuery.add(query, BooleanClause.Occur.MUST); booleanQuery.add(query, BooleanClause.Occur.MUST);
} }
@ -271,12 +269,12 @@ public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
return totalHitCountCollector.getTotalHits(); return totalHitCountCollector.getTotalHits();
} }
private double calculateLogPrior(BytesRef currentClass, int docsWithClassSize) throws IOException { private double calculateLogPrior(Term term, int docsWithClassSize) throws IOException {
return Math.log((double) docCount(currentClass)) - Math.log(docsWithClassSize); return Math.log((double) docCount(term)) - Math.log(docsWithClassSize);
} }
private int docCount(BytesRef countedClass) throws IOException { private int docCount(Term term) throws IOException {
return leafReader.docFreq(new Term(classFieldName, countedClass)); return leafReader.docFreq(term);
} }
/** /**

View File

@ -122,18 +122,18 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
int docsWithClassSize = countDocsWithClass(); int docsWithClassSize = countDocsWithClass();
while ((c = classesEnum.next()) != null) { while ((c = classesEnum.next()) != null) {
double classScore = 0; double classScore = 0;
Term term = new Term(this.classFieldName, c);
for (String fieldName : textFieldNames) { for (String fieldName : textFieldNames) {
List<String[]> tokensArrays = fieldName2tokensArray.get(fieldName); List<String[]> tokensArrays = fieldName2tokensArray.get(fieldName);
double fieldScore = 0; double fieldScore = 0;
for (String[] fieldTokensArray : tokensArrays) { for (String[] fieldTokensArray : tokensArrays) {
fieldScore += calculateLogPrior(c, docsWithClassSize) + calculateLogLikelihood(fieldTokensArray, fieldName, c, docsWithClassSize) * fieldName2boost.get(fieldName); fieldScore += calculateLogPrior(term, docsWithClassSize) + calculateLogLikelihood(fieldTokensArray, fieldName, term, docsWithClassSize) * fieldName2boost.get(fieldName);
} }
classScore += fieldScore; classScore += fieldScore;
} }
assignedClasses.add(new ClassificationResult<>(BytesRef.deepCopyOf(c), classScore)); assignedClasses.add(new ClassificationResult<>(term.bytes(), classScore));
} }
ArrayList<ClassificationResult<BytesRef>> assignedClassesNorm = normClassificationResults(assignedClasses); return normClassificationResults(assignedClasses);
return assignedClassesNorm;
} }
/** /**
@ -211,23 +211,23 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
/** /**
* @param tokenizedText the tokenized content of a field * @param tokenizedText the tokenized content of a field
* @param fieldName the input field name * @param fieldName the input field name
* @param c the class to calculate the score of * @param term the {@link Term} referring to the class to calculate the score of
* @param docsWithClass the total number of docs that have a class * @param docsWithClass the total number of docs that have a class
* @return a normalized score for the class * @return a normalized score for the class
* @throws IOException If there is a low-level I/O error * @throws IOException If there is a low-level I/O error
*/ */
private double calculateLogLikelihood(String[] tokenizedText, String fieldName, BytesRef c, int docsWithClass) throws IOException { private double calculateLogLikelihood(String[] tokenizedText, String fieldName, Term term, int docsWithClass) throws IOException {
// for each word // for each word
double result = 0d; double result = 0d;
for (String word : tokenizedText) { for (String word : tokenizedText) {
// search with text:word AND class:c // search with text:word AND class:c
int hits = getWordFreqForClass(word, fieldName, c); int hits = getWordFreqForClass(word, fieldName, term);
// num : count the no of times the word appears in documents of class c (+1) // num : count the no of times the word appears in documents of class c (+1)
double num = hits + 1; // +1 is added because of add 1 smoothing double num = hits + 1; // +1 is added because of add 1 smoothing
// den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|) // den : for the whole dictionary, count the no of times a word appears in documents of class c (+|V|)
double den = getTextTermFreqForClass(c, fieldName) + docsWithClass; double den = getTextTermFreqForClass(term, fieldName) + docsWithClass;
// P(w|c) = num/den // P(w|c) = num/den
double wordProbability = num / den; double wordProbability = num / den;
@ -242,16 +242,16 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
/** /**
* Returns the average number of unique terms times the number of docs belonging to the input class * Returns the average number of unique terms times the number of docs belonging to the input class
* *
* @param c the class * @param term the class term
* @return the average number of unique terms * @return the average number of unique terms
* @throws java.io.IOException If there is a low-level I/O error * @throws java.io.IOException If there is a low-level I/O error
*/ */
private double getTextTermFreqForClass(BytesRef c, String fieldName) throws IOException { private double getTextTermFreqForClass(Term term, String fieldName) throws IOException {
double avgNumberOfUniqueTerms; double avgNumberOfUniqueTerms;
Terms terms = MultiFields.getTerms(leafReader, fieldName); Terms terms = MultiFields.getTerms(leafReader, fieldName);
long numPostings = terms.getSumDocFreq(); // number of term/doc pairs long numPostings = terms.getSumDocFreq(); // number of term/doc pairs
avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc avgNumberOfUniqueTerms = numPostings / (double) terms.getDocCount(); // avg # of unique terms per doc
int docsWithC = leafReader.docFreq(new Term(classFieldName, c)); int docsWithC = leafReader.docFreq(term);
return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c return avgNumberOfUniqueTerms * docsWithC; // avg # of unique terms in text fields per doc * # docs with c
} }
@ -261,16 +261,16 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
* *
* @param word the token produced by the analyzer * @param word the token produced by the analyzer
* @param fieldName the field the word is coming from * @param fieldName the field the word is coming from
* @param c the class * @param term the class term
* @return number of documents of the input class * @return number of documents of the input class
* @throws java.io.IOException If there is a low-level I/O error * @throws java.io.IOException If there is a low-level I/O error
*/ */
private int getWordFreqForClass(String word, String fieldName, BytesRef c) throws IOException { private int getWordFreqForClass(String word, String fieldName, Term term) throws IOException {
BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder(); BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder();
BooleanQuery.Builder subQuery = new BooleanQuery.Builder(); BooleanQuery.Builder subQuery = new BooleanQuery.Builder();
subQuery.add(new BooleanClause(new TermQuery(new Term(fieldName, word)), BooleanClause.Occur.SHOULD)); subQuery.add(new BooleanClause(new TermQuery(new Term(fieldName, word)), BooleanClause.Occur.SHOULD));
booleanQuery.add(new BooleanClause(subQuery.build(), BooleanClause.Occur.MUST)); booleanQuery.add(new BooleanClause(subQuery.build(), BooleanClause.Occur.MUST));
booleanQuery.add(new BooleanClause(new TermQuery(new Term(classFieldName, c)), BooleanClause.Occur.MUST)); booleanQuery.add(new BooleanClause(new TermQuery(term), BooleanClause.Occur.MUST));
if (query != null) { if (query != null) {
booleanQuery.add(query, BooleanClause.Occur.MUST); booleanQuery.add(query, BooleanClause.Occur.MUST);
} }
@ -279,11 +279,11 @@ public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifi
return totalHitCountCollector.getTotalHits(); return totalHitCountCollector.getTotalHits();
} }
private double calculateLogPrior(BytesRef currentClass, int docsWithClassSize) throws IOException { private double calculateLogPrior(Term term, int docsWithClassSize) throws IOException {
return Math.log((double) docCount(currentClass)) - Math.log(docsWithClassSize); return Math.log((double) docCount(term)) - Math.log(docsWithClassSize);
} }
private int docCount(BytesRef countedClass) throws IOException { private int docCount(Term term) throws IOException {
return leafReader.docFreq(new Term(classFieldName, countedClass)); return leafReader.docFreq(term);
} }
} }

View File

@ -24,6 +24,7 @@ import java.nio.charset.CodingErrorAction;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
/** /**
A Term represents a word from text. This is the unit of search. It is A Term represents a word from text. This is the unit of search. It is
@ -39,16 +40,22 @@ public final class Term implements Comparable<Term> {
/** Constructs a Term with the given field and bytes. /** Constructs a Term with the given field and bytes.
* <p>Note that a null field or null bytes value results in undefined * <p>Note that a null field or null bytes value results in undefined
* behavior for most Lucene APIs that accept a Term parameter. * behavior for most Lucene APIs that accept a Term parameter.
* *
* <p>WARNING: the provided BytesRef is not copied, but used directly. * <p>The provided BytesRef is copied when it is non null.
* Therefore the bytes should not be modified after construction, for
* example, you should clone a copy by {@link BytesRef#deepCopyOf}
* rather than pass reused bytes from a TermsEnum.
*/ */
public Term(String fld, BytesRef bytes) { public Term(String fld, BytesRef bytes) {
field = fld; field = fld;
this.bytes = bytes; this.bytes = bytes == null ? null : BytesRef.deepCopyOf(bytes);
}
/** Constructs a Term with the given field and the bytes from a builder.
* <p>Note that a null field value results in undefined
* behavior for most Lucene APIs that accept a Term parameter.
*/
public Term(String fld, BytesRefBuilder bytesBuilder) {
field = fld;
this.bytes = bytesBuilder.toBytesRef();
} }
/** Constructs a Term with the given field and text. /** Constructs a Term with the given field and text.
@ -61,7 +68,7 @@ public final class Term implements Comparable<Term> {
/** Constructs a Term with the given field and empty text. /** Constructs a Term with the given field and empty text.
* This serves two purposes: 1) reuse of a Term with the same field. * This serves two purposes: 1) reuse of a Term with the same field.
* 2) pattern for a query. * 2) pattern for a query.
* *
* @param fld field's name * @param fld field's name
*/ */
public Term(String fld) { public Term(String fld) {
@ -75,10 +82,10 @@ public final class Term implements Comparable<Term> {
/** Returns the text of this term. In the case of words, this is simply the /** Returns the text of this term. In the case of words, this is simply the
text of the word. In the case of dates and other types, this is an text of the word. In the case of dates and other types, this is an
encoding of the object as a string. */ encoding of the object as a string. */
public final String text() { public final String text() {
return toString(bytes); return toString(bytes);
} }
/** Returns human-readable form of the term text. If the term is not unicode, /** Returns human-readable form of the term text. If the term is not unicode,
* the raw bytes will be printed instead. */ * the raw bytes will be printed instead. */
public static final String toString(BytesRef termText) { public static final String toString(BytesRef termText) {
@ -93,7 +100,7 @@ public final class Term implements Comparable<Term> {
} }
} }
/** Returns the bytes of this term. */ /** Returns the bytes of this term, these should not be modified. */
public final BytesRef bytes() { return bytes; } public final BytesRef bytes() { return bytes; }
@Override @Override
@ -141,8 +148,8 @@ public final class Term implements Comparable<Term> {
} }
} }
/** /**
* Resets the field and text of a Term. * Resets the field and text of a Term.
* <p>WARNING: the provided BytesRef is not copied, but used directly. * <p>WARNING: the provided BytesRef is not copied, but used directly.
* Therefore the bytes should not be modified after construction, for * Therefore the bytes should not be modified after construction, for
* example, you should clone a copy rather than pass reused bytes from * example, you should clone a copy rather than pass reused bytes from

View File

@ -92,7 +92,7 @@ public final class BlendedTermQuery extends Query {
terms = ArrayUtil.grow(terms, numTerms + 1); terms = ArrayUtil.grow(terms, numTerms + 1);
boosts = ArrayUtil.grow(boosts, numTerms + 1); boosts = ArrayUtil.grow(boosts, numTerms + 1);
contexts = ArrayUtil.grow(contexts, numTerms + 1); contexts = ArrayUtil.grow(contexts, numTerms + 1);
terms[numTerms] = new Term(term.field(), BytesRef.deepCopyOf(term.bytes())); terms[numTerms] = term;
boosts[numTerms] = boost; boosts[numTerms] = boost;
contexts[numTerms] = context; contexts[numTerms] = context;
numTerms += 1; numTerms += 1;

View File

@ -101,7 +101,6 @@ public class PhraseQuery extends Query {
* *
*/ */
public Builder add(Term term, int position) { public Builder add(Term term, int position) {
term = new Term(term.field(), BytesRef.deepCopyOf(term.bytes())); // be defensive
if (position < 0) { if (position < 0) {
throw new IllegalArgumentException("Positions must be >= 0, got " + position); throw new IllegalArgumentException("Positions must be >= 0, got " + position);
} }
@ -186,7 +185,7 @@ public class PhraseQuery extends Query {
private static Term[] toTerms(String field, BytesRef... termBytes) { private static Term[] toTerms(String field, BytesRef... termBytes) {
Term[] terms = new Term[termBytes.length]; Term[] terms = new Term[termBytes.length];
for (int i = 0; i < terms.length; ++i) { for (int i = 0; i < terms.length; ++i) {
terms[i] = new Term(field, BytesRef.deepCopyOf(termBytes[i])); terms[i] = new Term(field, termBytes[i]);
} }
return terms; return terms;
} }

View File

@ -274,7 +274,7 @@ public class QueryBuilder {
throw new AssertionError(); throw new AssertionError();
} }
return newTermQuery(new Term(field, BytesRef.deepCopyOf(termAtt.getBytesRef()))); return newTermQuery(new Term(field, termAtt.getBytesRef()));
} }
/** /**
@ -286,7 +286,7 @@ public class QueryBuilder {
stream.reset(); stream.reset();
List<Term> terms = new ArrayList<>(); List<Term> terms = new ArrayList<>();
while (stream.incrementToken()) { while (stream.incrementToken()) {
terms.add(new Term(field, BytesRef.deepCopyOf(termAtt.getBytesRef()))); terms.add(new Term(field, termAtt.getBytesRef()));
} }
return newSynonymQuery(terms.toArray(new Term[terms.size()])); return newSynonymQuery(terms.toArray(new Term[terms.size()]));
@ -319,7 +319,7 @@ public class QueryBuilder {
add(q, currentQuery, operator); add(q, currentQuery, operator);
currentQuery.clear(); currentQuery.clear();
} }
currentQuery.add(new Term(field, BytesRef.deepCopyOf(termAtt.getBytesRef()))); currentQuery.add(new Term(field, termAtt.getBytesRef()));
} }
add(q, currentQuery, operator); add(q, currentQuery, operator);
@ -376,7 +376,7 @@ public class QueryBuilder {
multiTerms.clear(); multiTerms.clear();
} }
position += positionIncrement; position += positionIncrement;
multiTerms.add(new Term(field, BytesRef.deepCopyOf(termAtt.getBytesRef()))); multiTerms.add(new Term(field, termAtt.getBytesRef()));
} }
if (enablePositionIncrements) { if (enablePositionIncrements) {

View File

@ -745,7 +745,7 @@ public abstract class FieldType extends FieldProperties {
// match-only // match-only
return getRangeQuery(parser, field, externalVal, externalVal, true, true); return getRangeQuery(parser, field, externalVal, externalVal, true, true);
} else { } else {
return new TermQuery(new Term(field.getName(), br.toBytesRef())); return new TermQuery(new Term(field.getName(), br));
} }
} }

View File

@ -1194,7 +1194,7 @@ public class SolrIndexSearcher extends IndexSearcher implements Closeable,SolrIn
TermQuery key = null; TermQuery key = null;
if (useCache) { if (useCache) {
key = new TermQuery(new Term(deState.fieldName, BytesRef.deepCopyOf(deState.termsEnum.term()))); key = new TermQuery(new Term(deState.fieldName, deState.termsEnum.term()));
DocSet result = filterCache.get(key); DocSet result = filterCache.get(key);
if (result != null) return result; if (result != null) return result;
} }

View File

@ -653,7 +653,7 @@ abstract class FacetFieldProcessorFCBase extends FacetFieldProcessor {
bucket.add("val", val); bucket.add("val", val);
TermQuery filter = needFilter ? new TermQuery(new Term(sf.getName(), BytesRef.deepCopyOf(br))) : null; TermQuery filter = needFilter ? new TermQuery(new Term(sf.getName(), br)) : null;
fillBucket(bucket, countAcc.getCount(slotNum), slotNum, null, filter); fillBucket(bucket, countAcc.getCount(slotNum), slotNum, null, filter);
bucketList.add(bucket); bucketList.add(bucket);

View File

@ -133,7 +133,7 @@ public class SimpleMLTQParser extends QParser {
BytesRefBuilder bytesRefBuilder = new BytesRefBuilder(); BytesRefBuilder bytesRefBuilder = new BytesRefBuilder();
bytesRefBuilder.grow(NumericUtils.BUF_SIZE_INT); bytesRefBuilder.grow(NumericUtils.BUF_SIZE_INT);
NumericUtils.intToPrefixCoded(Integer.parseInt(uniqueValue), 0, bytesRefBuilder); NumericUtils.intToPrefixCoded(Integer.parseInt(uniqueValue), 0, bytesRefBuilder);
return new Term(field, bytesRefBuilder.toBytesRef()); return new Term(field, bytesRefBuilder);
} }