diff --git a/mypy/build.py b/mypy/build.py index c2e378290274..9b6e4d54d1b1 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -2397,6 +2397,8 @@ def parse_file(self, *, temporary: bool = False) -> None: self.source_hash = compute_hash(source) self.parse_inline_configuration(source) + self.check_for_invalid_options() + self.size_hint = len(source) if not cached: self.tree = manager.parse_file( @@ -2447,6 +2449,13 @@ def parse_inline_configuration(self, source: str) -> None: for lineno, error in config_errors: self.manager.errors.report(lineno, 0, error) + def check_for_invalid_options(self) -> None: + if self.options.mypyc and not self.options.strict_bytes: + self.manager.errors.set_file(self.xpath, self.id, options=self.options) + self.manager.errors.report( + 1, 0, "Option --strict-bytes cannot be disabled when using mypyc", blocker=True + ) + def semantic_analysis_pass1(self) -> None: """Perform pass 1 of semantic analysis, which happens immediately after parsing. diff --git a/mypy/main.py b/mypy/main.py index 864fe4b4febf..926e72515d95 100644 --- a/mypy/main.py +++ b/mypy/main.py @@ -1371,6 +1371,7 @@ def process_options( fscache: FileSystemCache | None = None, program: str = "mypy", header: str = HEADER, + mypyc: bool = False, ) -> tuple[list[BuildSource], Options]: """Parse command line arguments. @@ -1398,6 +1399,9 @@ def process_options( options = Options() strict_option_set = False + if mypyc: + # Mypyc has strict_bytes enabled by default + options.strict_bytes = True def set_strict_flags() -> None: nonlocal strict_option_set diff --git a/mypy/typeshed/stubs/librt/librt/strings.pyi b/mypy/typeshed/stubs/librt/librt/strings.pyi index 89f5b843caf8..241f6a6fba5b 100644 --- a/mypy/typeshed/stubs/librt/librt/strings.pyi +++ b/mypy/typeshed/stubs/librt/librt/strings.pyi @@ -5,7 +5,7 @@ from mypy_extensions import i64, u8 @final class BytesWriter: def append(self, /, x: int) -> None: ... - def write(self, /, b: bytes) -> None: ... + def write(self, /, b: bytes | bytearray) -> None: ... def getvalue(self) -> bytes: ... def truncate(self, /, size: i64) -> None: ... def __len__(self) -> i64: ... diff --git a/mypyc/build.py b/mypyc/build.py index 757f0a49737a..566612ce9d33 100644 --- a/mypyc/build.py +++ b/mypyc/build.py @@ -189,7 +189,7 @@ def get_mypy_config( fscache: FileSystemCache | None, ) -> tuple[list[BuildSource], list[BuildSource], Options]: """Construct mypy BuildSources and Options from file and options lists""" - all_sources, options = process_options(mypy_options, fscache=fscache) + all_sources, options = process_options(mypy_options, fscache=fscache, mypyc=True) if only_compile_paths is not None: paths_set = set(only_compile_paths) mypyc_sources = [s for s in all_sources if s.path in paths_set] diff --git a/mypyc/codegen/emit.py b/mypyc/codegen/emit.py index 34837a73adbd..236da808c038 100644 --- a/mypyc/codegen/emit.py +++ b/mypyc/codegen/emit.py @@ -657,7 +657,7 @@ def emit_cast( elif is_bytes_rprimitive(typ): if declare_dest: self.emit_line(f"PyObject *{dest};") - check = "(PyBytes_Check({}) || PyByteArray_Check({}))" + check = "(PyBytes_Check({}))" if likely: check = f"(likely{check})" self.emit_arg_check(src, dest, typ, check.format(src, src), optional) diff --git a/mypyc/test-data/commandline.test b/mypyc/test-data/commandline.test index 392ad3620790..36f7f508daf0 100644 --- a/mypyc/test-data/commandline.test +++ b/mypyc/test-data/commandline.test @@ -312,3 +312,13 @@ print(type(Eggs(obj1=pkg1.A.B())["obj1"]).__module__) [out] B pkg2.mod2 + +[case testStrictBytesRequired] +# cmd: --no-strict-bytes a.py + +[file a.py] +def f(b: bytes) -> None: pass +f(bytearray()) + +[out] +a.py:1: error: Option --strict-bytes cannot be disabled when using mypyc diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index cce7af477d7a..c608a68c26e9 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -167,7 +167,7 @@ class bytes: def __init__(self) -> None: ... @overload def __init__(self, x: object) -> None: ... - def __add__(self, x: bytes) -> bytes: ... + def __add__(self, x: bytes | bytearray) -> bytes: ... def __mul__(self, x: int) -> bytes: ... def __rmul__(self, x: int) -> bytes: ... def __eq__(self, x: object) -> bool: ... @@ -178,8 +178,8 @@ def __getitem__(self, i: int) -> int: ... def __getitem__(self, i: slice) -> bytes: ... def join(self, x: Iterable[object]) -> bytes: ... def decode(self, encoding: str=..., errors: str=...) -> str: ... - def translate(self, t: bytes) -> bytes: ... - def startswith(self, t: bytes) -> bool: ... + def translate(self, t: bytes | bytearray) -> bytes: ... + def startswith(self, t: bytes | bytearray) -> bool: ... def __iter__(self) -> Iterator[int]: ... class bytearray: @@ -189,9 +189,12 @@ def __init__(self) -> None: pass def __init__(self, x: object) -> None: pass @overload def __init__(self, string: str, encoding: str, err: str = ...) -> None: pass - def __add__(self, s: bytes) -> bytearray: ... + def __add__(self, s: bytes | bytearray) -> bytearray: ... def __setitem__(self, i: int, o: int) -> None: ... + @overload def __getitem__(self, i: int) -> int: ... + @overload + def __getitem__(self, i: slice) -> bytearray: ... def decode(self, x: str = ..., y: str = ...) -> str: ... def startswith(self, t: bytes) -> bool: ... diff --git a/mypyc/test-data/irbuild-bytes.test b/mypyc/test-data/irbuild-bytes.test index 5e7c546eb25a..0c77c4bfbb69 100644 --- a/mypyc/test-data/irbuild-bytes.test +++ b/mypyc/test-data/irbuild-bytes.test @@ -261,3 +261,16 @@ L0: r0 = CPyBytes_Startswith(a, b) r1 = truncate r0: i32 to builtins.bool return r1 + +[case testBytesVsBytearray] +def bytes_func(b: bytes) -> None: pass +def bytearray_func(ba: bytearray) -> None: pass + +def foo(b: bytes, ba: bytearray) -> None: + bytes_func(b) + bytearray_func(ba) + bytes_func(ba) + bytearray_func(b) +[out] +main:7: error: Argument 1 to "bytes_func" has incompatible type "bytearray"; expected "bytes" +main:8: error: Argument 1 to "bytearray_func" has incompatible type "bytes"; expected "bytearray" diff --git a/mypyc/test-data/run-base64.test b/mypyc/test-data/run-base64.test index bf8ea4590e5e..c37fca9a23d4 100644 --- a/mypyc/test-data/run-base64.test +++ b/mypyc/test-data/run-base64.test @@ -1,5 +1,5 @@ [case testAllBase64Features_librt_experimental] -from typing import Any +from typing import Any, cast import base64 import binascii import random @@ -14,7 +14,7 @@ def test_encode_basic() -> None: assert b64encode(b"x") == b"eA==" with assertRaises(TypeError): - b64encode(bytearray(b"x")) + b64encode(cast(Any, bytearray(b"x"))) def check_encode(b: bytes) -> None: assert b64encode(b) == getattr(base64, "b64encode")(b) @@ -56,7 +56,7 @@ def test_decode_basic() -> None: assert b64decode(b"eA==") == b"x" with assertRaises(TypeError): - b64decode(bytearray(b"eA==")) + b64decode(cast(Any, bytearray(b"eA=="))) for non_ascii in "\x80", "foo\u100bar", "foo\ua1234bar": with assertRaises(ValueError): diff --git a/mypyc/test-data/run-bytes.test b/mypyc/test-data/run-bytes.test index 6e4b57152a4b..f5eb6abfe234 100644 --- a/mypyc/test-data/run-bytes.test +++ b/mypyc/test-data/run-bytes.test @@ -79,8 +79,8 @@ def test_concat() -> None: assert type(b1) == bytes assert type(b2) == bytes assert type(b3) == bytes - brr1: bytes = bytearray(3) - brr2: bytes = bytearray(range(5)) + brr1 = bytearray(3) + brr2 = bytearray(range(5)) b4 = b1 + brr1 assert b4 == b'123\x00\x00\x00' assert type(brr1) == bytearray @@ -94,9 +94,9 @@ def test_concat() -> None: b5 = brr2 + b2 assert b5 == bytearray(b'\x00\x01\x02\x03\x04456') assert type(b5) == bytearray - b5 = b2 + brr2 - assert b5 == b'456\x00\x01\x02\x03\x04' - assert type(b5) == bytes + b6 = b2 + brr2 + assert b6 == b'456\x00\x01\x02\x03\x04' + assert type(b6) == bytes def test_join() -> None: seq = (b'1', b'"', b'\xf0') @@ -217,9 +217,9 @@ def test_startswith() -> None: assert test.startswith(bytearray(b'some')) assert not test.startswith(bytearray(b'other')) - test = bytearray(b'some string') - assert test.startswith(b'some') - assert not test.startswith(b'other') + test2 = bytearray(b'some string') + assert test2.startswith(b'some') + assert not test2.startswith(b'other') [case testBytesSlicing] def test_bytes_slicing() -> None: @@ -257,34 +257,38 @@ def test_bytes_slicing() -> None: [case testBytearrayBasics] from typing import Any +from testutil import assertRaises + def test_basics() -> None: - brr1: bytes = bytearray(3) + brr1 = bytearray(3) assert brr1 == bytearray(b'\x00\x00\x00') assert brr1 == b'\x00\x00\x00' l = [10, 20, 30, 40] - brr2: bytes = bytearray(l) + brr2 = bytearray(l) assert brr2 == bytearray(b'\n\x14\x1e(') assert brr2 == b'\n\x14\x1e(' - brr3: bytes = bytearray(range(5)) + brr3 = bytearray(range(5)) assert brr3 == bytearray(b'\x00\x01\x02\x03\x04') assert brr3 == b'\x00\x01\x02\x03\x04' - brr4: bytes = bytearray('string', 'utf-8') + brr4 = bytearray('string', 'utf-8') assert brr4 == bytearray(b'string') assert brr4 == b'string' assert len(brr1) == 3 assert len(brr2) == 4 -def f(b: bytes) -> bool: - return True +def f(b: bytes) -> str: + return "xy" def test_bytearray_passed_into_bytes() -> None: - assert f(bytearray(3)) brr1: Any = bytearray() - assert f(brr1) + with assertRaises(TypeError, "bytes object expected; got bytearray"): + f(brr1) + with assertRaises(TypeError, "bytes object expected; got bytearray"): + b: bytes = brr1 [case testBytearraySlicing] def test_bytearray_slicing() -> None: - b: bytes = bytearray(b'abcdefg') + b = bytearray(b'abcdefg') zero = int() ten = 10 + zero two = 2 + zero @@ -318,7 +322,7 @@ def test_bytearray_slicing() -> None: from testutil import assertRaises def test_bytearray_indexing() -> None: - b: bytes = bytearray(b'\xae\x80\xfe\x15') + b = bytearray(b'\xae\x80\xfe\x15') assert b[0] == 174 assert b[1] == 128 assert b[2] == 254 @@ -347,10 +351,6 @@ def test_bytes_join() -> None: assert b' '.join([b'a', b'b']) == b'a b' assert b' '.join([]) == b'' - x: bytes = bytearray(b' ') - assert x.join([b'a', b'b']) == b'a b' - assert type(x.join([b'a', b'b'])) == bytearray - y: bytes = bytes_subclass() assert y.join([]) == b'spook' diff --git a/mypyc/test-data/run-strings.test b/mypyc/test-data/run-strings.test index 0ae67ed7f1c3..3c5a1f1d31e1 100644 --- a/mypyc/test-data/run-strings.test +++ b/mypyc/test-data/run-strings.test @@ -892,7 +892,7 @@ def test_decode_error() -> None: pass def test_decode_bytearray() -> None: - b: bytes = bytearray(b'foo\x00bar') + b = bytearray(b'foo\x00bar') assert b.decode() == 'foo\x00bar' assert b.decode('utf-8') == 'foo\x00bar' assert b.decode('latin-1') == 'foo\x00bar' @@ -900,7 +900,7 @@ def test_decode_bytearray() -> None: assert b.decode('utf-8' + str()) == 'foo\x00bar' assert b.decode('latin-1' + str()) == 'foo\x00bar' assert b.decode('ascii' + str()) == 'foo\x00bar' - b2: bytes = bytearray(b'foo\x00bar\xbe') + b2 = bytearray(b'foo\x00bar\xbe') assert b2.decode('latin-1') == 'foo\x00bar\xbe' with assertRaises(UnicodeDecodeError): b2.decode('ascii') @@ -910,7 +910,7 @@ def test_decode_bytearray() -> None: b2.decode('utf-8') with assertRaises(UnicodeDecodeError): b2.decode('utf-8' + str()) - b3: bytes = bytearray(b'Z\xc3\xbcrich') + b3 = bytearray(b'Z\xc3\xbcrich') assert b3.decode("utf-8") == 'Zürich' def test_invalid_encoding() -> None: diff --git a/mypyc/test/test_run.py b/mypyc/test/test_run.py index 681e15b58844..e37964f0be22 100644 --- a/mypyc/test/test_run.py +++ b/mypyc/test/test_run.py @@ -202,6 +202,9 @@ def run_case_step(self, testcase: DataDrivenTestCase, incremental_step: int) -> options.use_builtins_fixtures = True options.show_traceback = True options.strict_optional = True + options.strict_bytes = True + options.disable_bytearray_promotion = True + options.disable_memoryview_promotion = True options.python_version = sys.version_info[:2] options.export_types = True options.preserve_asts = True diff --git a/mypyc/test/testutil.py b/mypyc/test/testutil.py index 5e485e58c9b5..de9852496f75 100644 --- a/mypyc/test/testutil.py +++ b/mypyc/test/testutil.py @@ -114,6 +114,9 @@ def build_ir_for_single_file2( options.export_types = True options.preserve_asts = True options.allow_empty_bodies = True + options.strict_bytes = True + options.disable_bytearray_promotion = True + options.disable_memoryview_promotion = True options.per_module_options["__main__"] = {"mypyc": True} source = build.BuildSource("main", "__main__", program_text)