LUCENE-5231: better interoperability of expressions with valuesource

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1525192 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Robert Muir 2013-09-21 03:37:43 +00:00
parent 477360d047
commit 2f169863cb
6 changed files with 50 additions and 78 deletions

View File

@ -16,10 +16,6 @@ package org.apache.lucene.expressions;
* limitations under the License. * limitations under the License.
*/ */
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.queries.function.ValueSource;
/** /**
@ -30,7 +26,7 @@ import org.apache.lucene.queries.function.ValueSource;
* *
* @lucene.experimental * @lucene.experimental
*/ */
public abstract class Bindings implements Iterable<String> { public abstract class Bindings {
/** Sole constructor. (For invocation by subclass /** Sole constructor. (For invocation by subclass
* constructors, typically implicit.) */ * constructors, typically implicit.) */
@ -41,47 +37,8 @@ public abstract class Bindings implements Iterable<String> {
*/ */
public abstract ValueSource getValueSource(String name); public abstract ValueSource getValueSource(String name);
/** Returns an <code>Iterator</code> over the variable names in this binding */ /** Returns a {@code ValueSource} over relevance scores */
@Override protected final ValueSource getScoreValueSource() {
public abstract Iterator<String> iterator(); return new ScoreValueSource();
/**
* Traverses the graph of bindings, checking there are no cycles or missing references
* @throws IllegalArgumentException if the bindings is inconsistent
*/
public final void validate() {
Set<String> marked = new HashSet<String>();
Set<String> chain = new HashSet<String>();
for (String name : this) {
if (!marked.contains(name)) {
chain.add(name);
validate(name, marked, chain);
chain.remove(name);
}
}
}
private void validate(String name, Set<String> marked, Set<String> chain) {
ValueSource vs = getValueSource(name);
if (vs == null) {
throw new IllegalArgumentException("Invalid reference '" + name + "'");
}
if (vs instanceof ExpressionValueSource) {
Expression expr = ((ExpressionValueSource)vs).expression;
for (String external : expr.variables) {
if (chain.contains(external)) {
throw new IllegalArgumentException("Recursion Error: Cycle detected originating in (" + external + ")");
}
if (!marked.contains(external)) {
chain.add(external);
validate(external, marked, chain);
chain.remove(external);
}
}
}
marked.add(name);
} }
} }

View File

@ -50,7 +50,6 @@ class ExpressionComparator extends FieldComparator<Double> {
Map<String,Object> context = new HashMap<String,Object>(); Map<String,Object> context = new HashMap<String,Object>();
assert scorer != null; assert scorer != null;
context.put("scorer", new ScoreFunctionValues(scorer)); context.put("scorer", new ScoreFunctionValues(scorer));
context.put("valuesCache", new HashMap<String, FunctionValues>());
scores = source.getValues(context, readerContext); scores = source.getValues(context, readerContext);
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);

View File

@ -17,6 +17,7 @@ package org.apache.lucene.expressions;
*/ */
import java.io.IOException; import java.io.IOException;
import java.util.HashMap;
import java.util.Map; import java.util.Map;
import org.apache.lucene.index.AtomicReaderContext; import org.apache.lucene.index.AtomicReaderContext;
@ -29,32 +30,47 @@ import org.apache.lucene.search.SortField;
*/ */
@SuppressWarnings({"rawtypes", "unchecked"}) @SuppressWarnings({"rawtypes", "unchecked"})
final class ExpressionValueSource extends ValueSource { final class ExpressionValueSource extends ValueSource {
private final Bindings bindings; final ValueSource variables[];
final Expression expression; final Expression expression;
final boolean needsScores;
public ExpressionValueSource(Bindings bindings, Expression expression) { ExpressionValueSource(Bindings bindings, Expression expression) {
if (bindings == null) throw new NullPointerException(); if (bindings == null) throw new NullPointerException();
if (expression == null) throw new NullPointerException(); if (expression == null) throw new NullPointerException();
this.bindings = bindings;
this.expression = expression; this.expression = expression;
variables = new ValueSource[expression.variables.length];
boolean needsScores = false;
for (int i = 0; i < variables.length; i++) {
ValueSource source = bindings.getValueSource(expression.variables[i]);
if (source instanceof ScoreValueSource) {
needsScores = true;
} else if (source instanceof ExpressionValueSource) {
if (((ExpressionValueSource)source).needsScores()) {
needsScores = true;
}
} else if (source == null) {
throw new RuntimeException("Internal error. Variable (" + expression.variables[i] + ") does not exist.");
}
variables[i] = source;
}
this.needsScores = needsScores;
} }
/** <code>context</code> must contain a key <code>"valuesCache"</code> which is a <code>Map&lt;String,FunctionValues&gt;</code>. */
@Override @Override
public FunctionValues getValues(Map context, AtomicReaderContext readerContext) throws IOException { public FunctionValues getValues(Map context, AtomicReaderContext readerContext) throws IOException {
ValueSource source;
Map<String, FunctionValues> valuesCache = (Map<String, FunctionValues>)context.get("valuesCache"); Map<String, FunctionValues> valuesCache = (Map<String, FunctionValues>)context.get("valuesCache");
if (valuesCache == null) { if (valuesCache == null) {
throw new NullPointerException(); valuesCache = new HashMap<String, FunctionValues>();
context = new HashMap(context);
context.put("valuesCache", valuesCache);
} }
FunctionValues[] externalValues = new FunctionValues[expression.variables.length]; FunctionValues[] externalValues = new FunctionValues[expression.variables.length];
for (int i = 0; i < expression.variables.length; ++i) { for (int i = 0; i < variables.length; ++i) {
String externalName = expression.variables[i]; String externalName = expression.variables[i];
FunctionValues values = valuesCache.get(externalName); FunctionValues values = valuesCache.get(externalName);
if (values == null) { if (values == null) {
source = bindings.getValueSource(externalName); values = variables[i].getValues(context, readerContext);
values = source.getValues(context, readerContext);
if (values == null) { if (values == null) {
throw new RuntimeException("Internal error. External (" + externalName + ") does not exist."); throw new RuntimeException("Internal error. External (" + externalName + ") does not exist.");
} }
@ -87,17 +103,6 @@ final class ExpressionValueSource extends ValueSource {
} }
boolean needsScores() { boolean needsScores() {
for (int i = 0; i < expression.variables.length; i++) { return needsScores;
String externalName = expression.variables[i];
ValueSource source = bindings.getValueSource(externalName);
if (source instanceof ScoreValueSource) {
return true;
} else if (source instanceof ExpressionValueSource) {
if (((ExpressionValueSource)source).needsScores()) {
return true;
}
}
}
return false;
} }
} }

View File

@ -38,7 +38,7 @@ class ScoreValueSource extends ValueSource {
public FunctionValues getValues(Map context, AtomicReaderContext readerContext) throws IOException { public FunctionValues getValues(Map context, AtomicReaderContext readerContext) throws IOException {
FunctionValues v = (FunctionValues) context.get("scorer"); FunctionValues v = (FunctionValues) context.get("scorer");
if (v == null) { if (v == null) {
throw new NullPointerException(); throw new IllegalStateException("Expressions referencing the score can only be used for sorting");
} }
return v; return v;
} }

View File

@ -18,7 +18,6 @@ package org.apache.lucene.expressions;
*/ */
import java.util.HashMap; import java.util.HashMap;
import java.util.Iterator;
import java.util.Map; import java.util.Map;
import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.queries.function.ValueSource;
@ -96,14 +95,26 @@ public final class SimpleBindings extends Bindings {
case DOUBLE: case DOUBLE:
return new DoubleFieldSource(field.getField(), (DoubleParser) field.getParser()); return new DoubleFieldSource(field.getField(), (DoubleParser) field.getParser());
case SCORE: case SCORE:
return new ScoreValueSource(); return getScoreValueSource();
default: default:
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
} }
@Override /**
public Iterator<String> iterator() { * Traverses the graph of bindings, checking there are no cycles or missing references
return map.keySet().iterator(); * @throws IllegalArgumentException if the bindings is inconsistent
*/
public void validate() {
for (Object o : map.values()) {
if (o instanceof Expression) {
Expression expr = (Expression) o;
try {
expr.getValueSource(this);
} catch (StackOverflowError e) {
throw new IllegalArgumentException("Recursion Error: Cycle detected originating in (" + expr.sourceText + ")");
}
}
}
} }
} }

View File

@ -42,8 +42,8 @@ public class TestExpressionSortField extends LuceneTestCase {
bindings.add(new SortField("popularity", SortField.Type.INT)); bindings.add(new SortField("popularity", SortField.Type.INT));
SimpleBindings otherBindings = new SimpleBindings(); SimpleBindings otherBindings = new SimpleBindings();
bindings.add(new SortField("_score", SortField.Type.LONG)); otherBindings.add(new SortField("_score", SortField.Type.LONG));
bindings.add(new SortField("popularity", SortField.Type.INT)); otherBindings.add(new SortField("popularity", SortField.Type.INT));
SortField sf1 = expr.getSortField(bindings, true); SortField sf1 = expr.getSortField(bindings, true);