add RawTFSimilarity class (#13749)

This commit is contained in:
Christine Poerschke 2024-09-17 13:11:25 +01:00 committed by GitHub
parent a4c79c8d30
commit a817426511
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 152 additions and 46 deletions

View File

@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.similarities;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.TermStatistics;
/** Similarity that returns the raw TF as score. */
public class RawTFSimilarity extends Similarity {
/** Default constructor: parameter-free */
public RawTFSimilarity() {
super();
}
/** Primary constructor. */
public RawTFSimilarity(boolean discountOverlaps) {
super(discountOverlaps);
}
@Override
public SimScorer scorer(
float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
return new SimScorer() {
@Override
public float score(float freq, long norm) {
return boost * freq;
}
};
}
}

View File

@ -29,14 +29,13 @@ import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field.Store; import org.apache.lucene.document.Field.Store;
import org.apache.lucene.document.TextField; import org.apache.lucene.document.TextField;
import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause.Occur; import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.search.similarities.ClassicSimilarity; import org.apache.lucene.search.similarities.ClassicSimilarity;
import org.apache.lucene.search.similarities.Similarity; import org.apache.lucene.search.similarities.RawTFSimilarity;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.index.RandomIndexWriter;
@ -75,7 +74,7 @@ public class TestBooleanQueryVisitSubscorers extends LuceneTestCase {
searcher = newSearcher(reader, true, false); searcher = newSearcher(reader, true, false);
searcher.setSimilarity(new ClassicSimilarity()); searcher.setSimilarity(new ClassicSimilarity());
scorerSearcher = new ScorerIndexSearcher(reader); scorerSearcher = new ScorerIndexSearcher(reader);
scorerSearcher.setSimilarity(new CountingSimilarity()); scorerSearcher.setSimilarity(new RawTFSimilarity());
} }
@Override @Override
@ -345,24 +344,4 @@ public class TestBooleanQueryVisitSubscorers extends LuceneTestCase {
return builder; return builder;
} }
} }
// Similarity that just returns the frequency as the score
private static class CountingSimilarity extends Similarity {
@Override
public long computeNorm(FieldInvertState state) {
return 1;
}
@Override
public SimScorer scorer(
float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
return new SimScorer() {
@Override
public float score(float freq, long norm) {
return freq;
}
};
}
}
} }

View File

@ -29,12 +29,11 @@ import org.apache.lucene.document.Field.Store;
import org.apache.lucene.document.StringField; import org.apache.lucene.document.StringField;
import org.apache.lucene.document.TextField; import org.apache.lucene.document.TextField;
import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.search.similarities.Similarity; import org.apache.lucene.search.similarities.RawTFSimilarity;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.index.RandomIndexWriter;
@ -67,7 +66,7 @@ public class TestConjunctions extends LuceneTestCase {
reader = writer.getReader(); reader = writer.getReader();
writer.close(); writer.close();
searcher = newSearcher(reader); searcher = newSearcher(reader);
searcher.setSimilarity(new TFSimilarity()); searcher.setSimilarity(new RawTFSimilarity());
} }
static Document doc(String v1, String v2) { static Document doc(String v1, String v2) {
@ -93,26 +92,6 @@ public class TestConjunctions extends LuceneTestCase {
super.tearDown(); super.tearDown();
} }
// Similarity that returns the TF as score
private static class TFSimilarity extends Similarity {
@Override
public long computeNorm(FieldInvertState state) {
return 1; // we dont care
}
@Override
public SimScorer scorer(
float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
return new SimScorer() {
@Override
public float score(float freq, long norm) {
return freq;
}
};
}
}
public void testScorerGetChildren() throws Exception { public void testScorerGetChildren() throws Exception {
Directory dir = newDirectory(); Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig()); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig());

View File

@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.similarities;
import java.io.IOException;
import java.util.Random;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.search.similarities.BaseSimilarityTestCase;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.IOUtils;
public class TestRawTFSimilarity extends BaseSimilarityTestCase {
private Directory directory;
private IndexReader indexReader;
private IndexSearcher indexSearcher;
@Override
protected Similarity getSimilarity(Random random) {
return new RawTFSimilarity();
}
@Override
public void setUp() throws Exception {
super.setUp();
directory = newDirectory();
try (IndexWriter indexWriter = new IndexWriter(directory, newIndexWriterConfig())) {
final Document document1 = new Document();
final Document document2 = new Document();
final Document document3 = new Document();
document1.add(LuceneTestCase.newTextField("test", "one", Field.Store.YES));
document2.add(LuceneTestCase.newTextField("test", "two two", Field.Store.YES));
document3.add(LuceneTestCase.newTextField("test", "three three three", Field.Store.YES));
indexWriter.addDocument(document1);
indexWriter.addDocument(document2);
indexWriter.addDocument(document3);
indexWriter.commit();
}
indexReader = DirectoryReader.open(directory);
indexSearcher = newSearcher(indexReader);
indexSearcher.setSimilarity(new RawTFSimilarity());
}
@Override
public void tearDown() throws Exception {
IOUtils.close(indexReader, directory);
super.tearDown();
}
public void testOne() throws IOException {
implTest("one", 1f);
}
public void testTwo() throws IOException {
implTest("two", 2f);
}
public void testThree() throws IOException {
implTest("three", 3f);
}
private void implTest(String text, float expectedScore) throws IOException {
Query query = new TermQuery(new Term("test", text));
TopDocs topDocs = indexSearcher.search(query, 1);
assertEquals(1, topDocs.totalHits.value());
assertEquals(1, topDocs.scoreDocs.length);
assertEquals(expectedScore, topDocs.scoreDocs[0].score, 0.0);
}
public void testBoostQuery() throws IOException {
Query query = new TermQuery(new Term("test", "three"));
float boost = 14f;
TopDocs topDocs = indexSearcher.search(new BoostQuery(query, boost), 1);
assertEquals(1, topDocs.totalHits.value());
assertEquals(1, topDocs.scoreDocs.length);
assertEquals(42f, topDocs.scoreDocs[0].score, 0.0);
}
}