mirror of https://github.com/apache/lucene.git
LUCENE-9338: Clean up type safety in SimpleBindings (#1444)
Replaces SimpleBindings' Map<String, Object> with a map of Function<Bindings, DoubleValuesSource> to improve type safety, and reworks cycle detection and validation to avoid catching StackOverflowException
This commit is contained in:
parent
83018deef7
commit
ed3caab2d8
|
@ -221,6 +221,9 @@ Other
|
|||
* LUCENE-9191: Make LineFileDocs's random seeking more efficient, making tests using LineFileDocs faster (Robert Muir,
|
||||
Mike McCandless)
|
||||
|
||||
* LUCENE-9338: Refactors SimpleBindings to improve type safety and cycle detection (Alan Woodward,
|
||||
Adrien Grand)
|
||||
|
||||
======================= Lucene 8.5.1 =======================
|
||||
|
||||
Bug Fixes
|
||||
|
|
|
@ -31,9 +31,8 @@ import org.apache.lucene.search.IndexSearcher;
|
|||
/**
|
||||
* A {@link DoubleValuesSource} which evaluates a {@link Expression} given the context of an {@link Bindings}.
|
||||
*/
|
||||
@SuppressWarnings({"rawtypes", "unchecked"})
|
||||
final class ExpressionValueSource extends DoubleValuesSource {
|
||||
final DoubleValuesSource variables[];
|
||||
final DoubleValuesSource[] variables;
|
||||
final Expression expression;
|
||||
final boolean needsScores;
|
||||
|
||||
|
|
|
@ -18,7 +18,10 @@ package org.apache.lucene.expressions;
|
|||
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
|
||||
import org.apache.lucene.search.DoubleValuesSource;
|
||||
import org.apache.lucene.search.SortField;
|
||||
|
@ -44,7 +47,8 @@ import org.apache.lucene.search.SortField;
|
|||
* @lucene.experimental
|
||||
*/
|
||||
public final class SimpleBindings extends Bindings {
|
||||
final Map<String,Object> map = new HashMap<>();
|
||||
|
||||
final Map<String, Function<Bindings, DoubleValuesSource>> map = new HashMap<>();
|
||||
|
||||
/** Creates a new empty Bindings */
|
||||
public SimpleBindings() {}
|
||||
|
@ -56,13 +60,13 @@ public final class SimpleBindings extends Bindings {
|
|||
* FieldCache, the document's score, etc.
|
||||
*/
|
||||
public void add(SortField sortField) {
|
||||
map.put(sortField.getField(), sortField);
|
||||
map.put(sortField.getField(), bindings -> fromSortField(sortField));
|
||||
}
|
||||
|
||||
/**
|
||||
* Bind a {@link DoubleValuesSource} directly to the given name.
|
||||
*/
|
||||
public void add(String name, DoubleValuesSource source) { map.put(name, source); }
|
||||
public void add(String name, DoubleValuesSource source) { map.put(name, bindings -> source); }
|
||||
|
||||
/**
|
||||
* Adds an Expression to the bindings.
|
||||
|
@ -70,20 +74,10 @@ public final class SimpleBindings extends Bindings {
|
|||
* This can be used to reference expressions from other expressions.
|
||||
*/
|
||||
public void add(String name, Expression expression) {
|
||||
map.put(name, expression);
|
||||
map.put(name, expression::getDoubleValuesSource);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DoubleValuesSource getDoubleValuesSource(String name) {
|
||||
Object o = map.get(name);
|
||||
if (o == null) {
|
||||
throw new IllegalArgumentException("Invalid reference '" + name + "'");
|
||||
} else if (o instanceof Expression) {
|
||||
return ((Expression)o).getDoubleValuesSource(this);
|
||||
} else if (o instanceof DoubleValuesSource) {
|
||||
return ((DoubleValuesSource) o);
|
||||
}
|
||||
SortField field = (SortField) o;
|
||||
private DoubleValuesSource fromSortField(SortField field) {
|
||||
switch(field.getType()) {
|
||||
case INT:
|
||||
return DoubleValuesSource.fromIntField(field.getField());
|
||||
|
@ -100,20 +94,47 @@ public final class SimpleBindings extends Bindings {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public DoubleValuesSource getDoubleValuesSource(String name) {
|
||||
if (map.containsKey(name) == false) {
|
||||
throw new IllegalArgumentException("Invalid reference '" + name + "'");
|
||||
}
|
||||
return map.get(name).apply(this);
|
||||
}
|
||||
|
||||
/**
|
||||
* Traverses the graph of bindings, checking there are no cycles or missing references
|
||||
* @throws IllegalArgumentException if the bindings is inconsistent
|
||||
*/
|
||||
public void validate() {
|
||||
for (Object o : map.values()) {
|
||||
if (o instanceof Expression) {
|
||||
Expression expr = (Expression) o;
|
||||
try {
|
||||
expr.getDoubleValuesSource(this);
|
||||
} catch (StackOverflowError e) {
|
||||
throw new IllegalArgumentException("Recursion Error: Cycle detected originating in (" + expr.sourceText + ")");
|
||||
}
|
||||
}
|
||||
for (Map.Entry<String, Function<Bindings, DoubleValuesSource>> origin : map.entrySet()) {
|
||||
origin.getValue().apply(new CycleDetectionBindings(origin.getKey()));
|
||||
}
|
||||
}
|
||||
|
||||
private class CycleDetectionBindings extends Bindings {
|
||||
|
||||
private final Set<String> seenFields = new LinkedHashSet<>();
|
||||
|
||||
CycleDetectionBindings(String current) {
|
||||
seenFields.add(current);
|
||||
}
|
||||
|
||||
CycleDetectionBindings(Set<String> parents, String current) {
|
||||
seenFields.addAll(parents);
|
||||
seenFields.add(current);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DoubleValuesSource getDoubleValuesSource(String name) {
|
||||
if (seenFields.contains(name)) {
|
||||
throw new IllegalArgumentException("Recursion error: Cycle detected " + seenFields + "->" + name);
|
||||
}
|
||||
if (map.containsKey(name) == false) {
|
||||
throw new IllegalArgumentException("Invalid reference '" + name + "'");
|
||||
}
|
||||
return map.get(name).apply(new CycleDetectionBindings(seenFields, name));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue