From 9a78f6955bd2d40306830ee211226f123c9c2d85 Mon Sep 17 00:00:00 2001 From: Jack Conradson Date: Mon, 6 Jun 2016 15:25:09 -0700 Subject: [PATCH] Added foreach for array types. --- .../painless/node/AStatement.java | 19 +++ .../painless/node/SArrayEach.java | 130 ++++++++++++++++++ .../elasticsearch/painless/node/SEach.java | 10 +- .../painless/BasicStatementTests.java | 9 ++ 4 files changed, 163 insertions(+), 5 deletions(-) create mode 100644 modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SArrayEach.java diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/AStatement.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/AStatement.java index 7b5d4a2b4dc..f6b1048028c 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/AStatement.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/AStatement.java @@ -124,4 +124,23 @@ public abstract class AStatement extends ANode { * Writes ASM based on the data collected during the analysis phase. */ abstract void write(MethodWriter writer); + + /** + * Used to copy statement data from one to another during analysis in the case of replacement. + */ + final AStatement copy(AStatement statement) { + lastSource = statement.lastSource; + beginLoop = statement.beginLoop; + inLoop = statement.inLoop; + lastLoop = statement.lastLoop; + methodEscape = statement.methodEscape; + loopEscape = statement.loopEscape; + allEscape = statement.allEscape; + anyContinue = statement.anyContinue; + anyBreak = statement.anyBreak; + loopCounterSlot = statement.loopCounterSlot; + statementCount = statement.statementCount; + + return this; + } } diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SArrayEach.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SArrayEach.java new file mode 100644 index 00000000000..1c3f90ed93e --- /dev/null +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SArrayEach.java @@ -0,0 +1,130 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.painless.node; + +import org.elasticsearch.painless.AnalyzerCaster; +import org.elasticsearch.painless.Definition; +import org.elasticsearch.painless.Definition.Cast; +import org.elasticsearch.painless.Definition.Type; +import org.elasticsearch.painless.Location; +import org.elasticsearch.painless.MethodWriter; +import org.elasticsearch.painless.Variables; +import org.elasticsearch.painless.Variables.Variable; +import org.objectweb.asm.Label; +import org.objectweb.asm.Opcodes; + +class SArrayEach extends AStatement { + final int maxLoopCounter; + final String type; + final String name; + AExpression expression; + AStatement block; + + Variable variable = null; + Variable array = null; + Variable index = null; + Type indexed = null; + Cast cast = null; + + SArrayEach(final Location location, final int maxLoopCounter, + final String type, final String name, final AExpression expression, final SBlock block) { + super(location); + + this.maxLoopCounter = maxLoopCounter; + this.type = type; + this.name = name; + this.expression = expression; + this.block = block; + } + + @Override + AStatement analyze(Variables variables) { + final Type type; + + try { + type = Definition.getType(this.type); + } catch (IllegalArgumentException exception) { + throw createError(new IllegalArgumentException("Not a type [" + this.type + "].")); + } + + variables.incrementScope(); + + variable = variables.addVariable(location, type, name, true, false); + array = variables.addVariable(location, expression.actual, "#array" + location.getOffset(), true, false); + index = variables.addVariable(location, Definition.INT_TYPE, "#index" + location.getOffset(), true, false); + indexed = Definition.getType(expression.actual.struct, expression.actual.dimensions - 1); + cast = AnalyzerCaster.getLegalCast(location, indexed, type, true, true); + + if (block == null) { + throw location.createError(new IllegalArgumentException("Extraneous for each loop.")); + } + + block.beginLoop = true; + block.inLoop = true; + block = block.analyze(variables); + block.statementCount = Math.max(1, block.statementCount); + + if (block.loopEscape && !block.anyContinue) { + throw createError(new IllegalArgumentException("Extraneous for loop.")); + } + + statementCount = 1; + + if (maxLoopCounter > 0) { + loopCounterSlot = variables.getVariable(location, "#loop").slot; + } + + variables.decrementScope(); + + return this; + } + + @Override + void write(MethodWriter writer) { + writer.writeStatementOffset(location); + + expression.write(writer); + writer.visitVarInsn(array.type.type.getOpcode(Opcodes.ISTORE), array.slot); + writer.push(-1); + writer.visitVarInsn(index.type.type.getOpcode(Opcodes.ISTORE), index.slot); + + Label begin = new Label(); + Label end = new Label(); + + writer.mark(begin); + + writer.visitIincInsn(index.slot, 1); + writer.visitVarInsn(index.type.type.getOpcode(Opcodes.ILOAD), index.slot); + writer.visitVarInsn(array.type.type.getOpcode(Opcodes.ILOAD), array.slot); + writer.arrayLength(); + writer.ifICmp(MethodWriter.GE, end); + + writer.visitVarInsn(array.type.type.getOpcode(Opcodes.ILOAD), array.slot); + writer.visitVarInsn(index.type.type.getOpcode(Opcodes.ILOAD), index.slot); + writer.arrayLoad(indexed.type); + writer.writeCast(cast); + writer.visitVarInsn(variable.type.type.getOpcode(Opcodes.ISTORE), variable.slot); + + block.write(writer); + + writer.goTo(begin); + writer.mark(end); + } +} diff --git a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SEach.java b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SEach.java index d7942583387..5dd1ee0936d 100644 --- a/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SEach.java +++ b/modules/lang-painless/src/main/java/org/elasticsearch/painless/node/SEach.java @@ -33,9 +33,6 @@ import org.elasticsearch.painless.Variables.Variable; import org.objectweb.asm.Label; import org.objectweb.asm.Opcodes; -import java.util.HashMap; -import java.util.Map; - public class SEach extends AStatement { final int maxLoopCounter; @@ -62,7 +59,6 @@ public class SEach extends AStatement { this.block = block; } - @Override AStatement analyze(Variables variables) { expression.analyze(variables); @@ -72,7 +68,7 @@ public class SEach extends AStatement { Sort sort = expression.actual.sort; if (sort == Sort.ARRAY) { - throw location.createError(new UnsupportedOperationException("Cannot execute for each against array type.")); + return new SArrayEach(location, maxLoopCounter, type, name, expression, (SBlock)block).copy(this).analyze(variables); } else if (sort == Sort.DEF) { throw location.createError(new UnsupportedOperationException("Cannot execute for each against def type.")); } else if (Iterable.class.isAssignableFrom(expression.actual.clazz)) { @@ -120,6 +116,8 @@ public class SEach extends AStatement { throw location.createError(new IllegalArgumentException("Extraneous for each loop.")); } + block.beginLoop = true; + block.inLoop = true; block = block.analyze(variables); block.statementCount = Math.max(1, block.statementCount); @@ -143,6 +141,8 @@ public class SEach extends AStatement { @Override void write(MethodWriter writer) { + writer.writeStatementOffset(location); + expression.write(writer); if (java.lang.reflect.Modifier.isInterface(method.owner.clazz.getModifiers())) { diff --git a/modules/lang-painless/src/test/java/org/elasticsearch/painless/BasicStatementTests.java b/modules/lang-painless/src/test/java/org/elasticsearch/painless/BasicStatementTests.java index 7a2846433a6..6ddb4067426 100644 --- a/modules/lang-painless/src/test/java/org/elasticsearch/painless/BasicStatementTests.java +++ b/modules/lang-painless/src/test/java/org/elasticsearch/painless/BasicStatementTests.java @@ -135,6 +135,15 @@ public class BasicStatementTests extends ScriptTestCase { " for (Map.Entry e : m.entrySet()) { cat += e.getKey(); total += e.getValue(); } return cat + total")); } + public void testArrayForEachStatement() { + assertEquals(6, exec("int[] a = new int[3]; a[0] = 1; a[1] = 2; a[2] = 3; int total = 0;" + + " for (int x : a) total += x; return total")); + assertEquals("123", exec("String[] a = new String[3]; a[0] = '1'; a[1] = '2'; a[2] = '3'; def total = '';" + + " for (String x : a) total += x; return total")); + assertEquals(6, exec("int[][] i = new int[3][1]; i[0][0] = 1; i[1][0] = 2; i[2][0] = 3; int total = 0;" + + " for (int[] j : i) total += j[0]; return total")); + } + public void testDeclarationStatement() { assertEquals((byte)2, exec("byte a = 2; return a;")); assertEquals((short)2, exec("short a = 2; return a;"));