Skip to content

Commit 1516285

Browse files
committed
Add Function Invocation Simplification
1 parent c7afdde commit 1516285

3 files changed

Lines changed: 81 additions & 23 deletions

File tree

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariablePropagation.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import liquidjava.rj_language.ast.BinaryExpression;
44
import liquidjava.rj_language.ast.Expression;
5+
import liquidjava.rj_language.ast.FunctionInvocation;
56
import liquidjava.rj_language.ast.UnaryExpression;
67
import liquidjava.rj_language.ast.Var;
78
import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode;
@@ -69,6 +70,12 @@ private static ValDerivationNode propagateRecursive(Expression exp, Map<String,
6970
return new ValDerivationNode(var, null);
7071
}
7172

73+
if (exp instanceof FunctionInvocation) {
74+
Expression value = subs.get(exp.toString());
75+
if (value != null)
76+
return new ValDerivationNode(value.clone(), new VarDerivationNode(exp.toString()));
77+
}
78+
7279
// lift unary origin
7380
if (exp instanceof UnaryExpression unary) {
7481
ValDerivationNode operand = propagateRecursive(unary.getChildren().get(0), subs, varOrigins);

liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariableResolver.java

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import liquidjava.rj_language.ast.BinaryExpression;
99
import liquidjava.rj_language.ast.Expression;
10+
import liquidjava.rj_language.ast.FunctionInvocation;
1011
import liquidjava.rj_language.ast.Var;
1112

1213
public class VariableResolver {
@@ -45,33 +46,50 @@ private static void resolveRecursive(Expression exp, Map<String, Expression> map
4546
if ("&&".equals(op)) {
4647
resolveRecursive(be.getFirstOperand(), map);
4748
resolveRecursive(be.getSecondOperand(), map);
48-
} else if ("==".equals(op)) {
49-
Expression left = be.getFirstOperand();
50-
Expression right = be.getSecondOperand();
51-
if (left instanceof Var var && right.isLiteral()) {
52-
map.put(var.getName(), right.clone());
53-
} else if (right instanceof Var var && left.isLiteral()) {
54-
map.put(var.getName(), left.clone());
55-
} else if (left instanceof Var leftVar && right instanceof Var rightVar) {
56-
// to substitute internal variable with user-facing variable
57-
if (isInternal(leftVar) && !isInternal(rightVar) && !isReturnVar(leftVar)) {
58-
map.put(leftVar.getName(), right.clone());
59-
} else if (isInternal(rightVar) && !isInternal(leftVar) && !isReturnVar(rightVar)) {
60-
map.put(rightVar.getName(), left.clone());
61-
} else if (isInternal(leftVar) && isInternal(rightVar)) {
62-
// to substitute the lower-counter variable with the higher-counter one
63-
boolean isLeftCounterLower = getCounter(leftVar) <= getCounter(rightVar);
64-
Var lowerVar = isLeftCounterLower ? leftVar : rightVar;
65-
Var higherVar = isLeftCounterLower ? rightVar : leftVar;
66-
if (!isReturnVar(lowerVar) && !isFreshVar(higherVar))
67-
map.putIfAbsent(lowerVar.getName(), higherVar.clone());
68-
}
69-
} else if (left instanceof Var var && !(right instanceof Var) && canSubstitute(var, right)) {
70-
map.put(var.getName(), right.clone());
49+
return;
50+
}
51+
if (!"==".equals(op))
52+
return;
53+
54+
Expression left = be.getFirstOperand();
55+
Expression right = be.getSecondOperand();
56+
String leftKey = substitutionKey(left);
57+
String rightKey = substitutionKey(right);
58+
59+
if (leftKey != null && right.isLiteral()) {
60+
map.put(leftKey, right.clone());
61+
} else if (rightKey != null && left.isLiteral()) {
62+
map.put(rightKey, left.clone());
63+
} else if (left instanceof Var leftVar && right instanceof Var rightVar) {
64+
// to substitute internal variable with user-facing variable
65+
if (isInternal(leftVar) && !isInternal(rightVar) && !isReturnVar(leftVar)) {
66+
map.put(leftVar.getName(), right.clone());
67+
} else if (isInternal(rightVar) && !isInternal(leftVar) && !isReturnVar(rightVar)) {
68+
map.put(rightVar.getName(), left.clone());
69+
} else if (isInternal(leftVar) && isInternal(rightVar)) {
70+
// to substitute the lower-counter variable with the higher-counter one
71+
boolean isLeftCounterLower = getCounter(leftVar) <= getCounter(rightVar);
72+
Var lowerVar = isLeftCounterLower ? leftVar : rightVar;
73+
Var higherVar = isLeftCounterLower ? rightVar : leftVar;
74+
if (!isReturnVar(lowerVar) && !isFreshVar(higherVar))
75+
map.putIfAbsent(lowerVar.getName(), higherVar.clone());
7176
}
77+
} else if (left instanceof Var var && !(right instanceof Var) && canSubstitute(var, right)) {
78+
map.put(var.getName(), right.clone());
79+
} else if (left instanceof FunctionInvocation && !(right instanceof Var)
80+
&& !right.toString().contains(leftKey)) {
81+
map.put(leftKey, right.clone());
7282
}
7383
}
7484

85+
private static String substitutionKey(Expression exp) {
86+
if (exp instanceof Var var)
87+
return var.getName();
88+
if (exp instanceof FunctionInvocation)
89+
return exp.toString();
90+
return null;
91+
}
92+
7593
/**
7694
* Handles transitive variable equalities in the map (e.g. map: x -> y, y -> 1 => map: x -> 1, y -> 1)
7795
*
@@ -129,14 +147,22 @@ private static boolean hasUsage(Expression exp, String name) {
129147
if (left instanceof Var v && v.getName().equals(name)
130148
&& (right.isLiteral() || (!(right instanceof Var) && canSubstitute(v, right))))
131149
return false;
150+
if (left instanceof FunctionInvocation && left.toString().equals(name)
151+
&& (right.isLiteral() || (!(right instanceof Var) && !right.toString().contains(name))))
152+
return false;
132153
if (right instanceof Var v && v.getName().equals(name) && left.isLiteral())
133154
return false;
155+
if (right instanceof FunctionInvocation && right.toString().equals(name) && left.isLiteral())
156+
return false;
134157
}
135158

136159
// usage found
137160
if (exp instanceof Var var && var.getName().equals(name)) {
138161
return true;
139162
}
163+
if (exp instanceof FunctionInvocation && exp.toString().equals(name)) {
164+
return true;
165+
}
140166

141167
// recurse children
142168
if (exp.hasChildren()) {

liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import liquidjava.rj_language.ast.AliasInvocation;
1111
import liquidjava.rj_language.ast.BinaryExpression;
1212
import liquidjava.rj_language.ast.Expression;
13+
import liquidjava.rj_language.ast.FunctionInvocation;
1314
import liquidjava.rj_language.ast.Ite;
1415
import liquidjava.rj_language.ast.LiteralBoolean;
1516
import liquidjava.rj_language.ast.LiteralInt;
@@ -1133,4 +1134,28 @@ void testFoldsAdjacentIntegerConstantsInLeftAssociatedArithmetic() {
11331134
assertEquals("x + 3", ExpressionSimplifier.simplify(xPlus1Plus2).getValue().toString());
11341135
assertEquals("x", ExpressionSimplifier.simplify(xPlus1Minus1).getValue().toString());
11351136
}
1137+
1138+
@Test
1139+
void testFunctionInvocationEqualitiesPropagateTransitively() {
1140+
// Given: size(x3) == size(x2) - 1 && size(x2) == size(x1) + 1 && size(x1) == 0
1141+
// Expected: size(x3) == 0
1142+
Expression x1 = new Var("x1");
1143+
Expression x2 = new Var("x2");
1144+
Expression x3 = new Var("x3");
1145+
Expression sizeX1 = new FunctionInvocation("size", List.of(x1));
1146+
Expression sizeX2 = new FunctionInvocation("size", List.of(x2));
1147+
Expression sizeX3 = new FunctionInvocation("size", List.of(x3));
1148+
1149+
Expression sizeX3EqualsSizeX2Minus1 = new BinaryExpression(sizeX3, "==",
1150+
new BinaryExpression(sizeX2, "-", new LiteralInt(1)));
1151+
Expression sizeX2EqualsSizeX1Plus1 = new BinaryExpression(sizeX2, "==",
1152+
new BinaryExpression(sizeX1, "+", new LiteralInt(1)));
1153+
Expression sizeX1Equals0 = new BinaryExpression(sizeX1, "==", new LiteralInt(0));
1154+
Expression fullExpression = new BinaryExpression(sizeX3EqualsSizeX2Minus1, "&&",
1155+
new BinaryExpression(sizeX2EqualsSizeX1Plus1, "&&", sizeX1Equals0));
1156+
1157+
ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression);
1158+
1159+
assertEquals("size(x3) == 0", result.getValue().toString());
1160+
}
11361161
}

0 commit comments

Comments
 (0)