diff --git a/mypyc/codegen/emit.py b/mypyc/codegen/emit.py index 34837a73adbd..cf648001b824 100644 --- a/mypyc/codegen/emit.py +++ b/mypyc/codegen/emit.py @@ -6,7 +6,7 @@ import sys import textwrap from collections.abc import Callable -from typing import Final +from typing import TYPE_CHECKING, Final from mypyc.codegen.cstring import c_string_initializer from mypyc.codegen.literals import Literals @@ -59,6 +59,9 @@ from mypyc.namegen import NameGenerator, exported_name from mypyc.sametype import is_same_type +if TYPE_CHECKING: + from _typeshed import SupportsWrite + # Whether to insert debug asserts for all error handling, to quickly # catch errors propagating without exceptions set. DEBUG_ERRORS: Final = False @@ -210,7 +213,8 @@ def object_annotation(self, obj: object, line: str) -> str: If it contains illegal characters, an empty string is returned.""" line_width = self._indent + len(line) - formatted = pprint.pformat(obj, compact=True, width=max(90 - line_width, 20)) + formatted = pformat_deterministic(obj, max(90 - line_width, 20)) + if any(x in formatted for x in ("/*", "*/", "\0")): return "" @@ -1271,3 +1275,62 @@ def native_function_doc_initializer(func: FuncIR) -> str: return "NULL" docstring = f"{text_sig}\n--\n\n" return c_string_initializer(docstring.encode("ascii", errors="backslashreplace")) + + +def pformat_deterministic(obj: object, width: int) -> str: + """Pretty-print `obj` with deterministic sorting for mypyc literal types.""" + # Temporarily override pprint._safe_key to get deterministic ordering of containers. + default_safe_key = pprint._safe_key # type: ignore [attr-defined] + pprint._safe_key = _mypyc_safe_key # type: ignore [attr-defined] + + try: + printer = _DeterministicPrettyPrinter(width=width, compact=True, sort_dicts=True) + return printer.pformat(obj) + finally: + # Always restore the original key to avoid affecting other pprint users. + pprint._safe_key = default_safe_key # type: ignore [attr-defined] + + +def _mypyc_safe_key(obj: object) -> str: + """A custom sort key implementation for pprint that makes the output deterministic + for all literal types supported by mypyc. + + This is NOT safe for use as a sort key for other types, so we MUST replace the + original pprint._safe_key once we've pprinted our object. + + Since this is a bit hacky, see for context https://github.com/python/mypy/pull/20012 + """ + return str(type(obj)) + pprint.pformat(obj, compact=True, sort_dicts=True) + + +class _DeterministicPrettyPrinter(pprint.PrettyPrinter): + """PrettyPrinter that sorts set/frozenset elements deterministically.""" + + _dispatch = pprint.PrettyPrinter._dispatch.copy() + + def _pprint_set( + self, + object: set[object] | frozenset[object], + stream: SupportsWrite[str], + indent: int, + allowance: int, + context: dict[int, int], + level: int, + ) -> None: + if not object: + stream.write(repr(object)) + return + typ = type(object) + if typ is set: + stream.write("{") + endchar = "}" + else: + stream.write("frozenset({") + endchar = "})" + indent += len("frozenset(") + items = sorted(object, key=_mypyc_safe_key) + self._format_items(items, stream, indent, allowance + len(endchar), context, level) + stream.write(endchar) + + _dispatch[set.__repr__] = _pprint_set + _dispatch[frozenset.__repr__] = _pprint_set diff --git a/mypyc/test/test_emit.py b/mypyc/test/test_emit.py index 1baed3964299..f52da1cd8757 100644 --- a/mypyc/test/test_emit.py +++ b/mypyc/test/test_emit.py @@ -1,8 +1,9 @@ from __future__ import annotations +import pprint import unittest -from mypyc.codegen.emit import Emitter, EmitterContext +from mypyc.codegen.emit import Emitter, EmitterContext, pformat_deterministic from mypyc.common import HAVE_IMMORTAL from mypyc.ir.class_ir import ClassIR from mypyc.ir.ops import BasicBlock, Register, Value @@ -21,6 +22,34 @@ from mypyc.namegen import NameGenerator +class TestPformatDeterministic(unittest.TestCase): + def test_frozenset_elements_sorted(self) -> None: + fs_small = frozenset({("a", 1)}) + fs_large = frozenset({("a", 1), ("b", 2)}) + literal_a = frozenset({fs_large, fs_small}) + literal_b = frozenset({fs_small, fs_large}) + expected = "frozenset({frozenset({('b', 2), ('a', 1)}), frozenset({('a', 1)})})" + + assert pformat_deterministic(literal_a, 80) == expected + assert pformat_deterministic(literal_b, 80) == expected + + def test_nested_supported_literals(self) -> None: + nested_frozen = frozenset({("m", 0), ("n", 1)}) + item_a = ("outer", 1, nested_frozen) + item_b = ("outer", 2, frozenset({("x", 3)})) + literal_a = frozenset({item_a, item_b}) + literal_b = frozenset({item_b, item_a}) + expected = "frozenset({('outer', 2, frozenset({('x', 3)})), ('outer', 1, frozenset({('n', 1), ('m', 0)}))})" + + assert pformat_deterministic(literal_a, 120) == expected + assert pformat_deterministic(literal_b, 120) == expected + + def test_restores_default_safe_key(self) -> None: + original_safe_key = pprint._safe_key + pformat_deterministic({"key": "value"}, 80) + assert pprint._safe_key is original_safe_key + + class TestEmitter(unittest.TestCase): def setUp(self) -> None: self.n = Register(int_rprimitive, "n")