diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Variables.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Variables.java index 16130476c38..4905011520a 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/Variables.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/Variables.java @@ -138,15 +138,15 @@ public final class Variables { public void decrementScope() { int remove = scopes.pop(); - + while (remove > 0) { Variable variable = variables.pop(); - + // TODO: is this working? the code reads backwards... if (variable.read) { throw variable.location.createError(new IllegalArgumentException("Variable [" + variable.name + "] never used.")); } - + --remove; } } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/antlr/Walker.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/antlr/Walker.java index 3e315bbcb24..a309fccdd27 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/antlr/Walker.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/antlr/Walker.java @@ -263,11 +263,9 @@ public final class Walker extends PainlessParserBaseVisitor { if (ctx.trailer() != null) { SBlock block = (SBlock)visit(ctx.trailer()); - return new SFor(location(ctx), - settings.getMaxLoopCounter(), initializer, expression, afterthought, block); + return new SFor(location(ctx), settings.getMaxLoopCounter(), initializer, expression, afterthought, block); } else if (ctx.empty() != null) { - return new SFor(location(ctx), - settings.getMaxLoopCounter(), initializer, expression, afterthought, null); + return new SFor(location(ctx), settings.getMaxLoopCounter(), initializer, expression, afterthought, null); } else { throw location(ctx).createError(new IllegalStateException("Illegal tree structure.")); } @@ -275,12 +273,16 @@ public final class Walker extends PainlessParserBaseVisitor { @Override public Object visitEach(EachContext ctx) { + if (settings.getMaxLoopCounter() > 0) { + reserved.usesLoop(); + } + String type = ctx.decltype().getText(); String name = ctx.ID().getText(); AExpression expression = (AExpression)visit(ctx.expression()); SBlock block = (SBlock)visit(ctx.trailer()); - return new SEach(location(ctx), type, name, expression, block); + return new SEach(location(ctx), settings.getMaxLoopCounter(), type, name, expression, block); } @Override diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LBrace.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LBrace.java index 4a13e03a490..b0816540c5e 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LBrace.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/LBrace.java @@ -47,7 +47,7 @@ public final class LBrace extends ALink { throw createError(new IllegalArgumentException("Illegal array access made without target.")); } - final Sort sort = before.sort; + Sort sort = before.sort; if (sort == Sort.ARRAY) { index.expected = Definition.INT_TYPE; diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SEach.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SEach.java index aacde9b48b1..8d4052326fc 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SEach.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SEach.java @@ -19,19 +19,40 @@ package org.elasticsearch.painless.node; +import org.elasticsearch.painless.AnalyzerCaster; +import org.elasticsearch.painless.Definition; +import org.elasticsearch.painless.Definition.Cast; +import org.elasticsearch.painless.Definition.Method; +import org.elasticsearch.painless.Definition.MethodKey; +import org.elasticsearch.painless.Definition.Sort; +import org.elasticsearch.painless.Definition.Type; import org.elasticsearch.painless.Location; import org.elasticsearch.painless.MethodWriter; import org.elasticsearch.painless.Variables; +import org.elasticsearch.painless.Variables.Variable; +import org.objectweb.asm.Label; +import org.objectweb.asm.Opcodes; public class SEach extends AStatement { + + final int maxLoopCounter; final String type; final String name; AExpression expression; AStatement block; - public SEach(final Location location, final String type, final String name, final AExpression expression, final SBlock block) { + Variable variable = null; + Variable iterator = null; + Method method = null; + Method hasNext = null; + Method next = null; + Cast cast = null; + + public SEach(final Location location, final int maxLoopCounter, + final String type, final String name, final AExpression expression, final SBlock block) { super(location); + this.maxLoopCounter = maxLoopCounter; this.type = type; this.name = name; this.expression = expression; @@ -41,11 +62,107 @@ public class SEach extends AStatement { @Override AStatement analyze(Variables variables) { - return null; + expression.analyze(variables); + expression.expected = expression.actual; + expression = expression.cast(variables); + + Sort sort = expression.actual.sort; + + if (sort == Sort.ARRAY) { + throw location.createError(new UnsupportedOperationException("Cannot execute for each against array type.")); + } else if (sort == Sort.DEF) { + throw location.createError(new UnsupportedOperationException("Cannot execute for each against def type.")); + } else if (Iterable.class.isAssignableFrom(expression.actual.clazz)) { + final Type type; + + try { + type = Definition.getType(this.type); + } catch (IllegalArgumentException exception) { + throw createError(new IllegalArgumentException("Not a type [" + this.type + "].")); + } + + variables.incrementScope(); + + Type itr = Definition.getType("Iterator"); + + variable = variables.addVariable(location, type, name, false, false); + iterator = variables.addVariable(location, itr, "#itr" + location.getOffset(), true, false); + + method = expression.actual.struct.methods.get(new MethodKey("iterator", 0)); + + if (method == null) { + throw location.createError(new IllegalArgumentException( + "Unable to create iterator for the type [" + expression.actual.name + "].")); + } + + hasNext = itr.struct.methods.get(new MethodKey("hasNext", 0)); + + if (hasNext == null) { + throw location.createError(new IllegalArgumentException("Method [hasNext] does not exist for type [Iterator].")); + } else if (hasNext.rtn.sort != Sort.BOOL) { + throw location.createError(new IllegalArgumentException("Method [hasNext] does not return type [boolean].")); + } + + next = itr.struct.methods.get(new MethodKey("next", 0)); + + if (next == null) { + throw location.createError(new IllegalArgumentException("Method [next] does not exist for type [Iterator].")); + } else if (next.rtn.sort != Sort.OBJECT) { + throw location.createError(new IllegalArgumentException("Method [next] does not return type [Object].")); + } + + cast = AnalyzerCaster.getLegalCast(location, Definition.getType("Object"), type, true, true); + + if (block == null) { + throw location.createError(new IllegalArgumentException("Extraneous for each loop.")); + } + + block = block.analyze(variables); + block.statementCount = Math.max(1, block.statementCount); + + if (block.loopEscape && !block.anyContinue) { + throw createError(new IllegalArgumentException("Extraneous for loop.")); + } + + statementCount = 1; + + if (maxLoopCounter > 0) { + loopCounterSlot = variables.getVariable(location, "#loop").slot; + } + + variables.decrementScope(); + + return this; + } else { + throw location.createError(new IllegalArgumentException("Illegal for each type [" + expression.actual.name + "].")); + } } @Override void write(MethodWriter writer) { + expression.write(writer); + if (java.lang.reflect.Modifier.isInterface(method.owner.clazz.getModifiers())) { + writer.invokeInterface(method.owner.type, method.method); + } else { + writer.invokeVirtual(method.owner.type, method.method); + } + + writer.visitVarInsn(iterator.type.type.getOpcode(Opcodes.ISTORE), iterator.slot); + + Label end = new Label(); + + writer.visitVarInsn(iterator.type.type.getOpcode(Opcodes.ILOAD), iterator.slot); + writer.invokeInterface(hasNext.owner.type, hasNext.method); + writer.ifZCmp(MethodWriter.EQ, end); + + writer.visitVarInsn(iterator.type.type.getOpcode(Opcodes.ILOAD), iterator.slot); + writer.invokeInterface(next.owner.type, next.method); + writer.writeCast(cast); + writer.visitVarInsn(variable.type.type.getOpcode(Opcodes.ISTORE), variable.slot); + + block.write(writer); + + writer.mark(end); } } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SFor.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SFor.java index 75909dff81c..6085c1ce8d9 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SFor.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SFor.java @@ -40,11 +40,11 @@ public final class SFor extends AStatement { ANode initializer, AExpression condition, AExpression afterthought, SBlock block) { super(location); + this.maxLoopCounter = maxLoopCounter; this.initializer = initializer; this.condition = condition; this.afterthought = afterthought; this.block = block; - this.maxLoopCounter = maxLoopCounter; } @Override diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/BasicStatementTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/BasicStatementTests.java index 0d6a54b515b..1fffcf1ee29 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/BasicStatementTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/BasicStatementTests.java @@ -125,6 +125,11 @@ public class BasicStatementTests extends ScriptTestCase { } } + public void testEachStatement() { + assertEquals(6, exec("List l = new ArrayList(); l.add(1); l.add(2); l.add(3); int total = 0;" + + " for (int x : l) total += x; return x")); + } + public void testDeclarationStatement() { assertEquals((byte)2, exec("byte a = 2; return a;")); assertEquals((short)2, exec("short a = 2; return a;"));