Harmonize implementations of "visit" for Exprs from ExprMacros. (#12230)

* Harmonize implementations of "visit" for Exprs from ExprMacros.

Many of them had bugs where they would not visit all of the original
arguments. I don't think this has user-visible consequences right now,
but it's possible it would in a future world where "visit" is used
for more stuff than it is today.

So, this patch all updates all implementations to a more consistent
style that emphasizes reapplying the macro to the shuttled args.

* Test fixes, test coverage, PR review comments.
This commit is contained in:
Gian Merlino 2022-02-04 08:08:54 -08:00 committed by GitHub
parent 290130b1fa
commit de82c611de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 138 additions and 89 deletions

View File

@ -126,6 +126,9 @@ public interface Expr extends Cacheable
* Programatically rewrite the {@link Expr} tree with a {@link Shuttle}. Each {@link Expr} is responsible for * Programatically rewrite the {@link Expr} tree with a {@link Shuttle}. Each {@link Expr} is responsible for
* ensuring the {@link Shuttle} can visit all of its {@link Expr} children, as well as updating its children * ensuring the {@link Shuttle} can visit all of its {@link Expr} children, as well as updating its children
* {@link Expr} with the results from the {@link Shuttle}, before finally visiting an updated form of itself. * {@link Expr} with the results from the {@link Shuttle}, before finally visiting an updated form of itself.
*
* When this Expr is the result of {@link ExprMacroTable.ExprMacro#apply}, all of the original arguments to the
* macro must be visited, including arguments that may have been "baked in" to this Expr.
*/ */
Expr visit(Shuttle shuttle); Expr visit(Shuttle shuttle);
@ -158,6 +161,7 @@ public interface Expr extends Cacheable
* Check if an expression can be 'vectorized', for a given set of inputs. If this method returns true, * Check if an expression can be 'vectorized', for a given set of inputs. If this method returns true,
* {@link #buildVectorized} is expected to produce a {@link ExprVectorProcessor} which can evaluate values in batches * {@link #buildVectorized} is expected to produce a {@link ExprVectorProcessor} which can evaluate values in batches
* to use with vectorized query engines. * to use with vectorized query engines.
*
* @param inspector * @param inspector
*/ */
default boolean canVectorize(InputBindingInspector inspector) default boolean canVectorize(InputBindingInspector inspector)
@ -168,6 +172,7 @@ public interface Expr extends Cacheable
/** /**
* Builds a 'vectorized' expression processor, that can operate on batches of input values for use in vectorized * Builds a 'vectorized' expression processor, that can operate on batches of input values for use in vectorized
* query engines. * query engines.
*
* @param inspector * @param inspector
*/ */
default <T> ExprVectorProcessor<T> buildVectorized(VectorInputBindingInspector inspector) default <T> ExprVectorProcessor<T> buildVectorized(VectorInputBindingInspector inspector)
@ -390,6 +395,17 @@ public interface Expr extends Cacheable
* Provide the {@link Shuttle} with an {@link Expr} to inspect and potentially rewrite. * Provide the {@link Shuttle} with an {@link Expr} to inspect and potentially rewrite.
*/ */
Expr visit(Expr expr); Expr visit(Expr expr);
default List<Expr> visitAll(List<Expr> exprs)
{
final List<Expr> newExprs = new ArrayList<>();
for (final Expr arg : exprs) {
newExprs.add(visit(arg));
}
return newExprs;
}
} }
/** /**

View File

@ -204,8 +204,7 @@ class FunctionExpr implements Expr
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
List<Expr> newArgs = args.stream().map(shuttle::visit).collect(Collectors.toList()); return shuttle.visit(new FunctionExpr(function, name, shuttle.visitAll(args)));
return shuttle.visit(new FunctionExpr(function, name, newArgs));
} }
@Override @Override
@ -334,8 +333,7 @@ class ApplyFunctionExpr implements Expr
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
LambdaExpr newLambda = (LambdaExpr) lambdaExpr.visit(shuttle); LambdaExpr newLambda = (LambdaExpr) lambdaExpr.visit(shuttle);
List<Expr> newArgs = argsExpr.stream().map(shuttle::visit).collect(Collectors.toList()); return shuttle.visit(new ApplyFunctionExpr(function, name, newLambda, shuttle.visitAll(argsExpr)));
return shuttle.visit(new ApplyFunctionExpr(function, name, newLambda, newArgs));
} }
@Override @Override

View File

@ -19,9 +19,15 @@
package org.apache.druid.math.expr; package org.apache.druid.math.expr;
import com.google.common.collect.ImmutableList;
import nl.jqno.equalsverifier.EqualsVerifier; import nl.jqno.equalsverifier.EqualsVerifier;
import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class ExprTest public class ExprTest
{ {
@Test @Test
@ -196,6 +202,7 @@ public class ExprTest
{ {
EqualsVerifier.forClass(LambdaExpr.class).usingGetClass().verify(); EqualsVerifier.forClass(LambdaExpr.class).usingGetClass().verify();
} }
@Test @Test
public void testEqualsContractForNullLongExpr() public void testEqualsContractForNullLongExpr()
{ {
@ -213,4 +220,31 @@ public class ExprTest
.withPrefabValues(ExpressionType.class, ExpressionType.DOUBLE, ExpressionType.STRING) .withPrefabValues(ExpressionType.class, ExpressionType.DOUBLE, ExpressionType.STRING)
.verify(); .verify();
} }
@Test
public void testShuttleVisitAll()
{
final List<Expr> visitedExprs = new ArrayList<>();
final Expr.Shuttle shuttle = expr -> {
visitedExprs.add(expr);
return expr;
};
shuttle.visitAll(Collections.emptyList());
Assert.assertEquals("Visiting an empty list", Collections.emptyList(), visitedExprs);
final List<Expr> oneIdentifier = Collections.singletonList(new IdentifierExpr("ident"));
visitedExprs.clear();
shuttle.visitAll(oneIdentifier);
Assert.assertEquals("One identifier", oneIdentifier, visitedExprs);
final List<Expr> twoIdentifiers = ImmutableList.of(
new IdentifierExpr("ident1"),
new IdentifierExpr("ident2")
);
visitedExprs.clear();
shuttle.visitAll(twoIdentifiers);
Assert.assertEquals("Two identifiers", twoIdentifiers, visitedExprs);
}
} }

View File

@ -35,7 +35,6 @@ import javax.annotation.Nonnull;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
public class BloomFilterExpressions public class BloomFilterExpressions
{ {
@ -88,7 +87,7 @@ public class BloomFilterExpressions
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
return shuttle.visit(this); return shuttle.visit(apply(shuttle.visitAll(args)));
} }
@Nullable @Nullable
@ -171,8 +170,7 @@ public class BloomFilterExpressions
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
List<Expr> newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); return shuttle.visit(apply(shuttle.visitAll(args)));
return shuttle.visit(new BloomExpr(newArgs));
} }
@Nullable @Nullable
@ -259,8 +257,7 @@ public class BloomFilterExpressions
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
Expr newArg = arg.visit(shuttle); return shuttle.visit(apply(shuttle.visitAll(args)));
return shuttle.visit(new BloomExpr(filter, newArg));
} }
@Nullable @Nullable
@ -331,8 +328,7 @@ public class BloomFilterExpressions
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
List<Expr> newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); return shuttle.visit(apply(shuttle.visitAll(args)));
return shuttle.visit(new DynamicBloomExpr(newArgs));
} }
@Nullable @Nullable

View File

@ -85,8 +85,7 @@ public class SleepExprMacro implements ExprMacro
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
Expr newArg = arg.visit(shuttle); return shuttle.visit(apply(shuttle.visitAll(args)));
return shuttle.visit(new SleepExpr(newArg));
} }
/** /**

View File

@ -58,6 +58,6 @@ public class CaseInsensitiveContainsExprMacro implements ExprMacroTable.ExprMacr
final Expr arg = args.get(0); final Expr arg = args.get(0);
final Expr searchStr = args.get(1); final Expr searchStr = args.get(1);
return new ContainsExpr(FN_NAME, arg, searchStr, false); return new ContainsExpr(FN_NAME, arg, searchStr, false, shuttle -> apply(shuttle.visitAll(args)));
} }
} }

View File

@ -39,20 +39,31 @@ class ContainsExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr
{ {
private final Function<String, Boolean> searchFunction; private final Function<String, Boolean> searchFunction;
private final Expr searchStrExpr; private final Expr searchStrExpr;
private final Function<Shuttle, Expr> visitFunction;
ContainsExpr(String functioName, Expr arg, Expr searchStrExpr, boolean caseSensitive) ContainsExpr(
final String functionName,
final Expr arg,
final Expr searchStrExpr,
final boolean caseSensitive,
final Function<Shuttle, Expr> visitFunction
)
{ {
super(functioName, arg); this(functionName, arg, searchStrExpr, createFunction(searchStrExpr, caseSensitive), visitFunction);
this.searchStrExpr = validateSearchExpr(searchStrExpr, functioName);
// Creates the function eagerly to avoid branching in eval.
this.searchFunction = createFunction(searchStrExpr, caseSensitive);
} }
private ContainsExpr(String functioName, Expr arg, Expr searchStrExpr, Function<String, Boolean> searchFunction) private ContainsExpr(
final String functionName,
final Expr arg,
final Expr searchStrExpr,
final Function<String, Boolean> searchFunction,
final Function<Shuttle, Expr> visitFunction
)
{ {
super(functioName, arg); super(functionName, arg);
this.searchFunction = searchFunction; this.searchFunction = searchFunction;
this.searchStrExpr = validateSearchExpr(searchStrExpr, functioName); this.searchStrExpr = validateSearchExpr(searchStrExpr, functionName);
this.visitFunction = visitFunction;
} }
@Nonnull @Nonnull
@ -80,8 +91,7 @@ class ContainsExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr
@Override @Override
public Expr visit(Expr.Shuttle shuttle) public Expr visit(Expr.Shuttle shuttle)
{ {
Expr newArg = arg.visit(shuttle); return visitFunction.apply(shuttle);
return shuttle.visit(new ContainsExpr(name, newArg, searchStrExpr, searchFunction));
} }
@Override @Override
@ -90,15 +100,6 @@ class ContainsExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr
return StringUtils.format("%s(%s, %s)", name, arg.stringify(), searchStrExpr.stringify()); return StringUtils.format("%s(%s, %s)", name, arg.stringify(), searchStrExpr.stringify());
} }
private Function<String, Boolean> createFunction(Expr searchStrExpr, boolean caseSensitive)
{
String searchStr = StringUtils.nullToEmptyNonDruidDataString((String) searchStrExpr.getLiteralValue());
if (caseSensitive) {
return s -> s.contains(searchStr);
}
return s -> org.apache.commons.lang.StringUtils.containsIgnoreCase(s, searchStr);
}
private Expr validateSearchExpr(Expr searchExpr, String functioName) private Expr validateSearchExpr(Expr searchExpr, String functioName)
{ {
if (!ExprUtils.isStringLiteral(searchExpr)) { if (!ExprUtils.isStringLiteral(searchExpr)) {
@ -106,4 +107,13 @@ class ContainsExpr extends ExprMacroTable.BaseScalarUnivariateMacroFunctionExpr
} }
return searchExpr; return searchExpr;
} }
private static Function<String, Boolean> createFunction(Expr searchStrExpr, boolean caseSensitive)
{
String searchStr = StringUtils.nullToEmptyNonDruidDataString((String) searchStrExpr.getLiteralValue());
if (caseSensitive) {
return s -> s.contains(searchStr);
}
return s -> org.apache.commons.lang.StringUtils.containsIgnoreCase(s, searchStr);
}
} }

View File

@ -57,6 +57,6 @@ public class ContainsExprMacro implements ExprMacroTable.ExprMacro
final Expr arg = args.get(0); final Expr arg = args.get(0);
final Expr searchStr = args.get(1); final Expr searchStr = args.get(1);
return new ContainsExpr(FN_NAME, arg, searchStr, true); return new ContainsExpr(FN_NAME, arg, searchStr, true, shuttle -> shuttle.visit(apply(shuttle.visitAll(args))));
} }
} }

View File

@ -36,7 +36,6 @@ import javax.annotation.Nullable;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.stream.Collectors;
public class HyperUniqueExpressions public class HyperUniqueExpressions
{ {
@ -205,8 +204,7 @@ public class HyperUniqueExpressions
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
List<Expr> newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); return shuttle.visit(apply(shuttle.visitAll(args)));
return shuttle.visit(new HllExpr(newArgs));
} }
@Nullable @Nullable
@ -262,8 +260,7 @@ public class HyperUniqueExpressions
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
Expr newArg = arg.visit(shuttle); return shuttle.visit(apply(shuttle.visitAll(args)));
return shuttle.visit(new HllExpr(newArg));
} }
@Nullable @Nullable
@ -316,8 +313,7 @@ public class HyperUniqueExpressions
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
Expr newArg = arg.visit(shuttle); return shuttle.visit(apply(shuttle.visitAll(args)));
return shuttle.visit(new HllExpr(newArg));
} }
@Nullable @Nullable

View File

@ -115,8 +115,7 @@ public class IPv4AddressMatchExprMacro implements ExprMacroTable.ExprMacro
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
Expr newArg = arg.visit(shuttle); return shuttle.visit(apply(shuttle.visitAll(args)));
return shuttle.visit(new IPv4AddressMatchExpr(newArg, subnetInfo));
} }
@Nullable @Nullable

View File

@ -91,8 +91,7 @@ public class IPv4AddressParseExprMacro implements ExprMacroTable.ExprMacro
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
Expr newArg = arg.visit(shuttle); return shuttle.visit(apply(shuttle.visitAll(args)));
return shuttle.visit(new IPv4AddressParseExpr(newArg));
} }
@Nullable @Nullable

View File

@ -90,8 +90,7 @@ public class IPv4AddressStringifyExprMacro implements ExprMacroTable.ExprMacro
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
Expr newArg = arg.visit(shuttle); return shuttle.visit(apply(shuttle.visitAll(args)));
return shuttle.visit(new IPv4AddressStringifyExpr(newArg));
} }
@Nullable @Nullable

View File

@ -88,8 +88,7 @@ public class LikeExprMacro implements ExprMacroTable.ExprMacro
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
Expr newArg = arg.visit(shuttle); return shuttle.visit(apply(shuttle.visitAll(args)));
return shuttle.visit(new LikeExtractExpr(newArg));
} }
@Nullable @Nullable

View File

@ -94,8 +94,7 @@ public class LookupExprMacro implements ExprMacroTable.ExprMacro
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
Expr newArg = arg.visit(shuttle); return shuttle.visit(apply(shuttle.visitAll(args)));
return shuttle.visit(new LookupExpr(newArg));
} }
@Nullable @Nullable

View File

@ -95,8 +95,7 @@ public class RegexpExtractExprMacro implements ExprMacroTable.ExprMacro
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
Expr newArg = arg.visit(shuttle); return shuttle.visit(apply(shuttle.visitAll(args)));
return shuttle.visit(new RegexpExtractExpr(newArg));
} }
@Nullable @Nullable

View File

@ -87,8 +87,7 @@ public class RegexpLikeExprMacro implements ExprMacroTable.ExprMacro
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
Expr newArg = arg.visit(shuttle); return shuttle.visit(apply(shuttle.visitAll(args)));
return shuttle.visit(new RegexpLikeExpr(newArg));
} }
@Nullable @Nullable

View File

@ -33,7 +33,6 @@ import javax.annotation.Nonnull;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.stream.Collectors;
public class TimestampCeilExprMacro implements ExprMacroTable.ExprMacro public class TimestampCeilExprMacro implements ExprMacroTable.ExprMacro
{ {
@ -90,8 +89,7 @@ public class TimestampCeilExprMacro implements ExprMacroTable.ExprMacro
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
List<Expr> newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); return shuttle.visit(new TimestampCeilExpr(shuttle.visitAll(args)));
return shuttle.visit(new TimestampCeilExpr(newArgs));
} }
@Nullable @Nullable
@ -158,8 +156,7 @@ public class TimestampCeilExprMacro implements ExprMacroTable.ExprMacro
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
List<Expr> newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); return shuttle.visit(new TimestampCeilDynamicExpr(shuttle.visitAll(args)));
return shuttle.visit(new TimestampCeilDynamicExpr(newArgs));
} }
@Nullable @Nullable

View File

@ -160,8 +160,7 @@ public class TimestampExtractExprMacro implements ExprMacroTable.ExprMacro
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
Expr newArg = arg.visit(shuttle); return shuttle.visit(apply(shuttle.visitAll(args)));
return shuttle.visit(new TimestampExtractExpr(newArg));
} }
@Nullable @Nullable

View File

@ -34,7 +34,6 @@ import javax.annotation.Nonnull;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.stream.Collectors;
public class TimestampFloorExprMacro implements ExprMacroTable.ExprMacro public class TimestampFloorExprMacro implements ExprMacroTable.ExprMacro
{ {
@ -111,9 +110,7 @@ public class TimestampFloorExprMacro implements ExprMacroTable.ExprMacro
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
List<Expr> newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); return shuttle.visit(new TimestampFloorExpr(shuttle.visitAll(args)));
return shuttle.visit(new TimestampFloorExpr(newArgs));
} }
@Nullable @Nullable
@ -189,8 +186,7 @@ public class TimestampFloorExprMacro implements ExprMacroTable.ExprMacro
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
List<Expr> newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); return shuttle.visit(new TimestampFloorDynamicExpr(shuttle.visitAll(args)));
return shuttle.visit(new TimestampFloorDynamicExpr(newArgs));
} }
@Nullable @Nullable

View File

@ -95,8 +95,7 @@ public class TimestampFormatExprMacro implements ExprMacroTable.ExprMacro
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
Expr newArg = arg.visit(shuttle); return shuttle.visit(apply(shuttle.visitAll(args)));
return shuttle.visit(new TimestampFormatExpr(newArg));
} }
@Nullable @Nullable

View File

@ -98,8 +98,7 @@ public class TimestampParseExprMacro implements ExprMacroTable.ExprMacro
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
Expr newArg = arg.visit(shuttle); return shuttle.visit(apply(shuttle.visitAll(args)));
return shuttle.visit(new TimestampParseExpr(newArg));
} }
@Nullable @Nullable

View File

@ -33,7 +33,6 @@ import org.joda.time.chrono.ISOChronology;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
public class TimestampShiftExprMacro implements ExprMacroTable.ExprMacro public class TimestampShiftExprMacro implements ExprMacroTable.ExprMacro
{ {
@ -110,8 +109,7 @@ public class TimestampShiftExprMacro implements ExprMacroTable.ExprMacro
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
List<Expr> newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); return shuttle.visit(new TimestampShiftExpr(shuttle.visitAll(args)));
return shuttle.visit(new TimestampShiftExpr(newArgs));
} }
@Nullable @Nullable
@ -146,8 +144,7 @@ public class TimestampShiftExprMacro implements ExprMacroTable.ExprMacro
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
List<Expr> newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList()); return shuttle.visit(new TimestampShiftDynamicExpr(shuttle.visitAll(args)));
return shuttle.visit(new TimestampShiftDynamicExpr(newArgs));
} }
@Nullable @Nullable

View File

@ -34,6 +34,7 @@ import javax.annotation.Nullable;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.function.Function;
public abstract class TrimExprMacro implements ExprMacroTable.ExprMacro public abstract class TrimExprMacro implements ExprMacroTable.ExprMacro
{ {
@ -93,16 +94,18 @@ public abstract class TrimExprMacro implements ExprMacroTable.ExprMacro
throw new IAE("Function[%s] must have 1 or 2 arguments", name()); throw new IAE("Function[%s] must have 1 or 2 arguments", name());
} }
final Function<Expr.Shuttle, Expr> visitFn = shuttle -> shuttle.visit(apply(shuttle.visitAll(args)));
if (args.size() == 1) { if (args.size() == 1) {
return new TrimStaticCharsExpr(mode, args.get(0), DEFAULT_CHARS, null); return new TrimStaticCharsExpr(mode, args.get(0), DEFAULT_CHARS, null, visitFn);
} else { } else {
final Expr charsArg = args.get(1); final Expr charsArg = args.get(1);
if (charsArg.isLiteral()) { if (charsArg.isLiteral()) {
final String charsString = charsArg.eval(InputBindings.nilBindings()).asString(); final String charsString = charsArg.eval(InputBindings.nilBindings()).asString();
final char[] chars = charsString == null ? EMPTY_CHARS : charsString.toCharArray(); final char[] chars = charsString == null ? EMPTY_CHARS : charsString.toCharArray();
return new TrimStaticCharsExpr(mode, args.get(0), chars, charsArg); return new TrimStaticCharsExpr(mode, args.get(0), chars, charsArg, visitFn);
} else { } else {
return new TrimDynamicCharsExpr(mode, args.get(0), args.get(1)); return new TrimDynamicCharsExpr(mode, args.get(0), args.get(1), visitFn);
} }
} }
} }
@ -113,13 +116,21 @@ public abstract class TrimExprMacro implements ExprMacroTable.ExprMacro
private final TrimMode mode; private final TrimMode mode;
private final char[] chars; private final char[] chars;
private final Expr charsExpr; private final Expr charsExpr;
private final Function<Shuttle, Expr> visitFn;
public TrimStaticCharsExpr(final TrimMode mode, final Expr stringExpr, final char[] chars, final Expr charsExpr) public TrimStaticCharsExpr(
final TrimMode mode,
final Expr stringExpr,
final char[] chars,
final Expr charsExpr,
final Function<Shuttle, Expr> visitFn
)
{ {
super(mode.getFnName(), stringExpr); super(mode.getFnName(), stringExpr);
this.mode = mode; this.mode = mode;
this.chars = chars; this.chars = chars;
this.charsExpr = charsExpr; this.charsExpr = charsExpr;
this.visitFn = visitFn;
} }
@Nonnull @Nonnull
@ -167,8 +178,7 @@ public abstract class TrimExprMacro implements ExprMacroTable.ExprMacro
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
Expr newStringExpr = arg.visit(shuttle); return visitFn.apply(shuttle);
return shuttle.visit(new TrimStaticCharsExpr(mode, newStringExpr, chars, charsExpr));
} }
@Nullable @Nullable
@ -200,6 +210,8 @@ public abstract class TrimExprMacro implements ExprMacroTable.ExprMacro
return false; return false;
} }
TrimStaticCharsExpr that = (TrimStaticCharsExpr) o; TrimStaticCharsExpr that = (TrimStaticCharsExpr) o;
// Doesn't use "visitFn", but that's OK, because visitFn is determined entirely by "mode".
return mode == that.mode && return mode == that.mode &&
Arrays.equals(chars, that.chars) && Arrays.equals(chars, that.chars) &&
Objects.equals(charsExpr, that.charsExpr); Objects.equals(charsExpr, that.charsExpr);
@ -220,12 +232,19 @@ public abstract class TrimExprMacro implements ExprMacroTable.ExprMacro
private final TrimMode mode; private final TrimMode mode;
private final Expr stringExpr; private final Expr stringExpr;
private final Expr charsExpr; private final Expr charsExpr;
private final Function<Shuttle, Expr> visitFn;
public TrimDynamicCharsExpr(final TrimMode mode, final Expr stringExpr, final Expr charsExpr) public TrimDynamicCharsExpr(
final TrimMode mode,
final Expr stringExpr,
final Expr charsExpr,
final Function<Shuttle, Expr> visitFn
)
{ {
this.mode = mode; this.mode = mode;
this.stringExpr = stringExpr; this.stringExpr = stringExpr;
this.charsExpr = charsExpr; this.charsExpr = charsExpr;
this.visitFn = visitFn;
} }
@Nonnull @Nonnull
@ -286,9 +305,7 @@ public abstract class TrimExprMacro implements ExprMacroTable.ExprMacro
@Override @Override
public Expr visit(Shuttle shuttle) public Expr visit(Shuttle shuttle)
{ {
Expr newStringExpr = stringExpr.visit(shuttle); return visitFn.apply(shuttle);
Expr newCharsExpr = charsExpr.visit(shuttle);
return shuttle.visit(new TrimDynamicCharsExpr(mode, newStringExpr, newCharsExpr));
} }
@Override @Override
@ -316,6 +333,8 @@ public abstract class TrimExprMacro implements ExprMacroTable.ExprMacro
return false; return false;
} }
TrimDynamicCharsExpr that = (TrimDynamicCharsExpr) o; TrimDynamicCharsExpr that = (TrimDynamicCharsExpr) o;
// Doesn't use "visitFn", but that's OK, because visitFn is determined entirely by "mode".
return mode == that.mode && return mode == that.mode &&
Objects.equals(stringExpr, that.stringExpr) && Objects.equals(stringExpr, that.stringExpr) &&
Objects.equals(charsExpr, that.charsExpr); Objects.equals(charsExpr, that.charsExpr);

View File

@ -28,7 +28,7 @@ public class TrimExprMacroTest
public void testEqualsContractForTrimStaticCharsExpr() public void testEqualsContractForTrimStaticCharsExpr()
{ {
EqualsVerifier.forClass(TrimExprMacro.TrimStaticCharsExpr.class) EqualsVerifier.forClass(TrimExprMacro.TrimStaticCharsExpr.class)
.withIgnoredFields("analyzeInputsSupplier") .withIgnoredFields("analyzeInputsSupplier", "visitFn")
.usingGetClass() .usingGetClass()
.verify(); .verify();
} }
@ -37,6 +37,7 @@ public class TrimExprMacroTest
public void testEqualsContractForTrimDynamicCharsExpr() public void testEqualsContractForTrimDynamicCharsExpr()
{ {
EqualsVerifier.forClass(TrimExprMacro.TrimDynamicCharsExpr.class) EqualsVerifier.forClass(TrimExprMacro.TrimDynamicCharsExpr.class)
.withIgnoredFields("visitFn")
.usingGetClass() .usingGetClass()
.verify(); .verify();
} }