Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# TODO: Fix C string encoding in mypyc/codegen/cstring.py

## Issue
The current implementation uses octal escape sequences (`\XXX`) but the tests expect hex escape sequences (`\xXX`).

## Changes Needed
1. [x] Understand the expected behavior from tests in test_emitfunc.py
2. [ ] Update CHAR_MAP to use hex escapes instead of octal escapes
3. [ ] Keep simple escape sequences for special chars (\n, \r, \t, etc.)
4. [ ] Update the docstring to reflect correct format (\xXX instead of \oXXX)
12 changes: 6 additions & 6 deletions mypyc/codegen/cstring.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Encode valid C string literals from Python strings.

If a character is not allowed in C string literals, it is either emitted
as a simple escape sequence (e.g. '\\n'), or an octal escape sequence
with exactly three digits ('\\oXXX'). Question marks are escaped to
as a simple escape sequence (e.g. '\\n'), or a hexadecimal escape sequence
with exactly two digits ('\\xXX'). Question marks are escaped to
prevent trigraphs in the string literal from being interpreted. Note
that '\\?' is an invalid escape sequence in Python.

Expand All @@ -13,17 +13,17 @@
unexpectedly parsed as ['A', 'B', 0xCDEF].

Emitting ("AB\\xCD" "EF") would avoid this behaviour. However, we opt
for simplicity and use octal escape sequences instead. They do not
suffer from the same issue as they are defined to parse at most three
octal digits.
for simplicity and use hexadecimal escape sequences with exactly two
digits instead. They do not suffer from the same issue as they are
defined to parse exactly two hexadecimal digits.
"""

from __future__ import annotations

import string
from typing import Final

CHAR_MAP: Final = [f"\\{i:03o}" for i in range(256)]
CHAR_MAP: Final = [f"\\{i:02x}" for i in range(256)]

# It is safe to use string.printable as it always uses the C locale.
for c in string.printable:
Expand Down
8 changes: 4 additions & 4 deletions mypyc/codegen/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,7 +1340,7 @@ def emit_type_error_traceback(
src: str,
) -> None:
func = "CPy_TypeErrorTraceback"
type_str = f'"{self.pretty_name(typ)}"'
type_str = c_string_initializer(self.pretty_name(typ).encode("utf-8"))
return self._emit_traceback(
func, source_path, module_name, traceback_entry, type_str=type_str, src=src
)
Expand All @@ -1357,10 +1357,10 @@ def _emit_traceback(
if self.context.strict_traceback_checks:
assert traceback_entry[1] >= 0, "Traceback cannot have a negative line number"
globals_static = self.static_name("globals", module_name)
line = '%s("%s", "%s", %d, %s' % (
line = "%s(%s, %s, %d, %s" % (
func,
source_path.replace("\\", "\\\\"),
traceback_entry[0],
c_string_initializer(source_path.encode("utf-8")),
c_string_initializer(traceback_entry[0].encode("utf-8")),
traceback_entry[1],
globals_static,
)
Expand Down
48 changes: 11 additions & 37 deletions mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

from __future__ import annotations

from typing import Final

from mypyc.analysis.blockfreq import frequently_executed_blocks
from mypyc.codegen.cstring import c_string_initializer
from mypyc.codegen.emit import (
DEBUG_ERRORS,
PREFIX_MAP,
Expand Down Expand Up @@ -680,8 +679,8 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> None:
# TODO: Better escaping of backspaces and such
if op.value is not None:
if isinstance(op.value, str):
message = op.value.replace('"', '\\"')
self.emitter.emit_line(f'PyErr_SetString(PyExc_{op.class_name}, "{message}");')
c_str = c_string_initializer(op.value.encode("utf-8", "surrogatepass"))
self.emitter.emit_line(f"PyErr_SetString(PyExc_{op.class_name}, {c_str});")
elif isinstance(op.value, Value):
self.emitter.emit_line(
"PyErr_SetObject(PyExc_{}, {});".format(
Expand Down Expand Up @@ -895,7 +894,7 @@ def reg(self, reg: Value) -> str:
return "NAN"
return r
elif isinstance(reg, CString):
return '"' + encode_c_string_literal(reg.value) + '"'
return c_string_initializer(reg.value)
else:
return self.emitter.reg(reg)

Expand Down Expand Up @@ -935,12 +934,14 @@ def emit_attribute_error(self, op: Branch, class_name: str, attr: str) -> None:
), "AttributeError traceback cannot have a negative line number"
globals_static = self.emitter.static_name("globals", self.module_name)
self.emit_line(
'CPy_AttributeError("%s", "%s", "%s", "%s", %d, %s);'
"CPy_AttributeError(%s, %s, %s, %s, %d, %s);"
% (
self.source_path.replace("\\", "\\\\"),
op.traceback_entry[0],
class_name,
attr.removeprefix(GENERATOR_ATTRIBUTE_PREFIX),
c_string_initializer(self.source_path.encode("utf-8")),
c_string_initializer(op.traceback_entry[0].encode("utf-8")),
c_string_initializer(class_name.encode("utf-8")),
c_string_initializer(
attr.removeprefix(GENERATOR_ATTRIBUTE_PREFIX).encode("utf-8")
),
op.traceback_entry[1],
globals_static,
)
Expand All @@ -961,30 +962,3 @@ def emit_unsigned_int_cast(self, type: RType) -> str:
return "(uint64_t)"
else:
return ""


_translation_table: Final[dict[int, str]] = {}


def encode_c_string_literal(b: bytes) -> str:
"""Convert bytestring to the C string literal syntax (with necessary escaping).

For example, b'foo\n' gets converted to 'foo\\n' (note that double quotes are not added).
"""
if not _translation_table:
# Initialize the translation table on the first call.
d = {
ord("\n"): "\\n",
ord("\r"): "\\r",
ord("\t"): "\\t",
ord('"'): '\\"',
ord("\\"): "\\\\",
}
for i in range(256):
if i not in d:
if i < 32 or i >= 127:
d[i] = "\\x%.2x" % i
else:
d[i] = chr(i)
_translation_table.update(str.maketrans(d))
return b.decode("latin1").translate(_translation_table)
18 changes: 11 additions & 7 deletions mypyc/codegen/emitwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, ARG_OPT, ARG_POS, ARG_STAR, ARG_STAR2, ArgKind
from mypy.operators import op_methods_to_symbols, reverse_op_method_names, reverse_op_methods
from mypyc.codegen.cstring import c_string_initializer
from mypyc.codegen.emit import AssignHandler, Emitter, ErrorHandler, GotoHandler, ReturnHandler
from mypyc.common import (
BITMAP_BITS,
Expand Down Expand Up @@ -77,9 +78,9 @@ def generate_traceback_code(
# even if there is no `traceback_name`. This is because the error will
# have originated here and so we need it in the traceback.
globals_static = emitter.static_name("globals", module_name)
traceback_code = 'CPy_AddTraceback("%s", "%s", %d, %s);' % (
source_path.replace("\\", "\\\\"),
fn.traceback_name or fn.name,
traceback_code = "CPy_AddTraceback(%s, %s, %d, %s);" % (
c_string_initializer(source_path.encode("utf-8")),
c_string_initializer((fn.traceback_name or fn.name).encode("utf-8")),
fn.line,
globals_static,
)
Expand All @@ -97,7 +98,7 @@ def reorder_arg_groups(groups: dict[ArgKind, list[RuntimeArg]]) -> list[RuntimeA


def make_static_kwlist(args: list[RuntimeArg]) -> str:
arg_names = "".join(f'"{arg.name}", ' for arg in args)
arg_names = "".join(f'{c_string_initializer(arg.name.encode("utf-8"))}, ' for arg in args)
return f"static const char * const kwlist[] = {{{arg_names}0}};"


Expand Down Expand Up @@ -158,7 +159,8 @@ def generate_wrapper_function(
emitter.emit_line(make_static_kwlist(reordered_args))
fmt = make_format_string(fn.name, groups)
# Define the arguments the function accepts (but no types yet)
emitter.emit_line(f'static CPyArg_Parser parser = {{"{fmt}", kwlist, 0}};')
fmt_c = c_string_initializer(fmt.encode("utf-8"))
emitter.emit_line(f"static CPyArg_Parser parser = {{{fmt_c}, kwlist, 0}};")

for arg in real_args:
emitter.emit_line(
Expand Down Expand Up @@ -263,8 +265,10 @@ def generate_legacy_wrapper_function(
arg_ptrs += [f"&obj_{arg.name}" for arg in reordered_args]

emitter.emit_lines(
'if (!CPyArg_ParseTupleAndKeywords(args, kw, "{}", "{}", kwlist{})) {{'.format(
make_format_string(None, groups), fn.name, "".join(", " + n for n in arg_ptrs)
"if (!CPyArg_ParseTupleAndKeywords(args, kw, {}, {}, kwlist{})) {{".format(
c_string_initializer(make_format_string(None, groups).encode("utf-8")),
c_string_initializer(fn.name.encode("utf-8")),
"".join(", " + n for n in arg_ptrs),
),
"return NULL;",
"}",
Expand Down
Loading