Skip to content

Commit 6d1f086

Browse files
authored
Simplify Enums & Static Final Fields (#222)
1 parent 275e9b7 commit 6d1f086

4 files changed

Lines changed: 68 additions & 7 deletions

File tree

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionFolding.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package liquidjava.rj_language.opt;
22

33
import liquidjava.rj_language.ast.BinaryExpression;
4+
import liquidjava.rj_language.ast.Enum;
45
import liquidjava.rj_language.ast.Expression;
56
import liquidjava.rj_language.ast.GroupExpression;
67
import liquidjava.rj_language.ast.Ite;
@@ -62,6 +63,16 @@ private static ValDerivationNode foldBinary(ValDerivationNode node) {
6263
Expression left = leftNode.getValue();
6364
Expression right = rightNode.getValue();
6465
String op = binExp.getOperator();
66+
67+
if (left instanceof Enum en && en.getResolvedLiteral() != null) {
68+
left = en.getResolvedLiteral().clone();
69+
leftNode = new ValDerivationNode(left, leftNode);
70+
}
71+
if (right instanceof Enum en && en.getResolvedLiteral() != null) {
72+
right = en.getResolvedLiteral().clone();
73+
rightNode = new ValDerivationNode(right, rightNode);
74+
}
75+
6576
binExp.setChild(0, left);
6677
binExp.setChild(1, right);
6778

@@ -146,6 +157,18 @@ else if (left instanceof LiteralBoolean && right instanceof LiteralBoolean) {
146157
return new ValDerivationNode(res, new BinaryDerivationNode(leftNode, rightNode, op));
147158
}
148159

160+
else if (left instanceof Enum leftEnum && right instanceof Enum rightEnum
161+
&& leftEnum.getTypeName().equals(rightEnum.getTypeName())) {
162+
boolean equal = leftEnum.getConstName().equals(rightEnum.getConstName());
163+
Expression res = switch (op) {
164+
case "==" -> new LiteralBoolean(equal);
165+
case "!=" -> new LiteralBoolean(!equal);
166+
default -> null;
167+
};
168+
if (res != null)
169+
return new ValDerivationNode(res, new BinaryDerivationNode(leftNode, rightNode, op));
170+
}
171+
149172
ValDerivationNode adjacentConstants = foldAdjacentIntegerConstants(leftNode, rightNode, op);
150173
if (adjacentConstants != null)
151174
return adjacentConstants;

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariablePropagation.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package liquidjava.rj_language.opt;
22

33
import liquidjava.rj_language.ast.BinaryExpression;
4+
import liquidjava.rj_language.ast.Enum;
45
import liquidjava.rj_language.ast.Expression;
56
import liquidjava.rj_language.ast.FunctionInvocation;
67
import liquidjava.rj_language.ast.UnaryExpression;
@@ -28,7 +29,7 @@ public static ValDerivationNode propagate(Expression exp, ValDerivationNode prev
2829
Map<String, Expression> expressionSubstitutions = new HashMap<>(); // var == expression
2930
for (Map.Entry<String, Expression> entry : substitutions.entrySet()) {
3031
Expression value = entry.getValue();
31-
if (value.isLiteral() || value instanceof Var) {
32+
if (value.isLiteral() || value instanceof Var || value instanceof Enum) {
3233
directSubstitutions.put(entry.getKey(), value);
3334
} else {
3435
expressionSubstitutions.put(entry.getKey(), value);

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariableResolver.java

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import java.util.Set;
77

88
import liquidjava.rj_language.ast.BinaryExpression;
9+
import liquidjava.rj_language.ast.Enum;
910
import liquidjava.rj_language.ast.Expression;
1011
import liquidjava.rj_language.ast.FunctionInvocation;
1112
import liquidjava.rj_language.ast.Var;
@@ -56,9 +57,9 @@ private static void resolveRecursive(Expression exp, Map<String, Expression> map
5657
String leftKey = substitutionKey(left);
5758
String rightKey = substitutionKey(right);
5859

59-
if (leftKey != null && right.isLiteral()) {
60+
if (leftKey != null && isConstant(right)) {
6061
map.put(leftKey, right.clone());
61-
} else if (rightKey != null && left.isLiteral()) {
62+
} else if (rightKey != null && isConstant(left)) {
6263
map.put(rightKey, left.clone());
6364
} else if (left instanceof Var leftVar && right instanceof Var rightVar) {
6465
// to substitute internal variable with user-facing variable
@@ -144,15 +145,15 @@ private static boolean hasUsage(Expression exp, String name, Expression value) {
144145
Expression left = binary.getFirstOperand();
145146
Expression right = binary.getSecondOperand();
146147
if (left instanceof Var v && v.getName().equals(name) && right.equals(value)
147-
&& (right.isLiteral() || (!(right instanceof Var) && canSubstitute(v, right))))
148+
&& (isConstant(right) || (!(right instanceof Var) && canSubstitute(v, right))))
148149
return false;
149150
if (left instanceof FunctionInvocation && left.toString().equals(name) && right.equals(value)
150-
&& (right.isLiteral() || (!(right instanceof Var) && !containsExpression(right, left))))
151+
&& (isConstant(right) || (!(right instanceof Var) && !containsExpression(right, left))))
151152
return false;
152-
if (right instanceof Var v && v.getName().equals(name) && left.equals(value) && left.isLiteral())
153+
if (right instanceof Var v && v.getName().equals(name) && left.equals(value) && isConstant(left))
153154
return false;
154155
if (right instanceof FunctionInvocation && right.toString().equals(name) && left.equals(value)
155-
&& left.isLiteral())
156+
&& isConstant(left))
156157
return false;
157158
}
158159

@@ -198,6 +199,10 @@ private static boolean canSubstitute(Var var, Expression value) {
198199
return !isReturnVar(var) && !isFreshVar(var) && !containsVariable(value, var.getName());
199200
}
200201

202+
private static boolean isConstant(Expression exp) {
203+
return exp.isLiteral() || exp instanceof Enum;
204+
}
205+
201206
private static boolean containsVariable(Expression exp, String name) {
202207
if (exp instanceof Var var)
203208
return var.getName().equals(name);

liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,4 +626,36 @@ void testFunctionInvocationEqualitiesMixWithVariables() {
626626

627627
assertEquals("3", result.getValue().toString());
628628
}
629+
630+
@Test
631+
void testEnumConstantsPropagateIntoVariableEquality() {
632+
Expression expression = parse("current == mode && mode == Mode.Photo");
633+
ValDerivationNode result = ExpressionSimplifier.simplify(expression);
634+
635+
assertEquals("current == Mode.Photo", result.getValue().toString());
636+
}
637+
638+
@Test
639+
void testEnumConstantsPropagateTransitively() {
640+
Expression expression = parse("target == current && current == mode && mode == Mode.Photo");
641+
ValDerivationNode result = ExpressionSimplifier.simplify(expression);
642+
643+
assertEquals("target == Mode.Photo", result.getValue().toString());
644+
}
645+
646+
@Test
647+
void testEnumConstantsPropagateThroughFunctionInvocations() {
648+
Expression expression = parse("modeOf(x) == Mode.Photo && current == modeOf(x)");
649+
ValDerivationNode result = ExpressionSimplifier.simplify(expression);
650+
651+
assertEquals("current == Mode.Photo", result.getValue().toString());
652+
}
653+
654+
@Test
655+
void testEnumConstantsPropagateIntoTernaryCondition() {
656+
Expression expression = parse("mode == Mode.Photo && (mode == Mode.Video ? explicit(param) : start(param))");
657+
ValDerivationNode result = ExpressionSimplifier.simplify(expression);
658+
659+
assertEquals("start(param)", result.getValue().toString());
660+
}
629661
}

0 commit comments

Comments
 (0)