diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index fc4906f41916..dc5413008a62 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -784,8 +784,11 @@ PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter); CPyTagged CPyBytes_Ord(PyObject *obj); PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count); int CPyBytes_Startswith(PyObject *self, PyObject *subobj); - int CPyBytes_Compare(PyObject *left, PyObject *right); +PyObject *CPyBytes_LjustDefaultFill(PyObject *self, CPyTagged width); +PyObject *CPyBytes_RjustDefaultFill(PyObject *self, CPyTagged width); +PyObject *CPyBytes_LjustCustomFill(PyObject *self, CPyTagged width, PyObject *fillbyte); +PyObject *CPyBytes_RjustCustomFill(PyObject *self, CPyTagged width, PyObject *fillbyte); // Set operations diff --git a/mypyc/lib-rt/bytes_ops.c b/mypyc/lib-rt/bytes_ops.c index a9b694116866..ba118c2ac6c6 100644 --- a/mypyc/lib-rt/bytes_ops.c +++ b/mypyc/lib-rt/bytes_ops.c @@ -4,6 +4,7 @@ #include #include "CPy.h" +#include // Returns -1 on error, 0 on inequality, 1 on equality. // @@ -158,6 +159,101 @@ CPyTagged CPyBytes_Ord(PyObject *obj) { return CPY_INT_TAG; } + +PyObject *CPyBytes_RjustDefaultFill(PyObject *self, CPyTagged width) { + if (!PyBytes_Check(self)) { + PyErr_SetString(PyExc_TypeError, "self must be bytes"); + return NULL; + } + Py_ssize_t width_size_t = CPyTagged_AsSsize_t(width); + Py_ssize_t len = PyBytes_Size(self); + if (width_size_t <= len) { + Py_INCREF(self); + return self; + } + Py_ssize_t pad = width_size_t - len; + PyObject *result = PyBytes_FromStringAndSize(NULL, width_size_t); + if (!result) return NULL; + char *res_buf = PyBytes_AsString(result); + memset(res_buf, ' ', pad); + memcpy(res_buf + pad, PyBytes_AsString(self), len); + return result; +} + + +PyObject *CPyBytes_RjustCustomFill(PyObject *self, CPyTagged width, PyObject *fillbyte) { + if (!PyBytes_Check(self)) { + PyErr_SetString(PyExc_TypeError, "self must be bytes"); + return NULL; + } + if (!PyBytes_Check(fillbyte) || PyBytes_Size(fillbyte) != 1) { + PyErr_SetString(PyExc_TypeError, "fillbyte must be a single byte"); + return NULL; + } + Py_ssize_t width_size_t = CPyTagged_AsSsize_t(width); + Py_ssize_t len = PyBytes_Size(self); + if (width_size_t <= len) { + Py_INCREF(self); + return self; + } + char fill = PyBytes_AsString(fillbyte)[0]; + Py_ssize_t pad = width_size_t - len; + PyObject *result = PyBytes_FromStringAndSize(NULL, width_size_t); + if (!result) return NULL; + char *res_buf = PyBytes_AsString(result); + memset(res_buf, fill, pad); + memcpy(res_buf + pad, PyBytes_AsString(self), len); + return result; +} + + +PyObject *CPyBytes_LjustDefaultFill(PyObject *self, CPyTagged width) { + if (!PyBytes_Check(self)) { + PyErr_SetString(PyExc_TypeError, "self must be bytes"); + return NULL; + } + Py_ssize_t width_size_t = CPyTagged_AsSsize_t(width); + Py_ssize_t len = PyBytes_Size(self); + if (width_size_t <= len) { + Py_INCREF(self); + return self; + } + Py_ssize_t pad = width_size_t - len; + PyObject *result = PyBytes_FromStringAndSize(NULL, width_size_t); + if (!result) return NULL; + char *res_buf = PyBytes_AsString(result); + memcpy(res_buf, PyBytes_AsString(self), len); + memset(res_buf + len, ' ', pad); + return result; +} + + +PyObject *CPyBytes_LjustCustomFill(PyObject *self, CPyTagged width, PyObject *fillbyte) { + if (!PyBytes_Check(self)) { + PyErr_SetString(PyExc_TypeError, "self must be bytes"); + return NULL; + } + if (!PyBytes_Check(fillbyte) || PyBytes_Size(fillbyte) != 1) { + PyErr_SetString(PyExc_TypeError, "fillbyte must be a single byte"); + return NULL; + } + Py_ssize_t width_size_t = CPyTagged_AsSsize_t(width); + Py_ssize_t len = PyBytes_Size(self); + if (width_size_t <= len) { + Py_INCREF(self); + return self; + } + char fill = PyBytes_AsString(fillbyte)[0]; + Py_ssize_t pad = width_size_t - len; + PyObject *result = PyBytes_FromStringAndSize(NULL, width_size_t); + if (!result) return NULL; + char *res_buf = PyBytes_AsString(result); + memcpy(res_buf, PyBytes_AsString(self), len); + memset(res_buf + len, fill, pad); + return result; +} + + PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count) { Py_ssize_t temp_count = CPyTagged_AsSsize_t(count); if (temp_count == -1 && PyErr_Occurred()) { diff --git a/mypyc/primitives/bytes_ops.py b/mypyc/primitives/bytes_ops.py index 728da4181135..0a5f0acd36d8 100644 --- a/mypyc/primitives/bytes_ops.py +++ b/mypyc/primitives/bytes_ops.py @@ -167,3 +167,39 @@ c_function_name="CPyBytes_Ord", error_kind=ERR_MAGIC, ) + +# bytes.rjust(width) +method_op( + name="rjust", + arg_types=[bytes_rprimitive, int_rprimitive], + return_type=bytes_rprimitive, + c_function_name="CPyBytes_RjustDefaultFill", + error_kind=ERR_MAGIC, +) + +# bytes.rjust(width, fillbyte) +method_op( + name="rjust", + arg_types=[bytes_rprimitive, int_rprimitive, bytes_rprimitive], + return_type=bytes_rprimitive, + c_function_name="CPyBytes_RjustCustomFill", + error_kind=ERR_MAGIC, +) + +# bytes.ljust(width) +method_op( + name="ljust", + arg_types=[bytes_rprimitive, int_rprimitive], + return_type=bytes_rprimitive, + c_function_name="CPyBytes_LjustDefaultFill", + error_kind=ERR_MAGIC, +) + +# bytes.ljust(width, fillbyte) +method_op( + name="ljust", + arg_types=[bytes_rprimitive, int_rprimitive, bytes_rprimitive], + return_type=bytes_rprimitive, + c_function_name="CPyBytes_LjustCustomFill", + error_kind=ERR_MAGIC, +) diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index 592f6676e95e..c70235771fd8 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -181,6 +181,8 @@ def decode(self, encoding: str=..., errors: str=...) -> str: ... def translate(self, t: bytes) -> bytes: ... def startswith(self, t: bytes) -> bool: ... def __iter__(self) -> Iterator[int]: ... + def ljust(self, width: int, fillchar: bytes | bytearray = b" ") -> bytes: ... + def rjust(self, width: int, fillchar: bytes | bytearray = b" ") -> bytes: ... class bytearray: @overload diff --git a/mypyc/test-data/irbuild-bytes.test b/mypyc/test-data/irbuild-bytes.test index 5e7c546eb25a..91c15c4b97f8 100644 --- a/mypyc/test-data/irbuild-bytes.test +++ b/mypyc/test-data/irbuild-bytes.test @@ -218,6 +218,48 @@ L3: keep_alive y return r2 +[case testBytesRjustDefault] +def f(b: bytes) -> bytes: + return b.rjust(6) +[out] +def f(b): + b, r0 :: bytes +L0: + r0 = CPyBytes_RjustDefaultFill(b, 12) + return r0 + +[case testBytesRjustCustom] +def f(b: bytes) -> bytes: + return b.rjust(8, b'0') +[out] +def f(b): + b, r0, r1 :: bytes +L0: + r0 = b'0' + r1 = CPyBytes_RjustCustomFill(b, 16, r0) + return r1 + +[case testBytesLjustDefault] +def f(b: bytes) -> bytes: + return b.ljust(7) +[out] +def f(b): + b, r0 :: bytes +L0: + r0 = CPyBytes_LjustDefaultFill(b, 14) + return r0 + +[case testBytesLjustCustom] +def f(b: bytes) -> bytes: + return b.ljust(10, b'_') +[out] +def f(b): + b, r0, r1 :: bytes +L0: + r0 = b'_' + r1 = CPyBytes_LjustCustomFill(b, 20, r0) + return r1 + [case testBytesMultiply] def b_times_i(s: bytes, n: int) -> bytes: return s * n diff --git a/mypyc/test-data/run-bytes.test b/mypyc/test-data/run-bytes.test index 6e4b57152a4b..6ed0a04b3843 100644 --- a/mypyc/test-data/run-bytes.test +++ b/mypyc/test-data/run-bytes.test @@ -488,3 +488,43 @@ def test_optional_ne() -> None: assert ne_b_b_opt(b'x', b'y') assert ne_b_b_opt(b'y', b'x') assert ne_b_b_opt(b'x', None) + +[case testBytesRjustLjust] +from testutil import assertRaises + +def rjust_bytes(b: bytes, width: int, fill: bytes = b' ') -> bytes: + return b.rjust(width, fill) + +def ljust_bytes(b: bytes, width: int, fill: bytes = b' ') -> bytes: + return b.ljust(width, fill) + +def test_rjust_with_default_fill() -> None: + assert rjust_bytes(b'abc', 6) == b' abc', rjust_bytes(b'abc', 6) + assert rjust_bytes(b'abc', 3) == b'abc', rjust_bytes(b'abc', 3) + assert rjust_bytes(b'abc', 2) == b'abc', rjust_bytes(b'abc', 2) + assert rjust_bytes(b'', 4) == b' ', rjust_bytes(b'', 4) + +def test_rjust_with_custom_fill() -> None: + assert rjust_bytes(b'abc', 6, b'0') == b'000abc', rjust_bytes(b'abc', 6, b'0') + assert rjust_bytes(b'abc', 5, b'_') == b'__abc', rjust_bytes(b'abc', 5, b'_') + assert rjust_bytes(b'abc', 3, b'X') == b'abc', rjust_bytes(b'abc', 3, b'X') + +def test_ljust_with_default_fill() -> None: + assert ljust_bytes(b'abc', 6) == b'abc ', ljust_bytes(b'abc', 6) + assert ljust_bytes(b'abc', 3) == b'abc', ljust_bytes(b'abc', 3) + assert ljust_bytes(b'abc', 2) == b'abc', ljust_bytes(b'abc', 2) + assert ljust_bytes(b'', 4) == b' ', ljust_bytes(b'', 4) + +def test_ljust_with_custom_fill() -> None: + assert ljust_bytes(b'abc', 6, b'0') == b'abc000', ljust_bytes(b'abc', 6, b'0') + assert ljust_bytes(b'abc', 5, b'_') == b'abc__', ljust_bytes(b'abc', 5, b'_') + assert ljust_bytes(b'abc', 3, b'X') == b'abc', ljust_bytes(b'abc', 3, b'X') + +def test_edge_cases() -> None: + assert rjust_bytes(b'abc', 0) == b'abc', rjust_bytes(b'abc', 0) + assert ljust_bytes(b'abc', 0) == b'abc', ljust_bytes(b'abc', 0) + # fillbyte must be length 1 + with assertRaises(TypeError): + rjust_bytes(b'abc', 5, b'') + with assertRaises(TypeError): + ljust_bytes(b'abc', 5, b'12')