diff --git a/mypy/constant_fold.py b/mypy/constant_fold.py index e1b915df2298..6cdbabf37cbe 100644 --- a/mypy/constant_fold.py +++ b/mypy/constant_fold.py @@ -11,9 +11,11 @@ ComplexExpr, Expression, FloatExpr, + IndexExpr, IntExpr, NameExpr, OpExpr, + SliceExpr, StrExpr, UnaryExpr, Var, @@ -73,6 +75,40 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non value = constant_fold_expr(expr.expr, cur_mod_id) if value is not None: return constant_fold_unary_op(expr.op, value) + elif isinstance(expr, IndexExpr): + base = constant_fold_expr(expr.base, cur_mod_id) + if base is not None: + index_expr = expr.index + if isinstance(index_expr, SliceExpr): + if index_expr.begin_index is None: + begin_index = None + else: + begin_index = constant_fold_expr(index_expr.begin_index, cur_mod_id) + if begin_index is None: + return None + if index_expr.end_index is None: + end_index = None + else: + end_index = constant_fold_expr(index_expr.end_index, cur_mod_id) + if end_index is None: + return None + if index_expr.stride is None: + stride = None + else: + stride = constant_fold_expr(index_expr.stride, cur_mod_id) + if stride is None: + return None + try: + return base[begin_index:end_index:stride] # type: ignore [index, misc] + except Exception: + return None + + index = constant_fold_expr(index_expr, cur_mod_id) + if index is not None: + try: + return base[index] # type: ignore [index] + except Exception: + return None return None diff --git a/mypyc/irbuild/constant_fold.py b/mypyc/irbuild/constant_fold.py index 53274dd3f971..0a5587c50ac4 100644 --- a/mypyc/irbuild/constant_fold.py +++ b/mypyc/irbuild/constant_fold.py @@ -10,7 +10,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Final +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Final, TypeVar from mypy.constant_fold import constant_fold_binary_op, constant_fold_unary_op from mypy.nodes import ( @@ -18,14 +19,17 @@ ComplexExpr, Expression, FloatExpr, + IndexExpr, IntExpr, MemberExpr, NameExpr, OpExpr, + SliceExpr, StrExpr, UnaryExpr, Var, ) +from mypyc.ir.ops import Value from mypyc.irbuild.util import bytes_from_str if TYPE_CHECKING: @@ -35,6 +39,8 @@ ConstantValue = int | float | complex | str | bytes CONST_TYPES: Final = (int, float, complex, str, bytes) +Expr = TypeVar("Expr", bound=Expression) + def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | None: """Return the constant value of an expression for supported operations. @@ -74,6 +80,60 @@ def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | value = constant_fold_expr(builder, expr.expr) if value is not None and not isinstance(value, bytes): return constant_fold_unary_op(expr.op, value) + elif isinstance(expr, IndexExpr): + base = constant_fold_expr(builder, expr.base) + if base is not None: + assert isinstance(base, (Sequence, dict)), base + index_expr = expr.index + if isinstance(index_expr, SliceExpr): + if index_expr.begin_index is None: + begin_index = None + else: + begin_index = constant_fold_expr(builder, index_expr.begin_index) + if begin_index is None: + return None + if index_expr.end_index is None: + end_index = None + else: + end_index = constant_fold_expr(builder, index_expr.end_index) + if end_index is None: + return None + if index_expr.stride is None: + stride = None + else: + stride = constant_fold_expr(builder, index_expr.stride) + if stride is None: + return None + + # this branching just keeps mypy happy, non-functional + if isinstance(base, Sequence): + assert isinstance(begin_index, int) or begin_index is None + assert isinstance(end_index, int) or end_index is None + assert isinstance(stride, int) or stride is None + try: + return base[begin_index:end_index:stride] + except Exception: + return None + try: # type: ignore [unreachable] + return base[begin_index:end_index:stride] + except Exception: + return None + + index = constant_fold_expr(builder, index_expr) + + # this branching just keeps mypy happy, non-functional + if isinstance(base, Sequence): + + if isinstance(index, int): + try: + return base[index] + except Exception: + return None + else: + try: # type: ignore [unreachable] + return base[index] + except Exception: + return None return None @@ -95,3 +155,31 @@ def constant_fold_binary_op_extended( return left * right return None + + +def try_constant_fold(builder: IRBuilder, expr: Expression) -> Value | None: + """Return the constant value of an expression if possible. + + Return None otherwise. + """ + value = constant_fold_expr(builder, expr) + if value is not None: + return builder.load_literal_value(value) + return None + + +def folding_candidate( + transform: Callable[[IRBuilder, Expr], Value], +) -> Callable[[IRBuilder, Expr], Value]: + """Mark a transform function as a candidate for constant folding. + + Candidate functions will attempt to short-circuit the transformation + by constant folding the expression and will only proceed to transform + the expression if folding is not possible. + """ + + def constant_fold_wrap(builder: IRBuilder, expr: Expr) -> Value: + folded = try_constant_fold(builder, expr) + return folded if folded is not None else transform(builder, expr) + + return constant_fold_wrap diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index 04a55fb257f0..ed0e40847775 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -82,7 +82,7 @@ ) from mypyc.irbuild.ast_helpers import is_borrow_friendly_expr, process_conditional from mypyc.irbuild.builder import IRBuilder, int_borrow_friendly_op -from mypyc.irbuild.constant_fold import constant_fold_expr +from mypyc.irbuild.constant_fold import constant_fold_expr, folding_candidate, try_constant_fold from mypyc.irbuild.for_helpers import ( comprehension_helper, raise_error_if_contains_unreachable_names, @@ -526,11 +526,8 @@ def translate_cast_expr(builder: IRBuilder, expr: CastExpr) -> Value: # Operators +@folding_candidate def transform_unary_expr(builder: IRBuilder, expr: UnaryExpr) -> Value: - folded = try_constant_fold(builder, expr) - if folded: - return folded - return builder.unary_op(builder.accept(expr.expr), expr.op, expr.line) @@ -581,6 +578,7 @@ def try_optimize_int_floor_divide(builder: IRBuilder, expr: OpExpr) -> OpExpr: return expr +@folding_candidate def transform_index_expr(builder: IRBuilder, expr: IndexExpr) -> Value: index = expr.index base_type = builder.node_type(expr.base) @@ -607,17 +605,6 @@ def transform_index_expr(builder: IRBuilder, expr: IndexExpr) -> Value: ) -def try_constant_fold(builder: IRBuilder, expr: Expression) -> Value | None: - """Return the constant value of an expression if possible. - - Return None otherwise. - """ - value = constant_fold_expr(builder, expr) - if value is not None: - return builder.load_literal_value(value) - return None - - def try_gen_slice_op(builder: IRBuilder, base: Value, index: SliceExpr) -> Value | None: """Generate specialized slice op for some index expressions. diff --git a/mypyc/test-data/irbuild-constant-fold.test b/mypyc/test-data/irbuild-constant-fold.test index cd953c84c541..eb84a74fdd3f 100644 --- a/mypyc/test-data/irbuild-constant-fold.test +++ b/mypyc/test-data/irbuild-constant-fold.test @@ -478,3 +478,66 @@ L0: r3 = (-1.5+2j) neg_2 = r3 return 1 + +[case testIndexExprConstantFolding] +from typing import Final + +long_string: Final = "long string" + +def pos_index() -> None: + a = long_string[5] +def neg_index() -> None: + a = long_string[-5] +def slice_index() -> None: + a = long_string[5:] +def full_slice() -> None: + a = long_string[:] +def prefix_slice() -> None: + a = long_string[:5] +def mid_slice() -> None: + a = long_string[3:5] +def negative_slice() -> None: + a = long_string[-6:-1] +[out] +def pos_index(): + r0, a :: str +L0: + r0 = 's' + a = r0 + return 1 +def neg_index(): + r0, a :: str +L0: + r0 = 't' + a = r0 + return 1 +def slice_index(): + r0, a :: str +L0: + r0 = 'string' + a = r0 + return 1 +def full_slice(): + r0, a :: str +L0: + r0 = 'long string' + a = r0 + return 1 +def prefix_slice(): + r0, a :: str +L0: + r0 = 'long ' + a = r0 + return 1 +def mid_slice(): + r0, a :: str +L0: + r0 = 'g ' + a = r0 + return 1 +def negative_slice(): + r0, a :: str +L0: + r0 = 'strin' + a = r0 + return 1 diff --git a/mypyc/test-data/run-strings.test b/mypyc/test-data/run-strings.test index 0ae67ed7f1c3..954f4659c690 100644 --- a/mypyc/test-data/run-strings.test +++ b/mypyc/test-data/run-strings.test @@ -151,7 +151,7 @@ def test_unicode() -> None: assert ne("\U0001f4a9foo", "\U0001f4a8foo" + str()) [case testStringOps] -from typing import List, Optional, Tuple +from typing import Final, List, Optional, Tuple from testutil import assertRaises def do_split(s: str, sep: Optional[str] = None, max_split: Optional[int] = None) -> List[str]: @@ -226,6 +226,19 @@ def contains(s: str, o: str) -> bool: def getitem(s: str, index: int) -> str: return s[index] +final_string: Final = "abc" +final_int: Final = 1 + +def getitem_folded() -> str: + return ( + final_string[final_int] + + final_string[-1] + + final_string[:] + + final_string[:2] + + final_string[1:3] + + final_string[-3:-1] + ) + def find(s: str, substr: str, start: Optional[int] = None, end: Optional[int] = None) -> int: if start is not None: if end is not None: @@ -263,6 +276,7 @@ def test_getitem() -> None: getitem(s, 4) with assertRaises(IndexError, "string index out of range"): getitem(s, -4) + assert getitem_folded() == "bcabcabbcab", getitem_folded() def test_find() -> None: s = "abcab"