diff --git a/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imtranslation/Flatten.java b/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imtranslation/Flatten.java index 5cbe146e2..a9b7a7edd 100644 --- a/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imtranslation/Flatten.java +++ b/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imtranslation/Flatten.java @@ -178,7 +178,7 @@ public ImLExpr getExpr() { public static class MultiResult { - final List stmts; + public final List stmts; final List exprs; public MultiResult(List stmts, List exprs) { @@ -488,6 +488,49 @@ public static Result flatten(ImCompiletimeExpr e, ImTranslator translator, ImFun } + private static boolean isEffectivelyPure(ImExpr expr) { + if (expr.attrPurity() instanceof Pure) { + return true; + } + return expr instanceof ImConst; + } + + private static boolean isUnchangedVarAccess(ImExpr expr, List laterResults, int startIndex) { + if (!(expr instanceof ImVarAccess)) { + return false; + } + + ImVar target = ((ImVarAccess) expr).getVar(); + for (int i = startIndex; i < laterResults.size(); i++) { + for (ImStmt stmt : laterResults.get(i).stmts) { + if (writesVar(stmt, target)) { + return false; + } + } + } + + return true; + } + + private static boolean writesVar(ImStmt stmt, ImVar var) { + class WriteDetector extends ImStmt.DefaultVisitor { + boolean writes = false; + + @Override + public void visit(ImSet e) { + super.visit(e); + if (e.getLeft() instanceof ImVarAccess + && ((ImVarAccess) e.getLeft()).getVar() == var) { + writes = true; + } + } + } + + WriteDetector detector = new WriteDetector(); + stmt.accept(detector); + return detector.writes; + } + private static MultiResult flattenExprs(ImTranslator t, ImFunction f, ImExpr... exprs) { return flattenExprs(t, f, Arrays.asList(exprs)); } @@ -510,7 +553,8 @@ private static MultiResult flattenExprs(ImTranslator t, ImFunction f, List= withStmts) { newExprs.add(r.expr); } else { @@ -541,7 +585,8 @@ private static MultiResultL flattenExprsL(ImTranslator t, ImFunction f, List= withStmts) { newExprs.add(r.getExpr()); } else { diff --git a/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/FlattenTests.java b/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/FlattenTests.java new file mode 100644 index 000000000..58e553655 --- /dev/null +++ b/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/FlattenTests.java @@ -0,0 +1,115 @@ +package tests.wurstscript.tests; + +import de.peeeq.wurstscript.RunArgs; +import de.peeeq.wurstscript.ast.Ast; +import de.peeeq.wurstscript.jassIm.*; +import de.peeeq.wurstscript.translation.imtranslation.CallType; +import de.peeeq.wurstscript.translation.imtranslation.Flatten; +import de.peeeq.wurstscript.translation.imtranslation.Flatten.MultiResult; +import de.peeeq.wurstscript.translation.imtranslation.ImTranslator; +import org.testng.annotations.Test; + +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.Collections; + +import static org.testng.Assert.*; + +public class FlattenTests { + + @Test + public void avoid_extra_temp_for_var_access() throws Exception { + ImTranslator translator = new ImTranslator(Ast.WurstModel(), true, new RunArgs()); + de.peeeq.wurstscript.ast.Element trace = Ast.NoExpr(); + ImVar sourceVar = JassIm.ImVar(trace, JassIm.ImSimpleType("int"), "source", false); + ImVar targetVar = JassIm.ImVar(trace, JassIm.ImSimpleType("int"), "target", false); + + ImExpr simpleAccess = JassIm.ImVarAccess(sourceVar); + ImExpr expressionWithStatements = JassIm.ImStatementExpr( + JassIm.ImStmts(JassIm.ImSet(trace, JassIm.ImVarAccess(targetVar), JassIm.ImIntVal(1))), + JassIm.ImVarAccess(targetVar)); + + ImFunction function = JassIm.ImFunction(trace, "test", JassIm.ImTypeVars(), JassIm.ImVars(), JassIm.ImVoid(), + JassIm.ImVars(sourceVar, targetVar), JassIm.ImStmts(), Collections.emptyList()); + + Method flattenExprs = Flatten.class.getDeclaredMethod("flattenExprs", ImTranslator.class, ImFunction.class, java.util.List.class); + flattenExprs.setAccessible(true); + + MultiResult result = (MultiResult) flattenExprs.invoke(null, translator, function, Arrays.asList(simpleAccess, expressionWithStatements)); + + assertEquals(function.getLocals().size(), 2, "no extra locals should be introduced"); + assertSame(simpleAccess, result.expr(0), "the original access should be reused"); + } + + @Test + public void add_temp_for_impure_expr_with_followup_statements() throws Exception { + ImTranslator translator = new ImTranslator(Ast.WurstModel(), true, new RunArgs()); + de.peeeq.wurstscript.ast.Element trace = Ast.NoExpr(); + ImVar sourceVar = JassIm.ImVar(trace, JassIm.ImSimpleType("int"), "source", false); + ImVar targetVar = JassIm.ImVar(trace, JassIm.ImSimpleType("int"), "target", false); + + ImFunction impureCallee = JassIm.ImFunction(trace, "impure", JassIm.ImTypeVars(), JassIm.ImVars(), + JassIm.ImSimpleType("int"), JassIm.ImVars(), JassIm.ImStmts(), Collections.emptyList()); + ImExpr impureCall = JassIm.ImFunctionCall(trace, impureCallee, JassIm.ImTypeArguments(), JassIm.ImExprs(), + false, CallType.NORMAL); + + ImExpr expressionWithStatements = JassIm.ImStatementExpr( + JassIm.ImStmts(JassIm.ImSet(trace, JassIm.ImVarAccess(targetVar), JassIm.ImIntVal(1))), + JassIm.ImVarAccess(targetVar)); + + ImFunction function = JassIm.ImFunction(trace, "test", JassIm.ImTypeVars(), JassIm.ImVars(), JassIm.ImVoid(), + JassIm.ImVars(sourceVar, targetVar), JassIm.ImStmts(), Collections.emptyList()); + + Method flattenExprs = Flatten.class.getDeclaredMethod("flattenExprs", ImTranslator.class, ImFunction.class, java.util.List.class); + flattenExprs.setAccessible(true); + + MultiResult result = (MultiResult) flattenExprs.invoke(null, translator, function, Arrays.asList(impureCall, expressionWithStatements)); + + assertEquals(function.getLocals().size(), 3, "a temporary should be added for the impure expression"); + assertEquals(result.stmts.size(), 2, "flattening should produce a temp assignment followed by the statement expr body"); + + ImExpr flattenedImpure = result.expr(0); + assertTrue(flattenedImpure instanceof ImVarAccess, "impure call should be replaced with temp access"); + ImVar tempVar = ((ImVarAccess) flattenedImpure).getVar(); + assertNotSame(tempVar, sourceVar); + assertNotSame(tempVar, targetVar); + + ImStmt tempAssignment = result.stmts.get(0); + assertTrue(tempAssignment instanceof ImSet, "first statement should assign the impure result to the temp var"); + ImLExpr left = ((ImSet) tempAssignment).getLeft(); + assertTrue(left instanceof ImVarAccess); + assertSame(((ImVarAccess) left).getVar(), tempVar); + } + + @Test + public void add_temp_when_followup_writes_accessed_var() throws Exception { + ImTranslator translator = new ImTranslator(Ast.WurstModel(), true, new RunArgs()); + de.peeeq.wurstscript.ast.Element trace = Ast.NoExpr(); + ImVar shared = JassIm.ImVar(trace, JassIm.ImSimpleType("int"), "shared", false); + + ImExpr initialRead = JassIm.ImVarAccess(shared); + ImExpr statementWithWrite = JassIm.ImStatementExpr( + JassIm.ImStmts(JassIm.ImSet(trace, JassIm.ImVarAccess(shared), JassIm.ImIntVal(4))), + JassIm.ImIntVal(2)); + + ImFunction function = JassIm.ImFunction(trace, "test", JassIm.ImTypeVars(), JassIm.ImVars(), JassIm.ImVoid(), + JassIm.ImVars(shared), JassIm.ImStmts(), Collections.emptyList()); + + Method flattenExprs = Flatten.class.getDeclaredMethod("flattenExprs", ImTranslator.class, ImFunction.class, java.util.List.class); + flattenExprs.setAccessible(true); + + MultiResult result = (MultiResult) flattenExprs.invoke(null, translator, function, Arrays.asList(initialRead, statementWithWrite)); + + assertEquals(function.getLocals().size(), 2, "flattening should introduce a temp for the first argument"); + assertEquals(result.stmts.size(), 2, "flattening should store the first arg before executing later statements"); + + ImExpr firstExpr = result.expr(0); + assertTrue(firstExpr instanceof ImVarAccess, "first expression should be rewritten to a temp access"); + ImVar tempVar = ((ImVarAccess) firstExpr).getVar(); + assertNotSame(tempVar, shared); + + ImStmt firstStmt = result.stmts.get(0); + assertTrue(firstStmt instanceof ImSet, "first statement should capture the original value"); + assertSame(((ImVarAccess) ((ImSet) firstStmt).getLeft()).getVar(), tempVar); + } +} diff --git a/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/OptimizerTests.java b/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/OptimizerTests.java index 37d661bb8..175d281f4 100644 --- a/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/OptimizerTests.java +++ b/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/OptimizerTests.java @@ -427,7 +427,6 @@ public void test_tempVarRemover() throws IOException { } @Test - @Ignore // This test was for a rewrite that caused an infinite loop in the optimizer. public void test_mult2rewrite() throws IOException { test().lines( "package test", @@ -442,7 +441,7 @@ public void test_mult2rewrite() throws IOException { "endpackage"); String output = Files.toString(new File("./test-output/OptimizerTests_test_mult2rewrite_inlopt.j"), Charsets.UTF_8); - assertTrue(!output.contains("blub_a") && !(output.contains("blub_b") && !output.contains("blub_c"))); + assertTrue(output.contains("blub_a") && !(output.contains("blub_b") && !output.contains("blub_c"))); } @Test