diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000000000..527b4d672c8d7 --- /dev/null +++ b/TODO.md @@ -0,0 +1,10 @@ +# TODO: Fix C string encoding in mypyc/codegen/cstring.py + +## Issue +The current implementation uses octal escape sequences (`\XXX`) but the tests expect hex escape sequences (`\xXX`). + +## Changes Needed +1. [x] Understand the expected behavior from tests in test_emitfunc.py +2. [ ] Update CHAR_MAP to use hex escapes instead of octal escapes +3. [ ] Keep simple escape sequences for special chars (\n, \r, \t, etc.) +4. [ ] Update the docstring to reflect correct format (\xXX instead of \oXXX) diff --git a/mypyc/codegen/cstring.py b/mypyc/codegen/cstring.py index 853787f8161d4..b6fb2e1a72498 100644 --- a/mypyc/codegen/cstring.py +++ b/mypyc/codegen/cstring.py @@ -1,8 +1,8 @@ """Encode valid C string literals from Python strings. If a character is not allowed in C string literals, it is either emitted -as a simple escape sequence (e.g. '\\n'), or an octal escape sequence -with exactly three digits ('\\oXXX'). Question marks are escaped to +as a simple escape sequence (e.g. '\\n'), or a hexadecimal escape sequence +with exactly two digits ('\\xXX'). Question marks are escaped to prevent trigraphs in the string literal from being interpreted. Note that '\\?' is an invalid escape sequence in Python. @@ -13,9 +13,9 @@ unexpectedly parsed as ['A', 'B', 0xCDEF]. Emitting ("AB\\xCD" "EF") would avoid this behaviour. However, we opt -for simplicity and use octal escape sequences instead. They do not -suffer from the same issue as they are defined to parse at most three -octal digits. +for simplicity and use hexadecimal escape sequences with exactly two +digits instead. They do not suffer from the same issue as they are +defined to parse exactly two hexadecimal digits. """ from __future__ import annotations @@ -23,7 +23,7 @@ import string from typing import Final -CHAR_MAP: Final = [f"\\{i:03o}" for i in range(256)] +CHAR_MAP: Final = [f"\\{i:02x}" for i in range(256)] # It is safe to use string.printable as it always uses the C locale. for c in string.printable: diff --git a/mypyc/codegen/emit.py b/mypyc/codegen/emit.py index e313c9231564d..7336816dc6fc4 100644 --- a/mypyc/codegen/emit.py +++ b/mypyc/codegen/emit.py @@ -1340,7 +1340,7 @@ def emit_type_error_traceback( src: str, ) -> None: func = "CPy_TypeErrorTraceback" - type_str = f'"{self.pretty_name(typ)}"' + type_str = c_string_initializer(self.pretty_name(typ).encode("utf-8")) return self._emit_traceback( func, source_path, module_name, traceback_entry, type_str=type_str, src=src ) @@ -1357,10 +1357,10 @@ def _emit_traceback( if self.context.strict_traceback_checks: assert traceback_entry[1] >= 0, "Traceback cannot have a negative line number" globals_static = self.static_name("globals", module_name) - line = '%s("%s", "%s", %d, %s' % ( + line = "%s(%s, %s, %d, %s" % ( func, - source_path.replace("\\", "\\\\"), - traceback_entry[0], + c_string_initializer(source_path.encode("utf-8")), + c_string_initializer(traceback_entry[0].encode("utf-8")), traceback_entry[1], globals_static, ) diff --git a/mypyc/codegen/emitfunc.py b/mypyc/codegen/emitfunc.py index c1202d1c928ca..cdd8a3c17e251 100644 --- a/mypyc/codegen/emitfunc.py +++ b/mypyc/codegen/emitfunc.py @@ -2,9 +2,8 @@ from __future__ import annotations -from typing import Final - from mypyc.analysis.blockfreq import frequently_executed_blocks +from mypyc.codegen.cstring import c_string_initializer from mypyc.codegen.emit import ( DEBUG_ERRORS, PREFIX_MAP, @@ -680,8 +679,8 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> None: # TODO: Better escaping of backspaces and such if op.value is not None: if isinstance(op.value, str): - message = op.value.replace('"', '\\"') - self.emitter.emit_line(f'PyErr_SetString(PyExc_{op.class_name}, "{message}");') + c_str = c_string_initializer(op.value.encode("utf-8", "surrogatepass")) + self.emitter.emit_line(f"PyErr_SetString(PyExc_{op.class_name}, {c_str});") elif isinstance(op.value, Value): self.emitter.emit_line( "PyErr_SetObject(PyExc_{}, {});".format( @@ -895,7 +894,7 @@ def reg(self, reg: Value) -> str: return "NAN" return r elif isinstance(reg, CString): - return '"' + encode_c_string_literal(reg.value) + '"' + return c_string_initializer(reg.value) else: return self.emitter.reg(reg) @@ -935,12 +934,14 @@ def emit_attribute_error(self, op: Branch, class_name: str, attr: str) -> None: ), "AttributeError traceback cannot have a negative line number" globals_static = self.emitter.static_name("globals", self.module_name) self.emit_line( - 'CPy_AttributeError("%s", "%s", "%s", "%s", %d, %s);' + "CPy_AttributeError(%s, %s, %s, %s, %d, %s);" % ( - self.source_path.replace("\\", "\\\\"), - op.traceback_entry[0], - class_name, - attr.removeprefix(GENERATOR_ATTRIBUTE_PREFIX), + c_string_initializer(self.source_path.encode("utf-8")), + c_string_initializer(op.traceback_entry[0].encode("utf-8")), + c_string_initializer(class_name.encode("utf-8")), + c_string_initializer( + attr.removeprefix(GENERATOR_ATTRIBUTE_PREFIX).encode("utf-8") + ), op.traceback_entry[1], globals_static, ) @@ -961,30 +962,3 @@ def emit_unsigned_int_cast(self, type: RType) -> str: return "(uint64_t)" else: return "" - - -_translation_table: Final[dict[int, str]] = {} - - -def encode_c_string_literal(b: bytes) -> str: - """Convert bytestring to the C string literal syntax (with necessary escaping). - - For example, b'foo\n' gets converted to 'foo\\n' (note that double quotes are not added). - """ - if not _translation_table: - # Initialize the translation table on the first call. - d = { - ord("\n"): "\\n", - ord("\r"): "\\r", - ord("\t"): "\\t", - ord('"'): '\\"', - ord("\\"): "\\\\", - } - for i in range(256): - if i not in d: - if i < 32 or i >= 127: - d[i] = "\\x%.2x" % i - else: - d[i] = chr(i) - _translation_table.update(str.maketrans(d)) - return b.decode("latin1").translate(_translation_table) diff --git a/mypyc/codegen/emitwrapper.py b/mypyc/codegen/emitwrapper.py index 9118f0d5bc25e..7e664b5929de9 100644 --- a/mypyc/codegen/emitwrapper.py +++ b/mypyc/codegen/emitwrapper.py @@ -16,6 +16,7 @@ from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, ARG_OPT, ARG_POS, ARG_STAR, ARG_STAR2, ArgKind from mypy.operators import op_methods_to_symbols, reverse_op_method_names, reverse_op_methods +from mypyc.codegen.cstring import c_string_initializer from mypyc.codegen.emit import AssignHandler, Emitter, ErrorHandler, GotoHandler, ReturnHandler from mypyc.common import ( BITMAP_BITS, @@ -77,9 +78,9 @@ def generate_traceback_code( # even if there is no `traceback_name`. This is because the error will # have originated here and so we need it in the traceback. globals_static = emitter.static_name("globals", module_name) - traceback_code = 'CPy_AddTraceback("%s", "%s", %d, %s);' % ( - source_path.replace("\\", "\\\\"), - fn.traceback_name or fn.name, + traceback_code = "CPy_AddTraceback(%s, %s, %d, %s);" % ( + c_string_initializer(source_path.encode("utf-8")), + c_string_initializer((fn.traceback_name or fn.name).encode("utf-8")), fn.line, globals_static, ) @@ -97,7 +98,7 @@ def reorder_arg_groups(groups: dict[ArgKind, list[RuntimeArg]]) -> list[RuntimeA def make_static_kwlist(args: list[RuntimeArg]) -> str: - arg_names = "".join(f'"{arg.name}", ' for arg in args) + arg_names = "".join(f'{c_string_initializer(arg.name.encode("utf-8"))}, ' for arg in args) return f"static const char * const kwlist[] = {{{arg_names}0}};" @@ -158,7 +159,8 @@ def generate_wrapper_function( emitter.emit_line(make_static_kwlist(reordered_args)) fmt = make_format_string(fn.name, groups) # Define the arguments the function accepts (but no types yet) - emitter.emit_line(f'static CPyArg_Parser parser = {{"{fmt}", kwlist, 0}};') + fmt_c = c_string_initializer(fmt.encode("utf-8")) + emitter.emit_line(f"static CPyArg_Parser parser = {{{fmt_c}, kwlist, 0}};") for arg in real_args: emitter.emit_line( @@ -263,8 +265,10 @@ def generate_legacy_wrapper_function( arg_ptrs += [f"&obj_{arg.name}" for arg in reordered_args] emitter.emit_lines( - 'if (!CPyArg_ParseTupleAndKeywords(args, kw, "{}", "{}", kwlist{})) {{'.format( - make_format_string(None, groups), fn.name, "".join(", " + n for n in arg_ptrs) + "if (!CPyArg_ParseTupleAndKeywords(args, kw, {}, {}, kwlist{})) {{".format( + c_string_initializer(make_format_string(None, groups).encode("utf-8")), + c_string_initializer(fn.name.encode("utf-8")), + "".join(", " + n for n in arg_ptrs), ), "return NULL;", "}",