Migrate more classes to records. (#13772)

This commit is contained in:
Adrien Grand 2024-09-12 21:38:28 +02:00 committed by GitHub
parent cb48c7121a
commit 2f7da75b7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 26 additions and 57 deletions

View File

@ -194,6 +194,7 @@ access the members using method calls instead of field accesses. Affected classe
- `IOContext`, `MergeInfo`, and `FlushInfo` (GITHUB#13205)
- `BooleanClause` (GITHUB#13261)
- `TotalHits` (GITHUB#13762)
- `TermAndVector` (GITHUB#13772)
- Many basic Lucene classes, including `CollectionStatistics`, `TermStatistics` and `LeafMetadata` (GITHUB#13328)
### Boolean flags on IOContext replaced with a new ReadAdvice enum.

View File

@ -56,25 +56,25 @@ public class Word2VecModel implements RandomAccessVectorValues.Floats {
}
public void addTermAndVector(TermAndVector modelEntry) {
modelEntry.normalizeVector();
modelEntry = modelEntry.normalizeVector();
this.termsAndVectors[loadedCount++] = modelEntry;
this.word2Vec.add(modelEntry.getTerm());
this.word2Vec.add(modelEntry.term());
}
@Override
public float[] vectorValue(int targetOrd) {
return termsAndVectors[targetOrd].getVector();
return termsAndVectors[targetOrd].vector();
}
public float[] vectorValue(BytesRef term) {
int termOrd = this.word2Vec.find(term);
if (termOrd < 0) return null;
TermAndVector entry = this.termsAndVectors[termOrd];
return (entry == null) ? null : entry.getVector();
return (entry == null) ? null : entry.vector();
}
public BytesRef termValue(int targetOrd) {
return termsAndVectors[targetOrd].getTerm();
return termsAndVectors[targetOrd].term();
}
@Override

View File

@ -120,8 +120,8 @@ public class TestWord2VecSynonymProvider extends LuceneTestCase {
@Test
public void normalizedVector_shouldReturnModule1() {
TermAndVector synonymTerm = new TermAndVector(new BytesRef("a"), new float[] {10, 10});
synonymTerm.normalizeVector();
float[] vector = synonymTerm.getVector();
synonymTerm = synonymTerm.normalizeVector();
float[] vector = synonymTerm.vector();
float len = 0;
for (int i = 0; i < vector.length; i++) {
len += vector[i] * vector[i];

View File

@ -78,10 +78,7 @@ public abstract class PerFieldDocValuesFormat extends DocValuesFormat {
return new FieldsWriter(state);
}
static class ConsumerAndSuffix implements Closeable {
DocValuesConsumer consumer;
int suffix;
record ConsumerAndSuffix(DocValuesConsumer consumer, int suffix) implements Closeable {
@Override
public void close() throws IOException {
consumer.close();
@ -222,10 +219,10 @@ public abstract class PerFieldDocValuesFormat extends DocValuesFormat {
final String segmentSuffix =
getFullSegmentSuffix(
segmentWriteState.segmentSuffix, getSuffix(formatName, Integer.toString(suffix)));
consumer = new ConsumerAndSuffix();
consumer.consumer =
format.fieldsConsumer(new SegmentWriteState(segmentWriteState, segmentSuffix));
consumer.suffix = suffix;
consumer =
new ConsumerAndSuffix(
format.fieldsConsumer(new SegmentWriteState(segmentWriteState, segmentSuffix)),
suffix);
formats.put(format, consumer);
} else {
// we've already seen this format, so just grab its suffix

View File

@ -79,22 +79,13 @@ public abstract class PerFieldPostingsFormat extends PostingsFormat {
super(PER_FIELD_NAME);
}
/** Group of fields written by one PostingsFormat */
static class FieldsGroup {
final List<String> fields;
final int suffix;
/**
* Custom SegmentWriteState for this group of fields, with the segmentSuffix uniqueified for
* this PostingsFormat
*/
final SegmentWriteState state;
private FieldsGroup(List<String> fields, int suffix, SegmentWriteState state) {
this.fields = fields;
this.suffix = suffix;
this.state = state;
}
/**
* Group of fields written by one PostingsFormat
*
* @param state Custom SegmentWriteState for this group of fields, with the segmentSuffix
* uniqueified for this PostingsFormat
*/
record FieldsGroup(List<String> fields, int suffix, SegmentWriteState state) {
static class Builder {
final Set<String> fields;

View File

@ -24,37 +24,17 @@ import java.util.Locale;
*
* @lucene.experimental
*/
public class TermAndVector {
private final BytesRef term;
private final float[] vector;
public TermAndVector(BytesRef term, float[] vector) {
this.term = term;
this.vector = vector;
}
public BytesRef getTerm() {
return this.term;
}
public float[] getVector() {
return this.vector;
}
public record TermAndVector(BytesRef term, float[] vector) {
public int size() {
return vector.length;
}
public void normalizeVector() {
float vectorLength = 0;
for (int i = 0; i < vector.length; i++) {
vectorLength += vector[i] * vector[i];
}
vectorLength = (float) Math.sqrt(vectorLength);
for (int i = 0; i < vector.length; i++) {
vector[i] /= vectorLength;
}
/** Return a {@link TermAndVector} whose vector is normalized according to the L2 norm. */
public TermAndVector normalizeVector() {
float[] vector = this.vector.clone();
VectorUtil.l2normalize(vector);
return new TermAndVector(term, vector);
}
@Override