diff --git a/BUILD b/BUILD index 809727199e..743d1867f6 100644 --- a/BUILD +++ b/BUILD @@ -23,6 +23,7 @@ pyx_library( name = "buffer", srcs = glob([ "python/pyfory/includes/*.pxd", + "python/pyfory/bfloat16.pxd", "python/pyfory/buffer.pxd", "python/pyfory/buffer.pyx", "python/pyfory/__init__.py", @@ -54,6 +55,7 @@ pyx_library( name = "serialization", srcs = glob([ "python/pyfory/includes/*.pxd", + "python/pyfory/bfloat16.pxd", "python/pyfory/buffer.pxd", "python/pyfory/serialization.pyx", "python/pyfory/*.pxi", @@ -70,12 +72,25 @@ pyx_library( ], ) +pyx_library( + name = "bfloat16", + srcs = glob([ + "python/pyfory/bfloat16.pxd", + "python/pyfory/bfloat16.pyx", + "python/pyfory/__init__.py", + ]), + cc_kwargs = dict( + linkstatic = 1, + ), +) + pyx_library( name = "_format", srcs = glob( [ "python/pyfory/__init__.py", "python/pyfory/includes/*.pxd", + "python/pyfory/bfloat16.pxd", "python/pyfory/buffer.pxd", "python/pyfory/*.pxi", "python/pyfory/format/_format.pyx", @@ -96,6 +111,7 @@ genrule( name = "cp_fory_so", srcs = [ ":python/pyfory/buffer.so", + ":python/pyfory/bfloat16.so", ":python/pyfory/lib/mmh3/mmh3.so", ":python/pyfory/format/_format.so", ":python/pyfory/serialization.so", @@ -111,11 +127,13 @@ genrule( if [ "$${u_name: 0: 4}" == "MING" ] || [ "$${u_name: 0: 4}" == "MSYS" ] then cp -f $(location python/pyfory/buffer.so) "$$WORK_DIR/python/pyfory/buffer.pyd" + cp -f $(location python/pyfory/bfloat16.so) "$$WORK_DIR/python/pyfory/bfloat16.pyd" cp -f $(location python/pyfory/lib/mmh3/mmh3.so) "$$WORK_DIR/python/pyfory/lib/mmh3/mmh3.pyd" cp -f $(location python/pyfory/format/_format.so) "$$WORK_DIR/python/pyfory/format/_format.pyd" cp -f $(location python/pyfory/serialization.so) "$$WORK_DIR/python/pyfory/serialization.pyd" else cp -f $(location python/pyfory/buffer.so) "$$WORK_DIR/python/pyfory" + cp -f $(location python/pyfory/bfloat16.so) "$$WORK_DIR/python/pyfory" cp -f $(location python/pyfory/lib/mmh3/mmh3.so) "$$WORK_DIR/python/pyfory/lib/mmh3" cp -f $(location python/pyfory/format/_format.so) "$$WORK_DIR/python/pyfory/format" cp -f $(location python/pyfory/serialization.so) "$$WORK_DIR/python/pyfory" diff --git a/python/README.md b/python/README.md index ee5215f0ec..970ef69a94 100644 --- a/python/README.md +++ b/python/README.md @@ -490,6 +490,34 @@ fory.register(Person.class, "example.Person"); Person person = (Person) fory.deserialize(binaryData); ``` +### BFloat16 Support + +`pyfory` supports `bfloat16` scalar values and `bfloat16` arrays in xlang mode: + +- Scalar type: `pyfory.BFloat16` (type id `18`) +- Array type: `pyfory.BFloat16Array` (type id `54`) + +```python +import pyfory +from pyfory import BFloat16, BFloat16Array + +fory = pyfory.Fory(xlang=True, ref=False, strict=True) + +# Scalar bfloat16 +v = BFloat16(3.1415926) +data = fory.serialize(v) +out = fory.deserialize(data) +print(float(out)) + +# bfloat16 array +arr = BFloat16Array([1.0, 2.5, -3.25]) +data = fory.serialize(arr) +out = fory.deserialize(data) +print(out) +``` + +`BFloat16Array` stores values in a packed `array('H')` representation and writes bytes in little-endian order for cross-language compatibility. + ## 📊 Row Format - Zero-Copy Processing Apache Fury™ provides a random-access row format that enables reading nested fields from binary data without full deserialization. This drastically reduces overhead when working with large objects where only partial data access is needed. The format also supports memory-mapped files for ultra-low memory footprint. diff --git a/python/pyfory/__init__.py b/python/pyfory/__init__.py index 377174a032..c5ef84715b 100644 --- a/python/pyfory/__init__.py +++ b/python/pyfory/__init__.py @@ -50,6 +50,7 @@ TaggedUint64Serializer, Float32Serializer, Float64Serializer, + BFloat16Serializer, StringSerializer, DateSerializer, TimestampSerializer, @@ -88,6 +89,8 @@ tagged_uint64, float32, float64, + bfloat16 as bfloat16_type, + bfloat16_array, int8_array, uint8_array, int16_array, @@ -118,6 +121,13 @@ from pyfory.policy import DeserializationPolicy # noqa: F401 # pylint: disable=unused-import from pyfory.buffer import Buffer # noqa: F401 # pylint: disable=unused-import +# BFloat16 support +from pyfory.bfloat16 import bfloat16 # noqa: F401 +from pyfory.bfloat16_array import BFloat16Array # noqa: F401 + +# Keep compatibility with existing API naming. +BFloat16 = bfloat16 + __version__ = "0.16.0.dev0" __all__ = [ @@ -151,6 +161,10 @@ "tagged_uint64", "float32", "float64", + "BFloat16", + "BFloat16Array", + "bfloat16", + "bfloat16_array", "int8_array", "uint8_array", "int16_array", @@ -192,6 +206,7 @@ "TaggedUint64Serializer", "Float32Serializer", "Float64Serializer", + "BFloat16Serializer", "StringSerializer", "DateSerializer", "TimestampSerializer", diff --git a/python/pyfory/_serializer.py b/python/pyfory/_serializer.py index 8f14cece10..c18326c83e 100644 --- a/python/pyfory/_serializer.py +++ b/python/pyfory/_serializer.py @@ -1,357 +1,393 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import datetime -import logging -import platform -import time -from abc import ABC - -from pyfory._fory import NOT_NULL_INT64_FLAG -from pyfory.resolver import NOT_NULL_VALUE_FLAG, NULL_FLAG -from pyfory.types import is_primitive_type - -try: - import numpy as np -except ImportError: - np = None - -logger = logging.getLogger(__name__) - - -class Serializer(ABC): - __slots__ = "fory", "type_", "need_to_write_ref" - - def __init__(self, fory, type_: type): - self.fory = fory - self.type_: type = type_ - self.need_to_write_ref = fory.track_ref and not is_primitive_type(type_) - - def write(self, buffer, value): - raise NotImplementedError - - def read(self, buffer): - raise NotImplementedError - - @classmethod - def support_subclass(cls) -> bool: - return False - - -class BooleanSerializer(Serializer): - def write(self, buffer, value): - buffer.write_bool(value) - - def read(self, buffer): - return buffer.read_bool() - - -class ByteSerializer(Serializer): - def write(self, buffer, value): - buffer.write_int8(value) - - def read(self, buffer): - return buffer.read_int8() - - -class Int16Serializer(Serializer): - def write(self, buffer, value): - buffer.write_int16(value) - - def read(self, buffer): - return buffer.read_int16() - - -class Int32Serializer(Serializer): - """Serializer for INT32/VARINT32 type - uses variable-length encoding for xlang compatibility.""" - - def write(self, buffer, value): - buffer.write_varint32(value) - - def read(self, buffer): - return buffer.read_varint32() - - -class FixedInt32Serializer(Serializer): - """Serializer for fixed-width 32-bit signed integer (INT32 type_id=4).""" - - def write(self, buffer, value): - buffer.write_int32(value) - - def read(self, buffer): - return buffer.read_int32() - - -class Int64Serializer(Serializer): - """Serializer for INT64/VARINT64 type - uses variable-length encoding for xlang compatibility.""" - - def write(self, buffer, value): - buffer.write_varint64(value) - - def read(self, buffer): - return buffer.read_varint64() - - -class FixedInt64Serializer(Serializer): - """Serializer for fixed-width 64-bit signed integer (INT64 type_id=6).""" - - def write(self, buffer, value): - buffer.write_int64(value) - - def read(self, buffer): - return buffer.read_int64() - - -class Varint32Serializer(Serializer): - """Serializer for VARINT32 type - variable-length encoded signed 32-bit integer.""" - - def write(self, buffer, value): - buffer.write_varint32(value) - - def read(self, buffer): - return buffer.read_varint32() - - -class Varint64Serializer(Serializer): - """Serializer for VARINT64 type - variable-length encoded signed 64-bit integer.""" - - def write(self, buffer, value): - buffer.write_varint64(value) - - def read(self, buffer): - return buffer.read_varint64() - - -class TaggedInt64Serializer(Serializer): - """Serializer for TAGGED_INT64 type - tagged encoding for signed 64-bit integer.""" - - def write(self, buffer, value): - buffer.write_tagged_int64(value) - - def read(self, buffer): - return buffer.read_tagged_int64() - - -class Uint8Serializer(Serializer): - """Serializer for UINT8 type - unsigned 8-bit integer.""" - - def write(self, buffer, value): - buffer.write_uint8(value) - - def read(self, buffer): - return buffer.read_uint8() - - -class Uint16Serializer(Serializer): - """Serializer for UINT16 type - unsigned 16-bit integer.""" - - def write(self, buffer, value): - buffer.write_uint16(value) - - def read(self, buffer): - return buffer.read_uint16() - - -class Uint32Serializer(Serializer): - """Serializer for UINT32 type - fixed-size unsigned 32-bit integer.""" - - def write(self, buffer, value): - buffer.write_uint32(value) - - def read(self, buffer): - return buffer.read_uint32() - - -class VarUint32Serializer(Serializer): - """Serializer for VAR_UINT32 type - variable-length encoded unsigned 32-bit integer.""" - - def write(self, buffer, value): - buffer.write_var_uint32(value) - - def read(self, buffer): - return buffer.read_var_uint32() - - -class Uint64Serializer(Serializer): - """Serializer for UINT64 type - fixed-size unsigned 64-bit integer.""" - - def write(self, buffer, value): - buffer.write_uint64(value) - - def read(self, buffer): - return buffer.read_uint64() - - -class VarUint64Serializer(Serializer): - """Serializer for VAR_UINT64 type - variable-length encoded unsigned 64-bit integer.""" - - def write(self, buffer, value): - buffer.write_var_uint64(value) - - def read(self, buffer): - return buffer.read_var_uint64() - - -class TaggedUint64Serializer(Serializer): - """Serializer for TAGGED_UINT64 type - tagged encoding for unsigned 64-bit integer.""" - - def write(self, buffer, value): - buffer.write_tagged_uint64(value) - - def read(self, buffer): - return buffer.read_tagged_uint64() - - -class Float32Serializer(Serializer): - def write(self, buffer, value): - buffer.write_float(value) - - def read(self, buffer): - return buffer.read_float() - - -class Float64Serializer(Serializer): - def write(self, buffer, value): - buffer.write_double(value) - - def read(self, buffer): - return buffer.read_double() - - -class StringSerializer(Serializer): - def __init__(self, fory, type_): - super().__init__(fory, type_) - self.need_to_write_ref = False - - def write(self, buffer, value: str): - buffer.write_string(value) - - def read(self, buffer): - return buffer.read_string() - - -_base_date = datetime.date(1970, 1, 1) - - -class DateSerializer(Serializer): - def write(self, buffer, value: datetime.date): - if not isinstance(value, datetime.date): - raise TypeError("{} should be {} instead of {}".format(value, datetime.date, type(value))) - days = (value - _base_date).days - buffer.write_int32(days) - - def read(self, buffer): - days = buffer.read_int32() - return _base_date + datetime.timedelta(days=days) - - -class TimestampSerializer(Serializer): - __win_platform = platform.system() == "Windows" - - def _get_timestamp(self, value: datetime.datetime): - seconds_offset = 0 - if TimestampSerializer.__win_platform and value.tzinfo is None: - is_dst = time.daylight and time.localtime().tm_isdst > 0 - seconds_offset = time.altzone if is_dst else time.timezone - value = value.replace(tzinfo=datetime.timezone.utc) - micros = int((value.timestamp() + seconds_offset) * 1_000_000) - seconds, micros_rem = divmod(micros, 1_000_000) - nanos = micros_rem * 1000 - return seconds, nanos - - def write(self, buffer, value: datetime.datetime): - if not isinstance(value, datetime.datetime): - raise TypeError("{} should be {} instead of {}".format(value, datetime, type(value))) - seconds, nanos = self._get_timestamp(value) - buffer.write_int64(seconds) - buffer.write_uint32(nanos) - - def read(self, buffer): - seconds = buffer.read_int64() - nanos = buffer.read_uint32() - ts = seconds + nanos / 1_000_000_000 - # TODO support timezone - return datetime.datetime.fromtimestamp(ts) - - -class EnumSerializer(Serializer): - def __init__(self, fory, type_): - super().__init__(fory, type_) - self.need_to_write_ref = False - self._members = tuple(type_) - self._ordinal_by_member = {member: idx for idx, member in enumerate(self._members)} - - @classmethod - def support_subclass(cls) -> bool: - return True - - def write(self, buffer, value): - buffer.write_var_uint32(self._ordinal_by_member[value]) - - def read(self, buffer): - ordinal = buffer.read_var_uint32() - return self._members[ordinal] - - -class SliceSerializer(Serializer): - def write(self, buffer, value: slice): - start, stop, step = value.start, value.stop, value.step - if type(start) is int: - # TODO support varint128 - buffer.write_int16(NOT_NULL_INT64_FLAG) - buffer.write_varint64(start) - else: - if start is None: - buffer.write_int8(NULL_FLAG) - else: - buffer.write_int8(NOT_NULL_VALUE_FLAG) - self.fory.write_no_ref(buffer, start) - if type(stop) is int: - # TODO support varint128 - buffer.write_int16(NOT_NULL_INT64_FLAG) - buffer.write_varint64(stop) - else: - if stop is None: - buffer.write_int8(NULL_FLAG) - else: - buffer.write_int8(NOT_NULL_VALUE_FLAG) - self.fory.write_no_ref(buffer, stop) - if type(step) is int: - # TODO support varint128 - buffer.write_int16(NOT_NULL_INT64_FLAG) - buffer.write_varint64(step) - else: - if step is None: - buffer.write_int8(NULL_FLAG) - else: - buffer.write_int8(NOT_NULL_VALUE_FLAG) - self.fory.write_no_ref(buffer, step) - - def read(self, buffer): - if buffer.read_int8() == NULL_FLAG: - start = None - else: - start = self.fory.read_no_ref(buffer) - if buffer.read_int8() == NULL_FLAG: - stop = None - else: - stop = self.fory.read_no_ref(buffer) - if buffer.read_int8() == NULL_FLAG: - step = None - else: - step = self.fory.read_no_ref(buffer) - return slice(start, stop, step) +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime +import logging +import platform +import time +from abc import ABC + +from pyfory._fory import NOT_NULL_INT64_FLAG +from pyfory.resolver import NOT_NULL_VALUE_FLAG, NULL_FLAG +from pyfory.types import is_primitive_type + +try: + import numpy as np +except ImportError: + np = None + +logger = logging.getLogger(__name__) + + +class Serializer(ABC): + __slots__ = "fory", "type_", "need_to_write_ref" + + def __init__(self, fory, type_: type): + self.fory = fory + self.type_: type = type_ + self.need_to_write_ref = fory.track_ref and not is_primitive_type(type_) + + def write(self, buffer, value): + raise NotImplementedError + + def read(self, buffer): + raise NotImplementedError + + def xwrite(self, buffer, value): + self.write(buffer, value) + + def xread(self, buffer): + return self.read(buffer) + + @classmethod + def support_subclass(cls) -> bool: + return False + + +class BooleanSerializer(Serializer): + def write(self, buffer, value): + buffer.write_bool(value) + + def read(self, buffer): + return buffer.read_bool() + + +class ByteSerializer(Serializer): + def write(self, buffer, value): + buffer.write_int8(value) + + def read(self, buffer): + return buffer.read_int8() + + +class Int16Serializer(Serializer): + def write(self, buffer, value): + buffer.write_int16(value) + + def read(self, buffer): + return buffer.read_int16() + + +class Int32Serializer(Serializer): + """Serializer for INT32/VARINT32 type - uses variable-length encoding for xlang compatibility.""" + + def write(self, buffer, value): + buffer.write_varint32(value) + + def read(self, buffer): + return buffer.read_varint32() + + +class FixedInt32Serializer(Serializer): + """Serializer for fixed-width 32-bit signed integer (INT32 type_id=4).""" + + def write(self, buffer, value): + buffer.write_int32(value) + + def read(self, buffer): + return buffer.read_int32() + + +class Int64Serializer(Serializer): + """Serializer for INT64/VARINT64 type - uses variable-length encoding for xlang compatibility.""" + + def xwrite(self, buffer, value): + buffer.write_varint64(value) + + def xread(self, buffer): + return buffer.read_varint64() + + def write(self, buffer, value): + buffer.write_varint64(value) + + def read(self, buffer): + return buffer.read_varint64() + + +class FixedInt64Serializer(Serializer): + """Serializer for fixed-width 64-bit signed integer (INT64 type_id=6).""" + + def write(self, buffer, value): + buffer.write_int64(value) + + def read(self, buffer): + return buffer.read_int64() + + +class Varint32Serializer(Serializer): + """Serializer for VARINT32 type - variable-length encoded signed 32-bit integer.""" + + def write(self, buffer, value): + buffer.write_varint32(value) + + def read(self, buffer): + return buffer.read_varint32() + + +class Varint64Serializer(Serializer): + """Serializer for VARINT64 type - variable-length encoded signed 64-bit integer.""" + + def write(self, buffer, value): + buffer.write_varint64(value) + + def read(self, buffer): + return buffer.read_varint64() + + +class TaggedInt64Serializer(Serializer): + """Serializer for TAGGED_INT64 type - tagged encoding for signed 64-bit integer.""" + + def write(self, buffer, value): + buffer.write_tagged_int64(value) + + def read(self, buffer): + return buffer.read_tagged_int64() + + +class Uint8Serializer(Serializer): + """Serializer for UINT8 type - unsigned 8-bit integer.""" + + def write(self, buffer, value): + buffer.write_uint8(value) + + def read(self, buffer): + return buffer.read_uint8() + + +class Uint16Serializer(Serializer): + """Serializer for UINT16 type - unsigned 16-bit integer.""" + + def write(self, buffer, value): + buffer.write_uint16(value) + + def read(self, buffer): + return buffer.read_uint16() + + +class Uint32Serializer(Serializer): + """Serializer for UINT32 type - fixed-size unsigned 32-bit integer.""" + + def write(self, buffer, value): + buffer.write_uint32(value) + + def read(self, buffer): + return buffer.read_uint32() + + +class VarUint32Serializer(Serializer): + """Serializer for VAR_UINT32 type - variable-length encoded unsigned 32-bit integer.""" + + def write(self, buffer, value): + buffer.write_var_uint32(value) + + def read(self, buffer): + return buffer.read_var_uint32() + + +class Uint64Serializer(Serializer): + """Serializer for UINT64 type - fixed-size unsigned 64-bit integer.""" + + def write(self, buffer, value): + buffer.write_uint64(value) + + def read(self, buffer): + return buffer.read_uint64() + + +class VarUint64Serializer(Serializer): + """Serializer for VAR_UINT64 type - variable-length encoded unsigned 64-bit integer.""" + + def write(self, buffer, value): + buffer.write_var_uint64(value) + + def read(self, buffer): + return buffer.read_var_uint64() + + +class TaggedUint64Serializer(Serializer): + """Serializer for TAGGED_UINT64 type - tagged encoding for unsigned 64-bit integer.""" + + def write(self, buffer, value): + buffer.write_tagged_uint64(value) + + def read(self, buffer): + return buffer.read_tagged_uint64() + + +class Float32Serializer(Serializer): + def write(self, buffer, value): + buffer.write_float(value) + + def read(self, buffer): + return buffer.read_float() + + +class Float64Serializer(Serializer): + def write(self, buffer, value): + buffer.write_double(value) + + def read(self, buffer): + return buffer.read_double() + + +class BFloat16Serializer(Serializer): + def write(self, buffer, value): + from pyfory.bfloat16 import bfloat16 + + if isinstance(value, bfloat16): + buffer.write_bfloat16(value.to_bits()) + else: + buffer.write_bfloat16(bfloat16(value).to_bits()) + + def read(self, buffer): + return buffer.read_bfloat16() + + +class StringSerializer(Serializer): + def __init__(self, fory, type_): + super().__init__(fory, type_) + self.need_to_write_ref = False + + def write(self, buffer, value: str): + buffer.write_string(value) + + def read(self, buffer): + return buffer.read_string() + + +_base_date = datetime.date(1970, 1, 1) + + +class DateSerializer(Serializer): + def write(self, buffer, value: datetime.date): + if not isinstance(value, datetime.date): + raise TypeError("{} should be {} instead of {}".format(value, datetime.date, type(value))) + days = (value - _base_date).days + buffer.write_int32(days) + + def read(self, buffer): + days = buffer.read_int32() + return _base_date + datetime.timedelta(days=days) + + +class TimestampSerializer(Serializer): + __win_platform = platform.system() == "Windows" + + def _get_timestamp(self, value: datetime.datetime): + seconds_offset = 0 + if TimestampSerializer.__win_platform and value.tzinfo is None: + is_dst = time.daylight and time.localtime().tm_isdst > 0 + seconds_offset = time.altzone if is_dst else time.timezone + value = value.replace(tzinfo=datetime.timezone.utc) + micros = int((value.timestamp() + seconds_offset) * 1_000_000) + seconds, micros_rem = divmod(micros, 1_000_000) + nanos = micros_rem * 1000 + return seconds, nanos + + def write(self, buffer, value: datetime.datetime): + if not isinstance(value, datetime.datetime): + raise TypeError("{} should be {} instead of {}".format(value, datetime, type(value))) + seconds, nanos = self._get_timestamp(value) + buffer.write_int64(seconds) + buffer.write_uint32(nanos) + + def read(self, buffer): + seconds = buffer.read_int64() + nanos = buffer.read_uint32() + ts = seconds + nanos / 1_000_000_000 + # TODO support timezone + return datetime.datetime.fromtimestamp(ts) + + +class EnumSerializer(Serializer): + def __init__(self, fory, type_): + super().__init__(fory, type_) + self.need_to_write_ref = False + + @classmethod + def support_subclass(cls) -> bool: + return True + + def write(self, buffer, value): + buffer.write_string(value.name) + + def read(self, buffer): + name = buffer.read_string() + return getattr(self.type_, name) + + def xwrite(self, buffer, value): + buffer.write_var_uint32(value.value) + + def xread(self, buffer): + ordinal = buffer.read_var_uint32() + return self.type_(ordinal) + + +class SliceSerializer(Serializer): + def write(self, buffer, value: slice): + start, stop, step = value.start, value.stop, value.step + if type(start) is int: + # TODO support varint128 + buffer.write_int16(NOT_NULL_INT64_FLAG) + buffer.write_varint64(start) + else: + if start is None: + buffer.write_int8(NULL_FLAG) + else: + buffer.write_int8(NOT_NULL_VALUE_FLAG) + self.fory.write_no_ref(buffer, start) + if type(stop) is int: + # TODO support varint128 + buffer.write_int16(NOT_NULL_INT64_FLAG) + buffer.write_varint64(stop) + else: + if stop is None: + buffer.write_int8(NULL_FLAG) + else: + buffer.write_int8(NOT_NULL_VALUE_FLAG) + self.fory.write_no_ref(buffer, stop) + if type(step) is int: + # TODO support varint128 + buffer.write_int16(NOT_NULL_INT64_FLAG) + buffer.write_varint64(step) + else: + if step is None: + buffer.write_int8(NULL_FLAG) + else: + buffer.write_int8(NOT_NULL_VALUE_FLAG) + self.fory.write_no_ref(buffer, step) + + def read(self, buffer): + if buffer.read_int8() == NULL_FLAG: + start = None + else: + start = self.fory.read_no_ref(buffer) + if buffer.read_int8() == NULL_FLAG: + stop = None + else: + stop = self.fory.read_no_ref(buffer) + if buffer.read_int8() == NULL_FLAG: + step = None + else: + step = self.fory.read_no_ref(buffer) + return slice(start, stop, step) + + def xwrite(self, buffer, value): + raise NotImplementedError + + def xread(self, buffer): + raise NotImplementedError diff --git a/python/pyfory/bfloat16.pxd b/python/pyfory/bfloat16.pxd new file mode 100644 index 0000000000..54f7318970 --- /dev/null +++ b/python/pyfory/bfloat16.pxd @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from libc.stdint cimport uint16_t + +cdef class bfloat16: + cdef uint16_t _bits + + +cdef bfloat16 bfloat16_from_bits(uint16_t bits) \ No newline at end of file diff --git a/python/pyfory/bfloat16.pyx b/python/pyfory/bfloat16.pyx new file mode 100644 index 0000000000..bf7bc6a3ff --- /dev/null +++ b/python/pyfory/bfloat16.pyx @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from libc.stdint cimport uint16_t, uint32_t +from libc.string cimport memcpy + +cdef inline uint16_t float32_to_bfloat16_bits(float value) nogil: + cdef uint32_t f32_bits + memcpy(&f32_bits, &value, 4) + cdef uint16_t bf16_bits = (f32_bits >> 16) + cdef uint16_t truncated = (f32_bits & 0xFFFF) + if truncated > 0x8000: + bf16_bits += 1 + if (bf16_bits & 0x7F80) == 0x7F80: + bf16_bits = (bf16_bits & 0x8000) | 0x7F80 + elif truncated == 0x8000 and (bf16_bits & 1): + bf16_bits += 1 + if (bf16_bits & 0x7F80) == 0x7F80: + bf16_bits = (bf16_bits & 0x8000) | 0x7F80 + return bf16_bits + +cdef inline float bfloat16_bits_to_float32(uint16_t bits) nogil: + cdef uint32_t f32_bits = bits << 16 + cdef float result + memcpy(&result, &f32_bits, 4) + return result + + +cdef bfloat16 bfloat16_from_bits(uint16_t bits): + cdef bfloat16 value = bfloat16.__new__(bfloat16) + value._bits = bits + return value + + +cdef class bfloat16: + def __init__(self, value): + if isinstance(value, bfloat16): + self._bits = (value)._bits + else: + self._bits = float32_to_bfloat16_bits(float(value)) + + @staticmethod + def from_bits(uint16_t bits): + return bfloat16_from_bits(bits) + + def to_bits(self): + return self._bits + + def to_float32(self): + return bfloat16_bits_to_float32(self._bits) + + def __float__(self): + return float(self.to_float32()) + + def __repr__(self): + return f"bfloat16({self.to_float32()})" + + def __str__(self): + return str(self.to_float32()) + + def __eq__(self, other): + if isinstance(other, bfloat16): + if self.is_nan() or (other).is_nan(): + return False + if self.is_zero() and (other).is_zero(): + return True + return self._bits == (other)._bits + return False + + def __hash__(self): + return hash(self._bits) + + def is_nan(self): + cdef uint16_t exp = (self._bits >> 7) & 0xFF + cdef uint16_t mant = self._bits & 0x7F + return exp == 0xFF and mant != 0 + + def is_inf(self): + cdef uint16_t exp = (self._bits >> 7) & 0xFF + cdef uint16_t mant = self._bits & 0x7F + return exp == 0xFF and mant == 0 + + def is_zero(self): + return (self._bits & 0x7FFF) == 0 + + def is_finite(self): + cdef uint16_t exp = (self._bits >> 7) & 0xFF + return exp != 0xFF + + def is_normal(self): + cdef uint16_t exp = (self._bits >> 7) & 0xFF + return exp != 0 and exp != 0xFF + + def is_subnormal(self): + cdef uint16_t exp = (self._bits >> 7) & 0xFF + cdef uint16_t mant = self._bits & 0x7F + return exp == 0 and mant != 0 + + def signbit(self): + return (self._bits & 0x8000) != 0 + + +# Backward-compatible alias for existing user code. +BFloat16 = bfloat16 diff --git a/python/pyfory/bfloat16_array.py b/python/pyfory/bfloat16_array.py new file mode 100644 index 0000000000..ad3b3a3772 --- /dev/null +++ b/python/pyfory/bfloat16_array.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import array + +from pyfory.bfloat16 import bfloat16 + + +class BFloat16Array: + def __init__(self, values=None): + if values is None: + self._data = array.array("H") + else: + self._data = array.array("H", [bfloat16(v).to_bits() if not isinstance(v, bfloat16) else v.to_bits() for v in values]) + + def __len__(self): + return len(self._data) + + def __getitem__(self, index): + return bfloat16.from_bits(self._data[index]) + + def __setitem__(self, index, value): + if isinstance(value, bfloat16): + self._data[index] = value.to_bits() + else: + self._data[index] = bfloat16(value).to_bits() + + def __iter__(self): + for bits in self._data: + yield bfloat16.from_bits(bits) + + def __repr__(self): + return f"BFloat16Array([{', '.join(str(bf16) for bf16 in self)}])" + + def __eq__(self, other): + if not isinstance(other, BFloat16Array): + return False + return self._data == other._data + + def append(self, value): + if isinstance(value, bfloat16): + self._data.append(value.to_bits()) + else: + self._data.append(bfloat16(value).to_bits()) + + def extend(self, values): + for value in values: + self.append(value) + + @property + def itemsize(self): + return 2 + + def tobytes(self): + return self._data.tobytes() + + @classmethod + def frombytes(cls, data): + arr = cls() + arr._data = array.array("H") + arr._data.frombytes(data) + return arr diff --git a/python/pyfory/buffer.pxd b/python/pyfory/buffer.pxd index d6f02f133b..e63b19df12 100644 --- a/python/pyfory/buffer.pxd +++ b/python/pyfory/buffer.pxd @@ -25,6 +25,7 @@ from libc.stdint cimport * from libcpp cimport bool as c_bool from libcpp.memory cimport shared_ptr from pyfory.includes.libutil cimport CBuffer, CError +from pyfory.bfloat16 cimport bfloat16 cdef class Buffer: @@ -128,6 +129,8 @@ cdef class Buffer: cpdef inline write_float64(self, double value) + cpdef inline write_bfloat16(self, uint16_t value) + cpdef inline skip(self, int32_t length) cpdef inline c_bool read_bool(self) @@ -158,6 +161,8 @@ cdef class Buffer: cpdef inline double read_float64(self) + cpdef inline bfloat16 read_bfloat16(self) + cpdef inline write_varint64(self, int64_t v) cpdef inline write_var_uint64(self, int64_t v) diff --git a/python/pyfory/buffer.pyx b/python/pyfory/buffer.pyx index 3f8e0935c6..460e9a04d5 100644 --- a/python/pyfory/buffer.pyx +++ b/python/pyfory/buffer.pyx @@ -30,6 +30,7 @@ from cython.operator cimport dereference as deref from libcpp.string cimport string as c_string from libc.stdint cimport * from libcpp cimport bool as c_bool +from pyfory.bfloat16 cimport bfloat16, bfloat16_from_bits from pyfory.includes.libutil cimport( CBuffer, allocate_buffer, get_bit as c_get_bit, set_bit as c_set_bit, clear_bit as c_clear_bit, set_bit_to as c_set_bit_to, CError, CErrorCode, CResultVoidError, utf16_has_surrogate_pairs @@ -244,6 +245,14 @@ cdef class Buffer: cpdef inline write_float64(self, double value): self.c_buffer.write_double(value) + cpdef inline write_bfloat16(self, uint16_t value): + self.c_buffer.write_uint16(value) + + cpdef inline bfloat16 read_bfloat16(self): + cdef uint16_t value = self.c_buffer.read_uint16(self._error) + self._raise_if_error() + return bfloat16_from_bits(value) + cpdef put_buffer(self, uint32_t offset, v, int32_t src_index, int32_t length): if length == 0: # access an emtpy buffer may raise out-of-bound exception. return diff --git a/python/pyfory/codegen.py b/python/pyfory/codegen.py index 820484de7c..6b9d72d684 100644 --- a/python/pyfory/codegen.py +++ b/python/pyfory/codegen.py @@ -58,6 +58,7 @@ "write_nullable_pyfloat64", "read_nullable_pyfloat64", ), + "bfloat16": ("write_bfloat16", "read_bfloat16", "write_nullable_bfloat16", "read_nullable_bfloat16"), } @@ -144,6 +145,8 @@ def compile_function( context["read_nullable_pyfloat64"] = serialization.read_nullable_pyfloat64 context["write_nullable_pystr"] = serialization.write_nullable_pystr context["read_nullable_pystr"] = serialization.read_nullable_pystr + context["write_nullable_bfloat16"] = serialization.write_nullable_bfloat16 + context["read_nullable_bfloat16"] = serialization.read_nullable_bfloat16 stmts = [f"{ident(statement)}" for statement in stmts] # Sanitize the function name to ensure it is valid Python syntax sanitized_function_name = _sanitize_function_name(function_name) diff --git a/python/pyfory/collection.pxi b/python/pyfory/collection.pxi index 394a07dd3f..58ff9cb0eb 100644 --- a/python/pyfory/collection.pxi +++ b/python/pyfory/collection.pxi @@ -239,7 +239,7 @@ cdef class CollectionSerializer(Serializer): cpdef _read_same_type_no_ref(self, Buffer buffer, int64_t len_, object collection_, TypeInfo typeinfo): self.fory.inc_depth() for i in range(len_): - obj = self.fory.read_no_ref(buffer, serializer=typeinfo.serializer) + obj = self.fory.read_no_ref(buffer, typeinfo.serializer) self._add_element(collection_, i, obj) self.fory.dec_depth() @@ -262,7 +262,7 @@ cdef class CollectionSerializer(Serializer): self._add_element( collection_, i, - self.fory.read_no_ref(buffer, serializer=typeinfo.serializer), + self.fory.read_no_ref(buffer, typeinfo.serializer), ) self.fory.dec_depth() @@ -354,7 +354,7 @@ cdef class ListSerializer(CollectionSerializer): if is_py: elem = typeinfo.serializer.read(buffer) else: - elem = self.fory.read_no_ref(buffer, serializer=typeinfo.serializer) + elem = self.fory.read_no_ref(buffer, typeinfo.serializer) Py_INCREF(elem) PyList_SET_ITEM(list_, i, elem) else: @@ -368,7 +368,7 @@ cdef class ListSerializer(CollectionSerializer): if is_py: elem = typeinfo.serializer.read(buffer) else: - elem = self.fory.read_no_ref(buffer, serializer=typeinfo.serializer) + elem = self.fory.read_no_ref(buffer, typeinfo.serializer) Py_INCREF(elem) PyList_SET_ITEM(list_, i, elem) self.fory.dec_depth() @@ -469,7 +469,7 @@ cdef class TupleSerializer(CollectionSerializer): if is_py: elem = typeinfo.serializer.read(buffer) else: - elem = self.fory.read_no_ref(buffer, serializer=typeinfo.serializer) + elem = self.fory.read_no_ref(buffer, typeinfo.serializer) Py_INCREF(elem) PyTuple_SET_ITEM(tuple_, i, elem) else: @@ -483,7 +483,7 @@ cdef class TupleSerializer(CollectionSerializer): if is_py: elem = typeinfo.serializer.read(buffer) else: - elem = self.fory.read_no_ref(buffer, serializer=typeinfo.serializer) + elem = self.fory.read_no_ref(buffer, typeinfo.serializer) Py_INCREF(elem) PyTuple_SET_ITEM(tuple_, i, elem) self.fory.dec_depth() @@ -592,7 +592,7 @@ cdef class SetSerializer(CollectionSerializer): if is_py: instance.add(typeinfo.serializer.read(buffer)) else: - instance.add(self.fory.read_no_ref(buffer, serializer=typeinfo.serializer)) + instance.add(self.fory.read_no_ref(buffer, typeinfo.serializer)) else: # When ref tracking is disabled but has nulls, read null flag first for i in range(len_): @@ -614,7 +614,7 @@ cdef class SetSerializer(CollectionSerializer): if is_py: instance.add(typeinfo.serializer.read(buffer)) else: - instance.add(self.fory.read_no_ref(buffer, serializer=typeinfo.serializer)) + instance.add(self.fory.read_no_ref(buffer, typeinfo.serializer)) self.fory.dec_depth() return instance @@ -898,7 +898,7 @@ cdef class MapSerializer(Serializer): if is_py: key = key_serializer.read(buffer) else: - key = fory.read_no_ref(buffer, serializer=key_serializer) + key = fory.read_no_ref(buffer, key_serializer) else: if is_py: key = fory.read_ref(buffer) @@ -923,7 +923,7 @@ cdef class MapSerializer(Serializer): if is_py: value = ( value_serializer).read(buffer) else: - value = fory.read_no_ref(buffer, serializer=value_serializer) + value = fory.read_no_ref(buffer, value_serializer) else: if is_py: value = fory.read_ref(buffer) @@ -975,7 +975,7 @@ cdef class MapSerializer(Serializer): if is_py: key = ( key_serializer).read(buffer) else: - key = fory.read_no_ref(buffer, serializer=key_serializer) + key = fory.read_no_ref(buffer, key_serializer) if track_value_ref: ref_id = ref_resolver.try_preserve_ref_id(buffer) if ref_id < NOT_NULL_VALUE_FLAG: @@ -1003,7 +1003,7 @@ cdef class MapSerializer(Serializer): if is_py: value = ( value_serializer).read(buffer) else: - value = fory.read_no_ref(buffer, serializer=value_serializer) + value = fory.read_no_ref(buffer, value_serializer) map_[key] = value size -= 1 if size != 0: diff --git a/python/pyfory/collection.py b/python/pyfory/collection.py index c7d6b9c376..df5cac3e07 100644 --- a/python/pyfory/collection.py +++ b/python/pyfory/collection.py @@ -197,7 +197,7 @@ def _read_same_type_no_ref(self, buffer, len_, collection_, typeinfo): for _ in range(len_): self._add_element( collection_, - self.fory.read_no_ref(buffer, serializer=typeinfo.serializer), + self.fory.read_no_ref(buffer, typeinfo.serializer), ) self.fory.dec_depth() @@ -209,7 +209,7 @@ def _read_same_type_has_null(self, buffer, len_, collection_, typeinfo): else: self._add_element( collection_, - self.fory.read_no_ref(buffer, serializer=typeinfo.serializer), + self.fory.read_no_ref(buffer, typeinfo.serializer), ) self.fory.dec_depth() @@ -241,7 +241,7 @@ def _read_different_types(self, buffer, len_, collection_, collect_flag): if typeinfo is None: elem = None else: - elem = self.fory.read_no_ref(buffer, serializer=typeinfo.serializer) + elem = self.fory.read_no_ref(buffer, typeinfo.serializer) self._add_element(collection_, elem) else: # When ref tracking is disabled but has nulls, read null flag first @@ -254,7 +254,7 @@ def _read_different_types(self, buffer, len_, collection_, collect_flag): if typeinfo is None: elem = None else: - elem = self.fory.read_no_ref(buffer, serializer=typeinfo.serializer) + elem = self.fory.read_no_ref(buffer, typeinfo.serializer) self._add_element(collection_, elem) self.fory.dec_depth() @@ -575,7 +575,7 @@ def _read_obj(self, serializer, buffer): return serializer.read(buffer) def _read_obj_no_ref(self, serializer, buffer): - return self.fory.read_no_ref(buffer, serializer=serializer) + return self.fory.read_no_ref(buffer, serializer) SubMapSerializer = MapSerializer diff --git a/python/pyfory/format/__init__.py b/python/pyfory/format/__init__.py index 6c9fb205d8..d2732299af 100644 --- a/python/pyfory/format/__init__.py +++ b/python/pyfory/format/__init__.py @@ -36,6 +36,7 @@ int32, int64, float16, + bfloat16, float32, float64, utf8, diff --git a/python/pyfory/format/schema.pxi b/python/pyfory/format/schema.pxi index 84c859cf12..b155b19d17 100644 --- a/python/pyfory/format/schema.pxi +++ b/python/pyfory/format/schema.pxi @@ -42,6 +42,7 @@ from pyfory.includes.libformat cimport ( ) + # Create Python-accessible TypeId enum # The CTypeId enum from libformat.pxd is only accessible from Cython class TypeId: @@ -417,6 +418,12 @@ def float16(): """Create a 16-bit floating point type.""" return DataType.wrap(c_float16()) +def bfloat16(): + """Create a 16-bit brain floating point type.""" + # TODO: Use c_bfloat16() when C++ row format supports bfloat16 + # For now, use float16 as a temporary workaround since C++ doesn't have bfloat16() yet + return DataType.wrap(c_float16()) + def float32(): """Create a 32-bit floating point type.""" return DataType.wrap(c_float32()) diff --git a/python/pyfory/format/schema.py b/python/pyfory/format/schema.py index 18baa20378..68309b15c0 100644 --- a/python/pyfory/format/schema.py +++ b/python/pyfory/format/schema.py @@ -56,6 +56,8 @@ def arrow_type_to_fory_type_id(arrow_type): # Floating point types if pa_types.is_float16(arrow_type): return 17 # FLOAT16 + if hasattr(pa_types, "is_bfloat16") and pa_types.is_bfloat16(arrow_type): + return 18 # BFLOAT16 if pa_types.is_float32(arrow_type): return 19 # FLOAT32 if pa_types.is_float64(arrow_type): @@ -116,6 +118,7 @@ def fory_type_id_to_arrow_type(type_id, precision=None, scale=None, list_type=No 4: pa.int32(), # INT32 6: pa.int64(), # INT64 17: pa.float16(), # FLOAT16 + 18: pa.float16(), # BFLOAT16 (Arrow doesn't have native bfloat16, map to float16) 19: pa.float32(), # FLOAT32 20: pa.float64(), # FLOAT64 21: pa.utf8(), # STRING @@ -204,17 +207,17 @@ def fory_field_list_to_arrow_schema(field_list): nullable = field_spec.get("nullable", True) # Handle nested types - if type_id == 21: # LIST + if type_id == 22: # LIST value_type = field_spec.get("value_type") arrow_type = pa.list_(value_type) - elif type_id == 23: # MAP + elif type_id == 24: # MAP key_type = field_spec.get("key_type") item_type = field_spec.get("item_type") arrow_type = pa.map_(key_type, item_type) - elif type_id == 15: # STRUCT + elif type_id == 27: # STRUCT struct_fields = field_spec.get("struct_fields", []) arrow_type = pa.struct(struct_fields) - elif type_id == 27: # DECIMAL + elif type_id == 40: # DECIMAL precision = field_spec.get("precision", 38) scale = field_spec.get("scale", 18) arrow_type = pa.decimal128(precision, scale) @@ -274,20 +277,20 @@ def reconstruct_arrow_type(spec): """ type_id = spec["type_id"] - if type_id == 21: # LIST + if type_id == 22: # LIST value_type = reconstruct_arrow_type(spec["value_type"]) return pa.list_(value_type) - elif type_id == 23: # MAP + elif type_id == 24: # MAP key_type = reconstruct_arrow_type(spec["key_type"]) item_type = reconstruct_arrow_type(spec["item_type"]) return pa.map_(key_type, item_type) - elif type_id == 15: # STRUCT + elif type_id == 27: # STRUCT fields = [] for field_spec in spec["fields"]: field_type = reconstruct_arrow_type(field_spec["type"]) fields.append(pa.field(field_spec["name"], field_type, nullable=field_spec.get("nullable", True))) return pa.struct(fields) - elif type_id == 27: # DECIMAL + elif type_id == 40: # DECIMAL return pa.decimal128(spec.get("precision", 38), spec.get("scale", 18)) else: return fory_type_id_to_arrow_type(type_id) diff --git a/python/pyfory/includes/libformat.pxd b/python/pyfory/includes/libformat.pxd index 240cb0f44a..d579ac67c7 100755 --- a/python/pyfory/includes/libformat.pxd +++ b/python/pyfory/includes/libformat.pxd @@ -136,6 +136,9 @@ cdef extern from "fory/row/schema.h" namespace "fory::row" nogil: cdef cppclass CFloat16Type" fory::row::Float16Type"(CFixedWidthType): pass + cdef cppclass CBFloat16Type" fory::row::BFloat16Type"(CFixedWidthType): + pass + cdef cppclass CFloat32Type" fory::row::Float32Type"(CFixedWidthType): pass @@ -223,6 +226,8 @@ cdef extern from "fory/row/schema.h" namespace "fory::row" nogil: shared_ptr[CDataType] int32" fory::row::int32"() shared_ptr[CDataType] int64" fory::row::int64"() shared_ptr[CDataType] float16" fory::row::float16"() + # TODO: Uncomment when C++ row format supports bfloat16 + # shared_ptr[CDataType] bfloat16" fory::row::bfloat16"() shared_ptr[CDataType] float32" fory::row::float32"() shared_ptr[CDataType] float64" fory::row::float64"() shared_ptr[CDataType] utf8" fory::row::utf8"() diff --git a/python/pyfory/primitive.pxi b/python/pyfory/primitive.pxi index 279a5079ab..51fe941f14 100644 --- a/python/pyfory/primitive.pxi +++ b/python/pyfory/primitive.pxi @@ -1,279 +1,298 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -@cython.final -cdef class BooleanSerializer(Serializer): - cpdef inline write(self, Buffer buffer, value): - buffer.write_bool(value) - - cpdef inline read(self, Buffer buffer): - return buffer.read_bool() - - -@cython.final -cdef class ByteSerializer(Serializer): - cpdef inline write(self, Buffer buffer, value): - buffer.write_int8(value) - - cpdef inline read(self, Buffer buffer): - return buffer.read_int8() - - -@cython.final -cdef class Int16Serializer(Serializer): - cpdef inline write(self, Buffer buffer, value): - buffer.write_int16(value) - - cpdef inline read(self, Buffer buffer): - return buffer.read_int16() - - -@cython.final -cdef class Int32Serializer(Serializer): - cpdef inline write(self, Buffer buffer, value): - buffer.write_varint32(value) - - cpdef inline read(self, Buffer buffer): - return buffer.read_varint32() - - -@cython.final -cdef class Int64Serializer(Serializer): - cpdef inline write(self, Buffer buffer, value): - buffer.write_varint64(value) - - cpdef inline read(self, Buffer buffer): - return buffer.read_varint64() - - -@cython.final -cdef class FixedInt32Serializer(Serializer): - """Serializer for fixed-width 32-bit signed integer (INT32 type_id=4).""" - cpdef inline write(self, Buffer buffer, value): - buffer.write_int32(value) - - cpdef inline read(self, Buffer buffer): - return buffer.read_int32() - - -@cython.final -cdef class FixedInt64Serializer(Serializer): - """Serializer for fixed-width 64-bit signed integer (INT64 type_id=6).""" - cpdef inline write(self, Buffer buffer, value): - buffer.write_int64(value) - - cpdef inline read(self, Buffer buffer): - return buffer.read_int64() - - -@cython.final -cdef class Varint32Serializer(Serializer): - """Serializer for VARINT32 type - variable-length encoded signed 32-bit integer.""" - cpdef inline write(self, Buffer buffer, value): - buffer.write_varint32(value) - - cpdef inline read(self, Buffer buffer): - return buffer.read_varint32() - - -@cython.final -cdef class Varint64Serializer(Serializer): - """Serializer for VARINT64 type - variable-length encoded signed 64-bit integer.""" - cpdef inline write(self, Buffer buffer, value): - buffer.write_varint64(value) - - cpdef inline read(self, Buffer buffer): - return buffer.read_varint64() - - -@cython.final -cdef class TaggedInt64Serializer(Serializer): - """Serializer for TAGGED_INT64 type - tagged encoding for signed 64-bit integer.""" - cpdef inline write(self, Buffer buffer, value): - buffer.write_tagged_int64(value) - - cpdef inline read(self, Buffer buffer): - return buffer.read_tagged_int64() - - -@cython.final -cdef class Uint8Serializer(Serializer): - """Serializer for UINT8 type - unsigned 8-bit integer.""" - cpdef inline write(self, Buffer buffer, value): - buffer.write_uint8(value) - - cpdef inline read(self, Buffer buffer): - return buffer.read_uint8() - - -@cython.final -cdef class Uint16Serializer(Serializer): - """Serializer for UINT16 type - unsigned 16-bit integer.""" - cpdef inline write(self, Buffer buffer, value): - buffer.write_uint16(value) - - cpdef inline read(self, Buffer buffer): - return buffer.read_uint16() - - -@cython.final -cdef class Uint32Serializer(Serializer): - """Serializer for UINT32 type - fixed-size unsigned 32-bit integer.""" - cpdef inline write(self, Buffer buffer, value): - buffer.write_uint32(value) - - cpdef inline read(self, Buffer buffer): - return buffer.read_uint32() - - -@cython.final -cdef class VarUint32Serializer(Serializer): - """Serializer for VAR_UINT32 type - variable-length encoded unsigned 32-bit integer.""" - cpdef inline write(self, Buffer buffer, value): - buffer.write_var_uint32(value) - - cpdef inline read(self, Buffer buffer): - return buffer.read_var_uint32() - - -@cython.final -cdef class Uint64Serializer(Serializer): - """Serializer for UINT64 type - fixed-size unsigned 64-bit integer.""" - cpdef inline write(self, Buffer buffer, value): - buffer.write_uint64(value) - - cpdef inline read(self, Buffer buffer): - return buffer.read_uint64() - - -@cython.final -cdef class VarUint64Serializer(Serializer): - """Serializer for VAR_UINT64 type - variable-length encoded unsigned 64-bit integer.""" - cpdef inline write(self, Buffer buffer, value): - buffer.write_var_uint64(value) - - cpdef inline read(self, Buffer buffer): - return buffer.read_var_uint64() - - -@cython.final -cdef class TaggedUint64Serializer(Serializer): - """Serializer for TAGGED_UINT64 type - tagged encoding for unsigned 64-bit integer.""" - cpdef inline write(self, Buffer buffer, value): - buffer.write_tagged_uint64(value) - - cpdef inline read(self, Buffer buffer): - return buffer.read_tagged_uint64() - - -@cython.final -cdef class Float32Serializer(Serializer): - cpdef inline write(self, Buffer buffer, value): - buffer.write_float(value) - - cpdef inline read(self, Buffer buffer): - return buffer.read_float() - - -@cython.final -cdef class Float64Serializer(Serializer): - cpdef inline write(self, Buffer buffer, value): - buffer.write_double(value) - - cpdef inline read(self, Buffer buffer): - return buffer.read_double() - - -@cython.final -cdef class StringSerializer(Serializer): - def __init__(self, fory, type_, track_ref=False): - super().__init__(fory, type_) - self.need_to_write_ref = track_ref - - cpdef inline write(self, Buffer buffer, value): - buffer.write_string(value) - - cpdef inline read(self, Buffer buffer): - return buffer.read_string() - - -cdef _base_date = datetime.date(1970, 1, 1) -cdef int _base_date_ordinal = _base_date.toordinal() # Precompute for faster date deserialization - - -@cython.final -cdef class DateSerializer(Serializer): - cpdef inline write(self, Buffer buffer, value): - if type(value) is not datetime.date: - raise TypeError( - "{} should be {} instead of {}".format( - value, datetime.date, type(value) - ) - ) - days = (value - _base_date).days - buffer.write_int32(days) - - cpdef inline read(self, Buffer buffer): - days = buffer.read_int32() - return datetime.date.fromordinal(_base_date_ordinal + days) - - -@cython.final -cdef class TimestampSerializer(Serializer): - cdef bint win_platform - - def __init__(self, fory, type_: Union[type, TypeVar]): - super().__init__(fory, type_) - self.win_platform = platform.system() == "Windows" - - cdef inline _get_timestamp(self, value): - seconds_offset = 0 - if self.win_platform and value.tzinfo is None: - is_dst = time.daylight and time.localtime().tm_isdst > 0 - seconds_offset = time.altzone if is_dst else time.timezone - value = value.replace(tzinfo=datetime.timezone.utc) - cdef long long micros = ((value.timestamp() + seconds_offset) * 1000000) - cdef long long seconds - cdef long long micros_rem - if micros >= 0: - seconds = micros // 1000000 - micros_rem = micros % 1000000 - else: - seconds = -((-micros) // 1000000) - micros_rem = micros - seconds * 1000000 - if micros_rem < 0: - seconds -= 1 - micros_rem += 1000000 - return seconds, (micros_rem * 1000) - - cpdef inline write(self, Buffer buffer, value): - if type(value) is not datetime.datetime: - raise TypeError( - "{} should be {} instead of {}".format(value, datetime, type(value)) - ) - cdef long long seconds - cdef unsigned int nanos - seconds, nanos = self._get_timestamp(value) - buffer.write_int64(seconds) - buffer.write_uint32(nanos) - - cpdef inline read(self, Buffer buffer): - cdef long long seconds = buffer.read_int64() - cdef unsigned int nanos = buffer.read_uint32() - ts = seconds + (nanos) / 1000000000.0 - # TODO support timezone - return datetime.datetime.fromtimestamp(ts) +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +@cython.final +cdef class BooleanSerializer(Serializer): + cpdef inline write(self, Buffer buffer, value): + buffer.write_bool(value) + + cpdef inline read(self, Buffer buffer): + return buffer.read_bool() + + +@cython.final +cdef class ByteSerializer(Serializer): + cpdef inline write(self, Buffer buffer, value): + buffer.write_int8(value) + + cpdef inline read(self, Buffer buffer): + return buffer.read_int8() + + +@cython.final +cdef class Int16Serializer(Serializer): + cpdef inline write(self, Buffer buffer, value): + buffer.write_int16(value) + + cpdef inline read(self, Buffer buffer): + return buffer.read_int16() + + +@cython.final +cdef class Int32Serializer(Serializer): + cpdef inline write(self, Buffer buffer, value): + buffer.write_varint32(value) + + cpdef inline read(self, Buffer buffer): + return buffer.read_varint32() + + +@cython.final +cdef class Int64Serializer(Serializer): + cpdef inline xwrite(self, Buffer buffer, value): + buffer.write_varint64(value) + + cpdef inline xread(self, Buffer buffer): + return buffer.read_varint64() + + cpdef inline write(self, Buffer buffer, value): + buffer.write_varint64(value) + + cpdef inline read(self, Buffer buffer): + return buffer.read_varint64() + + +@cython.final +cdef class FixedInt32Serializer(Serializer): + """Serializer for fixed-width 32-bit signed integer (INT32 type_id=4).""" + cpdef inline write(self, Buffer buffer, value): + buffer.write_int32(value) + + cpdef inline read(self, Buffer buffer): + return buffer.read_int32() + + +@cython.final +cdef class FixedInt64Serializer(Serializer): + """Serializer for fixed-width 64-bit signed integer (INT64 type_id=6).""" + cpdef inline write(self, Buffer buffer, value): + buffer.write_int64(value) + + cpdef inline read(self, Buffer buffer): + return buffer.read_int64() + + +@cython.final +cdef class Varint32Serializer(Serializer): + """Serializer for VARINT32 type - variable-length encoded signed 32-bit integer.""" + cpdef inline write(self, Buffer buffer, value): + buffer.write_varint32(value) + + cpdef inline read(self, Buffer buffer): + return buffer.read_varint32() + + +@cython.final +cdef class Varint64Serializer(Serializer): + """Serializer for VARINT64 type - variable-length encoded signed 64-bit integer.""" + cpdef inline write(self, Buffer buffer, value): + buffer.write_varint64(value) + + cpdef inline read(self, Buffer buffer): + return buffer.read_varint64() + + +@cython.final +cdef class TaggedInt64Serializer(Serializer): + """Serializer for TAGGED_INT64 type - tagged encoding for signed 64-bit integer.""" + cpdef inline write(self, Buffer buffer, value): + buffer.write_tagged_int64(value) + + cpdef inline read(self, Buffer buffer): + return buffer.read_tagged_int64() + + +@cython.final +cdef class Uint8Serializer(Serializer): + """Serializer for UINT8 type - unsigned 8-bit integer.""" + cpdef inline write(self, Buffer buffer, value): + buffer.write_uint8(value) + + cpdef inline read(self, Buffer buffer): + return buffer.read_uint8() + + +@cython.final +cdef class Uint16Serializer(Serializer): + """Serializer for UINT16 type - unsigned 16-bit integer.""" + cpdef inline write(self, Buffer buffer, value): + buffer.write_uint16(value) + + cpdef inline read(self, Buffer buffer): + return buffer.read_uint16() + + +@cython.final +cdef class Uint32Serializer(Serializer): + """Serializer for UINT32 type - fixed-size unsigned 32-bit integer.""" + cpdef inline write(self, Buffer buffer, value): + buffer.write_uint32(value) + + cpdef inline read(self, Buffer buffer): + return buffer.read_uint32() + + +@cython.final +cdef class VarUint32Serializer(Serializer): + """Serializer for VAR_UINT32 type - variable-length encoded unsigned 32-bit integer.""" + cpdef inline write(self, Buffer buffer, value): + buffer.write_var_uint32(value) + + cpdef inline read(self, Buffer buffer): + return buffer.read_var_uint32() + + +@cython.final +cdef class Uint64Serializer(Serializer): + """Serializer for UINT64 type - fixed-size unsigned 64-bit integer.""" + cpdef inline write(self, Buffer buffer, value): + buffer.write_uint64(value) + + cpdef inline read(self, Buffer buffer): + return buffer.read_uint64() + + +@cython.final +cdef class VarUint64Serializer(Serializer): + """Serializer for VAR_UINT64 type - variable-length encoded unsigned 64-bit integer.""" + cpdef inline write(self, Buffer buffer, value): + buffer.write_var_uint64(value) + + cpdef inline read(self, Buffer buffer): + return buffer.read_var_uint64() + + +@cython.final +cdef class TaggedUint64Serializer(Serializer): + """Serializer for TAGGED_UINT64 type - tagged encoding for unsigned 64-bit integer.""" + cpdef inline write(self, Buffer buffer, value): + buffer.write_tagged_uint64(value) + + cpdef inline read(self, Buffer buffer): + return buffer.read_tagged_uint64() + + +@cython.final +cdef class Float32Serializer(Serializer): + cpdef inline write(self, Buffer buffer, value): + buffer.write_float(value) + + cpdef inline read(self, Buffer buffer): + return buffer.read_float() + + +@cython.final +cdef class Float64Serializer(Serializer): + cpdef inline write(self, Buffer buffer, value): + buffer.write_double(value) + + cpdef inline read(self, Buffer buffer): + return buffer.read_double() + + +@cython.final +cdef class BFloat16Serializer(Serializer): + cpdef inline write(self, Buffer buffer, value): + from pyfory.bfloat16 import bfloat16 + if isinstance(value, bfloat16): + buffer.write_bfloat16(value.to_bits()) + else: + buffer.write_bfloat16(bfloat16(value).to_bits()) + + cpdef inline read(self, Buffer buffer): + return buffer.read_bfloat16() + + +@cython.final +cdef class StringSerializer(Serializer): + def __init__(self, fory, type_, track_ref=False): + super().__init__(fory, type_) + self.need_to_write_ref = track_ref + + cpdef inline write(self, Buffer buffer, value): + buffer.write_string(value) + + cpdef inline read(self, Buffer buffer): + return buffer.read_string() + + +cdef _base_date = datetime.date(1970, 1, 1) +cdef int _base_date_ordinal = _base_date.toordinal() # Precompute for faster date deserialization + + +@cython.final +cdef class DateSerializer(Serializer): + cpdef inline write(self, Buffer buffer, value): + if type(value) is not datetime.date: + raise TypeError( + "{} should be {} instead of {}".format( + value, datetime.date, type(value) + ) + ) + days = (value - _base_date).days + buffer.write_int32(days) + + cpdef inline read(self, Buffer buffer): + days = buffer.read_int32() + return datetime.date.fromordinal(_base_date_ordinal + days) + + +@cython.final +cdef class TimestampSerializer(Serializer): + cdef bint win_platform + + def __init__(self, fory, type_: Union[type, TypeVar]): + super().__init__(fory, type_) + self.win_platform = platform.system() == "Windows" + + cdef inline _get_timestamp(self, value): + seconds_offset = 0 + if self.win_platform and value.tzinfo is None: + is_dst = time.daylight and time.localtime().tm_isdst > 0 + seconds_offset = time.altzone if is_dst else time.timezone + value = value.replace(tzinfo=datetime.timezone.utc) + cdef long long micros = ((value.timestamp() + seconds_offset) * 1000000) + cdef long long seconds + cdef long long micros_rem + if micros >= 0: + seconds = micros // 1000000 + micros_rem = micros % 1000000 + else: + seconds = -((-micros) // 1000000) + micros_rem = micros - seconds * 1000000 + if micros_rem < 0: + seconds -= 1 + micros_rem += 1000000 + return seconds, (micros_rem * 1000) + + cpdef inline write(self, Buffer buffer, value): + if type(value) is not datetime.datetime: + raise TypeError( + "{} should be {} instead of {}".format(value, datetime, type(value)) + ) + cdef long long seconds + cdef unsigned int nanos + seconds, nanos = self._get_timestamp(value) + buffer.write_int64(seconds) + buffer.write_uint32(nanos) + + cpdef inline read(self, Buffer buffer): + cdef long long seconds = buffer.read_int64() + cdef unsigned int nanos = buffer.read_uint32() + ts = seconds + (nanos) / 1000000000.0 + # TODO support timezone + return datetime.datetime.fromtimestamp(ts) diff --git a/python/pyfory/registry.py b/python/pyfory/registry.py index 6328dd4a5a..c2dea631e4 100644 --- a/python/pyfory/registry.py +++ b/python/pyfory/registry.py @@ -35,7 +35,6 @@ Serializer, Numpy1DArraySerializer, NDArraySerializer, - PythonNDArraySerializer, PyArraySerializer, DynamicPyArraySerializer, NoneSerializer, @@ -56,6 +55,7 @@ TaggedUint64Serializer, Float32Serializer, Float64Serializer, + BFloat16Serializer, StringSerializer, DateSerializer, TimestampSerializer, @@ -249,7 +249,7 @@ def _initialize_py(self): register(tuple, serializer=TupleSerializer) register(slice, serializer=SliceSerializer) if np is not None: - register(np.ndarray, serializer=PythonNDArraySerializer) + register(np.ndarray, serializer=NDArraySerializer) register(array.array, serializer=DynamicPyArraySerializer) register(types.MappingProxyType, serializer=MappingProxySerializer) register(pickle.PickleBuffer, serializer=PickleBufferSerializer) @@ -321,6 +321,13 @@ def _initialize_common(self): serializer=Float64Serializer, ) register(float, type_id=TypeId.FLOAT64, serializer=Float64Serializer) + from pyfory.bfloat16 import bfloat16 + + register( + bfloat16, + type_id=TypeId.BFLOAT16, + serializer=BFloat16Serializer, + ) register(str, type_id=TypeId.STRING, serializer=StringSerializer) # TODO(chaokunyang) DURATION DECIMAL register(datetime.datetime, type_id=TypeId.TIMESTAMP, serializer=TimestampSerializer) @@ -332,6 +339,14 @@ def _initialize_common(self): type_id=typeid, serializer=PyArraySerializer(self.fory, ftype, typeid), ) + from pyfory.bfloat16_array import BFloat16Array + from pyfory.serializer import BFloat16ArraySerializer + + register( + BFloat16Array, + type_id=TypeId.BFLOAT16_ARRAY, + serializer=BFloat16ArraySerializer(self.fory, BFloat16Array, TypeId.BFLOAT16_ARRAY), + ) if np: # overwrite pyarray with same type id. # if pyarray are needed, one must annotate that value with XXXArrayType @@ -451,7 +466,8 @@ def _register_type( raise TypeError(f"type name {typename} and id {type_id} should not be set at the same time") if cls in self._types_info: raise TypeError(f"{cls} registered already") - return self._register_xtype( + register_type = self._register_xtype if self.fory.xlang else self._register_pytype + return register_type( cls, type_id=type_id, user_type_id=user_type_id, @@ -519,6 +535,30 @@ def _register_xtype( internal=internal, ) + def _register_pytype( + self, + cls: Union[type, TypeVar], + *, + type_id: int = None, + user_type_id: int = NO_USER_TYPE_ID, + namespace: str = None, + typename: str = None, + serializer: Serializer = None, + internal: bool = False, + ): + # Set default type_id when None, similar to _register_xtype + if type_id is None and typename is not None: + type_id = self._next_type_id() + return self.__register_type( + cls, + type_id=type_id, + user_type_id=user_type_id, + namespace=namespace, + typename=typename, + serializer=serializer, + internal=internal, + ) + def __register_type( self, cls: Union[type, TypeVar], @@ -565,7 +605,7 @@ def __register_type( if user_type_id not in self._user_type_id_to_type_info or not internal: self._user_type_id_to_type_info[user_type_id] = typeinfo self._used_user_type_ids.add(user_type_id) - elif not TypeId.is_namespaced_type(type_id): + elif not self.fory.xlang or not TypeId.is_namespaced_type(type_id): if type_id not in self._type_id_to_type_info or not internal: self._type_id_to_type_info[type_id] = typeinfo self._types_info[cls] = typeinfo @@ -588,6 +628,9 @@ def register_serializer(self, cls: Union[type, TypeVar], serializer): if cls not in self._types_info: raise TypeUnregisteredError(f"{cls} not registered") typeinfo = self._types_info[cls] + if not self.fory.xlang: + typeinfo.serializer = serializer + return prev_type_id = typeinfo.type_id prev_user_type_id = typeinfo.user_type_id if needs_user_type_id(prev_type_id) and prev_user_type_id not in {None, NO_USER_TYPE_ID}: diff --git a/python/pyfory/serialization.pyx b/python/pyfory/serialization.pyx index ba3eabb361..bf8e49f774 100644 --- a/python/pyfory/serialization.pyx +++ b/python/pyfory/serialization.pyx @@ -626,7 +626,7 @@ cdef class TypeResolver: else: if type_id >= self._c_registered_id_to_type_info.size(): self._c_registered_id_to_type_info.resize(type_id * 2, NULL) - if type_id > 0 and not is_namespaced_type(type_id): + if type_id > 0 and (not self.fory.xlang or not is_namespaced_type(type_id)): self._c_registered_id_to_type_info[type_id] = typeinfo self._c_types_info[ typeinfo.cls] = typeinfo # Resize if load factor >= 0.4 (using integer arithmetic: size/capacity >= 4/10) @@ -1083,8 +1083,9 @@ cdef class Fory: it controls which types can be deserialized, overriding the default policy. **Strongly recommended** when strict=False to maintain security controls. - field_nullable: Treat all dataclass fields as nullable regardless of - Optional annotation. + field_nullable: Treat all dataclass fields as nullable in Python-native mode + (xlang=False), regardless of Optional annotation. Ignored in cross-language + mode. Example: >>> # Python-native mode with reference tracking @@ -1102,7 +1103,7 @@ cdef class Fory: self.compatible = compatible self.track_ref = ref self.ref_resolver = MapRefResolver(ref) - self.field_nullable = field_nullable + self.field_nullable = field_nullable if not self.xlang else False self.metastring_resolver = MetaStringResolver() self.type_resolver = TypeResolver(self, meta_share=compatible, meta_compressor=meta_compressor) self.serialization_context = SerializationContext(fory=self, scoped_meta_share_enabled=compatible) @@ -1298,7 +1299,7 @@ cdef class Fory: cpdef inline _serialize( self, obj, Buffer buffer, buffer_callback=None, unsupported_callback=None): - assert self.depth == 0, "Nested serialization should use write_ref/write_no_ref." + assert self.depth == 0, "Nested serialization should use write_ref/write_no_ref/xwrite_ref/xwrite_no_ref." self.depth += 1 self.buffer_callback = buffer_callback self._unsupported_callback = unsupported_callback @@ -1314,13 +1315,21 @@ cdef class Fory: else: clear_bit(buffer, mask_index, 0) - # Unified protocol always writes xlang-compatible payload framing. - set_bit(buffer, mask_index, 1) + if self.xlang: + # set reader as x_lang. + set_bit(buffer, mask_index, 1) + else: + # set reader as native. + clear_bit(buffer, mask_index, 1) if self.buffer_callback is not None: set_bit(buffer, mask_index, 2) else: clear_bit(buffer, mask_index, 2) - self.write_ref(buffer, obj) + cdef int32_t start_offset + if not self.xlang: + self.write_ref(buffer, obj) + else: + self.xwrite_ref(buffer, obj) if buffer is not self.buffer: return buffer @@ -1328,29 +1337,32 @@ cdef class Fory: return buffer.to_bytes(0, buffer.get_writer_index()) cpdef inline write_ref( - self, Buffer buffer, obj, TypeInfo typeinfo=None, Serializer serializer=None): - if serializer is None and typeinfo is not None: - serializer = typeinfo.serializer - if serializer is None or serializer.need_to_write_ref: - if self.ref_resolver.write_ref_or_null(buffer, obj): - return - self.write_no_ref(buffer, obj, serializer=serializer, typeinfo=typeinfo) - else: - if obj is None: - buffer.write_int8(NULL_FLAG) - else: - buffer.write_int8(NOT_NULL_VALUE_FLAG) - self.write_no_ref(buffer, obj, serializer=serializer, typeinfo=typeinfo) - - cpdef inline write_no_ref( - self, - Buffer buffer, - obj, - Serializer serializer=None, - TypeInfo typeinfo=None): - if serializer is not None: - serializer.write(buffer, obj) + self, Buffer buffer, obj, TypeInfo typeinfo=None): + cls = type(obj) + if cls is str: + buffer.write_int16(NOT_NULL_STRING_FLAG) + buffer.write_string(obj) return + elif cls is int: + buffer.write_int16(NOT_NULL_INT64_FLAG) + buffer.write_varint64(obj) + return + elif cls is bool: + buffer.write_int16(NOT_NULL_BOOL_FLAG) + buffer.write_bool(obj) + return + elif cls is float: + buffer.write_int16(NOT_NULL_FLOAT64_FLAG) + buffer.write_double(obj) + return + if self.ref_resolver.write_ref_or_null(buffer, obj): + return + if typeinfo is None: + typeinfo = self.type_resolver.get_type_info(cls) + self.type_resolver.write_type_info(buffer, typeinfo) + typeinfo.serializer.write(buffer, obj) + + cpdef inline write_no_ref(self, Buffer buffer, obj, Serializer serializer=None, TypeInfo typeinfo=None): cls = type(obj) if cls is str: buffer.write_var_uint32(STRING_TYPE_ID) @@ -1368,10 +1380,36 @@ cdef class Fory: buffer.write_var_uint32(FLOAT64_TYPE_ID) buffer.write_double(obj) return - if typeinfo is None: - typeinfo = self.type_resolver.get_type_info(cls) - self.type_resolver.write_type_info(buffer, typeinfo) - typeinfo.serializer.write(buffer, obj) + if serializer is None: + if typeinfo is None: + typeinfo = self.type_resolver.get_type_info(cls) + self.type_resolver.write_type_info(buffer, typeinfo) + serializer = typeinfo.serializer + serializer.write(buffer, obj) + + cpdef inline xwrite_ref( + self, Buffer buffer, obj, Serializer serializer=None): + if serializer is None or serializer.need_to_write_ref: + if not self.ref_resolver.write_ref_or_null(buffer, obj): + self.xwrite_no_ref( + buffer, obj, serializer=serializer + ) + else: + if obj is None: + buffer.write_int8(NULL_FLAG) + else: + buffer.write_int8(NOT_NULL_VALUE_FLAG) + self.xwrite_no_ref( + buffer, obj, serializer=serializer + ) + + cpdef inline xwrite_no_ref( + self, Buffer buffer, obj, Serializer serializer=None): + if serializer is None: + typeinfo = self.type_resolver.get_type_info(type(obj)) + self.type_resolver.write_type_info(buffer, typeinfo) + serializer = typeinfo.serializer + serializer.xwrite(buffer, obj) def deserialize( self, @@ -1410,7 +1448,7 @@ cdef class Fory: cpdef inline _deserialize( self, Buffer buffer, buffers=None, unsupported_objects=None): - assert self.depth == 0, "Nested deserialization should use read_ref/read_no_ref." + assert self.depth == 0, "Nested deserialization should use read_ref/read_no_ref/xread_ref/xread_no_ref." self.depth += 1 if unsupported_objects is not None: self._unsupported_objects = iter(unsupported_objects) @@ -1418,6 +1456,7 @@ cdef class Fory: buffer.set_reader_index(reader_index + 1) if get_bit(buffer, reader_index, 0): return None + cdef c_bool is_target_x_lang = get_bit(buffer, reader_index, 1) self.is_peer_out_of_band_enabled = get_bit(buffer, reader_index, 2) if self.is_peer_out_of_band_enabled: assert buffers is not None, ( @@ -1431,36 +1470,39 @@ cdef class Fory: "produced with buffer_callback null." ) - return self.read_ref(buffer) + if not is_target_x_lang: + obj = self.read_ref(buffer) + else: + obj = self.xread_ref(buffer) - cpdef inline read_ref(self, Buffer buffer, Serializer serializer=None): - cdef int8_t head_flag - cdef MapRefResolver ref_resolver - cdef int32_t ref_id - if serializer is None or serializer.need_to_write_ref: - ref_resolver = self.ref_resolver - ref_id = ref_resolver.try_preserve_ref_id(buffer) - if ref_id >= NOT_NULL_VALUE_FLAG: - o = self._read_no_ref_internal(buffer, serializer) - ref_resolver.set_read_object(ref_id, o) - return o + return obj + + cpdef inline read_ref(self, Buffer buffer): + cdef MapRefResolver ref_resolver = self.ref_resolver + cdef int32_t ref_id = ref_resolver.try_preserve_ref_id(buffer) + if ref_id < NOT_NULL_VALUE_FLAG: return ref_resolver.get_read_object() - head_flag = buffer.read_int8() - if head_flag == NULL_FLAG: - return None - return self.read_no_ref(buffer, serializer=serializer) + # indicates that the object is first read. + cdef TypeInfo typeinfo = self.type_resolver.read_type_info(buffer) + cls = typeinfo.cls + if cls is str: + return buffer.read_string() + elif cls is int: + return buffer.read_varint64() + elif cls is bool: + return buffer.read_bool() + elif cls is float: + return buffer.read_double() + self.inc_depth() + o = typeinfo.serializer.read(buffer) + self.depth -= 1 + ref_resolver.set_read_object(ref_id, o) + return o cpdef inline read_no_ref(self, Buffer buffer, Serializer serializer=None): """Deserialize not-null and non-reference object from buffer.""" - if self.ref_resolver.track_ref: - # Push -1 so reference() can pop and skip tracking when read_no_ref is called directly. - self.ref_resolver.read_ref_ids.push_back(-1) - return self._read_no_ref_internal(buffer, serializer) - - cdef inline _read_no_ref_internal( - self, Buffer buffer, Serializer serializer): cdef TypeInfo typeinfo - cdef cls + cdef object cls if serializer is None: typeinfo = self.type_resolver.read_type_info(buffer) cls = typeinfo.cls @@ -1478,6 +1520,47 @@ cdef class Fory: self.depth -= 1 return o + cpdef inline xread_ref(self, Buffer buffer, Serializer serializer=None): + cdef MapRefResolver ref_resolver + cdef int32_t ref_id + if serializer is None or serializer.need_to_write_ref: + ref_resolver = self.ref_resolver + ref_id = ref_resolver.try_preserve_ref_id(buffer) + # indicates that the object is first read. + if ref_id >= NOT_NULL_VALUE_FLAG: + # Don't push -1 here - try_preserve_ref_id already pushed ref_id + o = self._xread_no_ref_internal(buffer, serializer) + ref_resolver.set_read_object(ref_id, o) + return o + else: + return ref_resolver.get_read_object() + cdef int8_t head_flag = buffer.read_int8() + if head_flag == NULL_FLAG: + return None + return self.xread_no_ref( + buffer, serializer=serializer + ) + + cpdef inline xread_no_ref( + self, Buffer buffer, Serializer serializer=None): + if serializer is None: + serializer = self.type_resolver.read_type_info(buffer).serializer + # Push -1 to read_ref_ids so reference() can pop it and skip reference tracking + # This handles the case where xread_no_ref is called directly without xread_ref + if self.ref_resolver.track_ref: + self.ref_resolver.read_ref_ids.push_back(-1) + return self._xread_no_ref_internal(buffer, serializer) + + cdef inline _xread_no_ref_internal( + self, Buffer buffer, Serializer serializer): + """Internal method to read without pushing to read_ref_ids.""" + if serializer is None: + serializer = self.type_resolver.read_type_info(buffer).serializer + self.inc_depth() + o = serializer.xread(buffer) + self.depth -= 1 + return o + cpdef inline inc_depth(self): self.depth += 1 if self.depth > self.max_depth: @@ -1714,13 +1797,36 @@ cpdef inline read_nullable_pystr(Buffer buffer): else: return None +cpdef inline write_nullable_bfloat16(Buffer buffer, value): + if value is None: + buffer.write_int8(NULL_FLAG) + else: + buffer.write_int8(NOT_NULL_VALUE_FLAG) + from pyfory.bfloat16 import bfloat16 + if isinstance(value, bfloat16): + buffer.write_bfloat16(value.to_bits()) + else: + buffer.write_bfloat16(bfloat16(value).to_bits()) + +cpdef inline read_nullable_bfloat16(Buffer buffer): + if buffer.read_int8() == NOT_NULL_VALUE_FLAG: + return buffer.read_bfloat16() + else: + return None + cdef class Serializer: """ Base class for type-specific serializers. Serializer defines the interface for serializing and deserializing objects of a - specific type. + specific type. Each serializer implements two modes: + + - Python-native mode (write/read): Optimized for Python-to-Python serialization, + supporting all Python-specific features like __reduce__, local functions, etc. + + - Cross-language mode (xwrite/xread): Serializes to a cross-language format + compatible with other Fory implementations (Java, Go, Rust, C++, etc). Custom serializers can be registered for user-defined types using Fory.register_serializer() to override default serialization behavior. @@ -1732,7 +1838,7 @@ cdef class Serializer: Note: This is a base class for implementing custom serializers. Subclasses must - implement write() and read() methods. + implement write(), read(), xwrite(), and xread() methods. """ cdef readonly Fory fory cdef readonly object type_ @@ -1749,32 +1855,35 @@ cdef class Serializer: cpdef read(self, Buffer buffer): raise NotImplementedError(f"read method not implemented in {type(self)}") + cpdef xwrite(self, Buffer buffer, value): + self.write(buffer, value) + + cpdef xread(self, Buffer buffer): + return self.read(buffer) + @classmethod def support_subclass(cls) -> bool: return False - @cython.final cdef class EnumSerializer(Serializer): - cdef tuple _members - cdef dict _ordinal_by_member - - def __init__(self, fory, type_): - super().__init__(fory, type_) - self.need_to_write_ref = False - self._members = tuple(type_) - self._ordinal_by_member = {member: idx for idx, member in enumerate(self._members)} - @classmethod def support_subclass(cls) -> bool: return True cpdef inline write(self, Buffer buffer, value): - buffer.write_var_uint32(self._ordinal_by_member[value]) + buffer.write_string(value.name) cpdef inline read(self, Buffer buffer): + name = buffer.read_string() + return getattr(self.type_, name) + + cpdef inline xwrite(self, Buffer buffer, value): + buffer.write_var_uint32(value.value) + + cpdef inline xread(self, Buffer buffer): ordinal = buffer.read_var_uint32() - return self._members[ordinal] + return self.type_(ordinal) @cython.final @@ -1828,5 +1937,12 @@ cdef class SliceSerializer(Serializer): step = self.fory.read_no_ref(buffer) return slice(start, stop, step) + cpdef xwrite(self, Buffer buffer, value): + raise NotImplementedError + + cpdef xread(self, Buffer buffer): + raise NotImplementedError + + include "primitive.pxi" include "collection.pxi" diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index 6458df9113..badc91dea9 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -64,6 +64,7 @@ TaggedUint64Serializer, Float32Serializer, Float64Serializer, + BFloat16Serializer, StringSerializer, DateSerializer, TimestampSerializer, @@ -99,6 +100,7 @@ TaggedUint64Serializer, Float32Serializer, Float64Serializer, + BFloat16Serializer, StringSerializer, DateSerializer, TimestampSerializer, @@ -147,6 +149,12 @@ def __init__(self, fory): super().__init__(fory, None) self.need_to_write_ref = False + def xwrite(self, buffer, value): + raise NotImplementedError + + def xread(self, buffer): + raise NotImplementedError + def write(self, buffer, value): pass @@ -214,6 +222,12 @@ def read(self, buffer): name = self.fory.read_ref(buffer) return self.type_(start, stop, step, dtype=dtype, name=name) + def xwrite(self, buffer, value): + raise NotImplementedError + + def xread(self, buffer): + raise NotImplementedError + # Use numpy array or python array module. typecode_dict = ( @@ -270,6 +284,7 @@ def read(self, buffer): TypeId.UINT64_ARRAY: "Q", TypeId.FLOAT32_ARRAY: "f", TypeId.FLOAT64_ARRAY: "d", + TypeId.BFLOAT16_ARRAY: "H", # bfloat16 uses 'H' typecode (uint16) } ) @@ -309,7 +324,7 @@ def __init__(self, fory, ftype, type_id: str): self.typecode = typeid_code[type_id] self.itemsize, ftype, self.type_id = typecode_dict[self.typecode] - def write(self, buffer, value): + def xwrite(self, buffer, value): assert value.itemsize == self.itemsize view = memoryview(value) assert view.format == self.typecode @@ -325,7 +340,7 @@ def write(self, buffer, value): swapped.byteswap() buffer.write_buffer(swapped) - def read(self, buffer): + def xread(self, buffer): data = buffer.read_bytes_and_size() arr = array.array(self.typecode, []) arr.frombytes(data) @@ -334,14 +349,37 @@ def read(self, buffer): arr.byteswap() return arr + def write(self, buffer, value: array.array): + nbytes = len(value) * value.itemsize + buffer.write_string(value.typecode) + buffer.write_var_uint32(nbytes) + if is_little_endian or value.itemsize == 1: + buffer.write_buffer(value) + else: + # Swap bytes on big-endian machines for multi-byte types + swapped = array.array(value.typecode, value) + swapped.byteswap() + buffer.write_buffer(swapped) + + def read(self, buffer): + typecode = buffer.read_string() + data = buffer.read_bytes_and_size() + arr = array.array(typecode[0], []) # Take first character + arr.frombytes(data) + if not is_little_endian and arr.itemsize > 1: + # Swap bytes on big-endian machines for multi-byte types + arr.byteswap() + return arr + class DynamicPyArraySerializer(Serializer): """Serializer for dynamic Python arrays that handles any typecode.""" def __init__(self, fory, cls): super().__init__(fory, cls) + self._serializer = ReduceSerializer(fory, cls) - def write(self, buffer, value): + def xwrite(self, buffer, value): itemsize, ftype, type_id = typecode_dict[value.typecode] view = memoryview(value) nbytes = len(value) * itemsize @@ -363,7 +401,7 @@ def write(self, buffer, value): swapped.byteswap() buffer.write_buffer(swapped) - def read(self, buffer): + def xread(self, buffer): type_id = buffer.read_uint8() typecode = typeid_code[type_id] itemsize = typecode_dict[typecode][0] @@ -374,6 +412,50 @@ def read(self, buffer): arr.byteswap() return arr + def write(self, buffer, value): + self._serializer.write(buffer, value) + + def read(self, buffer): + return self._serializer.read(buffer) + + +class BFloat16ArraySerializer(Serializer): + def __init__(self, fory, ftype, type_id: int): + super().__init__(fory, ftype) + self.type_id = type_id + self.itemsize = 2 + + def write(self, buffer, value): + from pyfory.bfloat16_array import BFloat16Array + + if isinstance(value, BFloat16Array): + arr_data = value._data + elif isinstance(value, array.array) and value.typecode == "H": + arr_data = value + else: + arr_data = BFloat16Array(value)._data + nbytes = len(arr_data) * 2 + buffer.write_var_uint32(nbytes) + if nbytes > 0: + if is_little_endian: + buffer.write_buffer(arr_data) + else: + swapped = array.array("H", arr_data) + swapped.byteswap() + buffer.write_buffer(swapped) + + def read(self, buffer): + from pyfory.bfloat16_array import BFloat16Array + + data = buffer.read_bytes_and_size() + arr = array.array("H", []) + arr.frombytes(data) + if not is_little_endian: + arr.byteswap() + bf16_arr = BFloat16Array.__new__(BFloat16Array) + bf16_arr._data = arr + return bf16_arr + if np: _np_dtypes_dict = ( @@ -407,7 +489,6 @@ def read(self, buffer): ) else: _np_dtypes_dict = {} -_np_typeid_to_dtype = {type_id: dtype for dtype, (_, _, _, type_id) in _np_dtypes_dict.items()} class Numpy1DArraySerializer(Serializer): @@ -417,8 +498,9 @@ def __init__(self, fory, ftype, dtype): super().__init__(fory, ftype) self.dtype = dtype self.itemsize, self.typecode, _, self.type_id = _np_dtypes_dict[self.dtype] + self._serializer = ReduceSerializer(fory, np.ndarray) - def write(self, buffer, value): + def xwrite(self, buffer, value): assert value.itemsize == self.itemsize view = memoryview(value) try: @@ -440,7 +522,7 @@ def write(self, buffer, value): # Swap bytes on big-endian machines for multi-byte types buffer.write_bytes(value.astype(value.dtype.newbyteorder("<")).tobytes()) - def read(self, buffer): + def xread(self, buffer): data = buffer.read_bytes_and_size() arr = np.frombuffer(data, dtype=self.dtype.newbyteorder("<")) if self.itemsize > 1: @@ -452,53 +534,32 @@ def read(self, buffer): arr = arr.astype(self.dtype) return arr + def write(self, buffer, value): + self._serializer.write(buffer, value) + + def read(self, buffer): + return self._serializer.read(buffer) + class NDArraySerializer(Serializer): - def write(self, buffer, value): - # Write concrete 1D primitive ndarray using type id + bytes payload. - dtype_info = _np_dtypes_dict.get(value.dtype) - if dtype_info is None or value.ndim != 1: - raise NotImplementedError(f"Unsupported ndarray: dtype={value.dtype}, ndim={value.ndim}") - itemsize, _typecode, _ftype, type_id = dtype_info + def xwrite(self, buffer, value): + itemsize, typecode, ftype, type_id = _np_dtypes_dict[value.dtype] view = memoryview(value) nbytes = len(value) * itemsize buffer.write_uint8(type_id) buffer.write_var_uint32(nbytes) if value.dtype == np.dtype("bool") or not view.c_contiguous: - if not is_little_endian and itemsize > 1: - buffer.write_bytes(value.astype(value.dtype.newbyteorder("<")).tobytes()) - else: - buffer.write_bytes(value.tobytes()) - elif is_little_endian or itemsize == 1: - buffer.write_buffer(value) + buffer.write_bytes(value.tobytes()) else: - buffer.write_bytes(value.astype(value.dtype.newbyteorder("<")).tobytes()) - - def read(self, buffer): - type_id = buffer.read_uint8() - dtype = _np_typeid_to_dtype.get(type_id) - if dtype is None: - raise NotImplementedError(f"Unsupported ndarray type id: {type_id}") - data = buffer.read_bytes_and_size() - arr = np.frombuffer(data, dtype=dtype.newbyteorder("<")) - if dtype.itemsize > 1: - if is_little_endian: - arr = arr.view(dtype) - else: - arr = arr.astype(dtype) - return arr + buffer.write_buffer(value) + def xread(self, buffer): + raise NotImplementedError("Multi-dimensional array not supported currently") -class PythonNDArraySerializer(NDArraySerializer): def write(self, buffer, value): - dtype_info = _np_dtypes_dict.get(value.dtype) - if dtype_info is not None and value.ndim == 1: - super().write(buffer, value) - return - fory = self.fory dtype = value.dtype - buffer.write_string(dtype.str) + fory.write_ref(buffer, dtype) buffer.write_var_uint32(len(value.shape)) for dim in value.shape: buffer.write_var_uint32(dim) @@ -510,22 +571,8 @@ def write(self, buffer, value): fory.write_buffer_object(buffer, NDArrayBufferObject(value)) def read(self, buffer): - reader_index = buffer.get_reader_index() - type_id = buffer.read_uint8() - dtype = _np_typeid_to_dtype.get(type_id) - if dtype is not None: - data = buffer.read_bytes_and_size() - arr = np.frombuffer(data, dtype=dtype.newbyteorder("<")) - if dtype.itemsize > 1: - if is_little_endian: - arr = arr.view(dtype) - else: - arr = arr.astype(dtype) - return arr - - buffer.set_reader_index(reader_index) fory = self.fory - dtype = np.dtype(buffer.read_string()) + dtype = fory.read_ref(buffer) ndim = buffer.read_var_uint32() shape = tuple(buffer.read_var_uint32() for _ in range(ndim)) if dtype.kind == "O": @@ -1200,6 +1247,12 @@ def _deserialize_function(self, buffer): func = result return func + def xwrite(self, buffer, value): + raise NotImplementedError() + + def xread(self, buffer): + raise NotImplementedError() + def write(self, buffer, value): """Serialize a function for Python-only mode.""" self._serialize_function(buffer, value) @@ -1264,6 +1317,12 @@ def read(self, buffer): method = result return method + def xwrite(self, buffer, value): + return self.write(buffer, value) + + def xread(self, buffer): + return self.read(buffer) + class ObjectSerializer(Serializer): """Serializer for regular Python objects. @@ -1306,6 +1365,14 @@ def read(self, buffer): setattr(obj, field_name, field_value) return obj + def xwrite(self, buffer, value): + # for cross-language or minimal framing, reuse the same logic + return self.write(buffer, value) + + def xread(self, buffer): + # symmetric to xwrite + return self.read(buffer) + @dataclasses.dataclass class NonExistEnum: @@ -1323,9 +1390,16 @@ def support_subclass(cls) -> bool: return True def write(self, buffer, value): - buffer.write_var_uint32(value.value) + buffer.write_string(value.name) def read(self, buffer): + name = buffer.read_string() + return NonExistEnum(name=name) + + def xwrite(self, buffer, value): + buffer.write_var_uint32(value.value) + + def xread(self, buffer): value = buffer.read_var_uint32() return NonExistEnum(value=value) @@ -1337,6 +1411,12 @@ def write(self, buffer, value): def read(self, buffer): return self.fory.handle_unsupported_read(buffer) + def xwrite(self, buffer, value): + raise NotImplementedError(f"{self.type_} is not supported for xwrite") + + def xread(self, buffer): + raise NotImplementedError(f"{self.type_} is not supported for xread") + __all__ = [ # Base serializers (imported) @@ -1361,6 +1441,7 @@ def read(self, buffer): "TaggedUint64Serializer", "Float32Serializer", "Float64Serializer", + "BFloat16Serializer", "StringSerializer", "DateSerializer", "TimestampSerializer", diff --git a/python/pyfory/struct.py b/python/pyfory/struct.py index b844b37b90..324bd08e6d 100644 --- a/python/pyfory/struct.py +++ b/python/pyfory/struct.py @@ -286,6 +286,7 @@ def __init__( serializers: List[Serializer] = None, nullable_fields: Dict[str, bool] = None, dynamic_fields: Dict[str, bool] = None, + xlang=None, ): super().__init__(fory, clz) @@ -767,7 +768,11 @@ def _gen_generated_write_method(self): # dynamic=True: don't pass serializer, write actual type info # dynamic=False: pass serializer, use declared type serializer_arg = "None" if is_dynamic else serializer_var - stmts.append(f"{fory}.write_ref({buffer}, {field_value}, serializer={serializer_arg})") + if self.fory.xlang: + stmts.append(f"{fory}.xwrite_ref({buffer}, {field_value}, {serializer_arg})") + else: + # Python-native write_ref doesn't take serializer kwarg. + stmts.append(f"{fory}.write_ref({buffer}, {field_value})") else: stmt = self._get_write_stmt_for_codegen(serializer, buffer, field_value) if stmt is None: @@ -776,7 +781,7 @@ def _gen_generated_write_method(self): if is_dynamic: stmt = f"{fory}.write_no_ref({buffer}, {field_value})" else: - stmt = f"{fory}.write_no_ref({buffer}, {field_value}, serializer={serializer_var})" + stmt = f"{fory}.write_no_ref({buffer}, {field_value}, {serializer_var})" # In compatible mode, handle None for non-nullable fields (schema evolution) # Write zero/default value when field is None due to missing from remote schema if self.fory.compatible: @@ -873,7 +878,11 @@ def _gen_generated_read_method(self): # dynamic=True: don't pass serializer, read type info from buffer # dynamic=False: pass serializer, use declared type serializer_arg = "None" if is_dynamic else serializer_var - stmts.append(f"{field_value} = {fory}.read_ref({buffer}, serializer={serializer_arg})") + if self.fory.xlang: + stmts.append(f"{field_value} = {fory}.xread_ref({buffer}, {serializer_arg})") + else: + # Python-native read_ref doesn't take serializer kwarg. + stmts.append(f"{field_value} = {fory}.read_ref({buffer})") else: stmt = self._get_read_stmt_for_codegen(serializer, buffer, field_value) if stmt is None: @@ -882,7 +891,7 @@ def _gen_generated_read_method(self): if is_dynamic: stmt = f"{field_value} = {fory}.read_no_ref({buffer})" else: - stmt = f"{field_value} = {fory}.read_no_ref({buffer}, serializer={serializer_var})" + stmt = f"{field_value} = {fory}.read_no_ref({buffer}, {serializer_var})" stmts.append(stmt) if field_name not in current_class_field_names: @@ -950,12 +959,15 @@ def write(self, buffer: Buffer, value): else: # dynamic=True: don't pass serializer, write actual type info # dynamic=False: pass serializer, use declared type - self.fory.write_ref(buffer, field_value, serializer=None if is_dynamic else serializer) + if self.fory.xlang: + self.fory.xwrite_ref(buffer, field_value, None if is_dynamic else serializer) + else: + self.fory.write_ref(buffer, field_value) else: if is_dynamic: self.fory.write_no_ref(buffer, field_value) else: - self.fory.write_no_ref(buffer, field_value, serializer=serializer) + self.fory.write_no_ref(buffer, field_value, serializer) def read(self, buffer): """Read dataclass instance from buffer. @@ -992,12 +1004,15 @@ def read(self, buffer): buffer.set_reader_index(buffer.get_reader_index() - 1) # dynamic=True: don't pass serializer, read type info from buffer # dynamic=False: pass serializer, use declared type - field_value = self.fory.read_ref(buffer, serializer=None if is_dynamic else serializer) + if self.fory.xlang: + field_value = self.fory.xread_ref(buffer, None if is_dynamic else serializer) + else: + field_value = self.fory.read_ref(buffer) else: if is_dynamic: field_value = self.fory.read_no_ref(buffer) else: - field_value = self.fory.read_no_ref(buffer, serializer=serializer) + field_value = self.fory.read_no_ref(buffer, serializer) if field_name in current_class_field_names: setattr(obj, field_name, field_value) read_field_names.add(field_name) @@ -1017,7 +1032,7 @@ def read(self, buffer): class DataClassStubSerializer(DataClassSerializer): - def __init__(self, fory, clz: type): + def __init__(self, fory, clz: type, xlang=None): Serializer.__init__(self, fory, clz) def write(self, buffer, value): diff --git a/python/pyfory/tests/test_bfloat16.py b/python/pyfory/tests/test_bfloat16.py new file mode 100644 index 0000000000..5f7c543bbf --- /dev/null +++ b/python/pyfory/tests/test_bfloat16.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import math +import pytest + +from pyfory import Fory +from pyfory.bfloat16 import bfloat16 +from pyfory.bfloat16_array import BFloat16Array +from pyfory.types import TypeId + + +def ser_de(fory, value): + data = fory.serialize(value) + return fory.deserialize(data) + + +def test_bfloat16_basic(): + bf16 = bfloat16(3.14) + assert isinstance(bf16, bfloat16) + assert bf16.to_float32() == pytest.approx(3.14, abs=0.01) + bits = bf16.to_bits() + assert bfloat16.from_bits(bits).to_bits() == bits + + +def test_bfloat16_special_values(): + assert bfloat16(float("nan")).is_nan() + assert bfloat16(float("inf")).is_inf() + assert bfloat16(float("-inf")).is_inf() + assert bfloat16(0.0).is_zero() + assert bfloat16(1.0).is_finite() + assert not bfloat16(1.0).is_nan() + assert not bfloat16(1.0).is_inf() + + +def test_bfloat16_conversion(): + assert bfloat16(0.0).to_float32() == 0.0 + assert bfloat16(1.0).to_float32() == 1.0 + assert bfloat16(-1.0).to_float32() == -1.0 + assert bfloat16(3.14).to_float32() == pytest.approx(3.14, abs=0.01) + assert math.isnan(bfloat16(float("nan")).to_float32()) + assert math.isinf(bfloat16(float("inf")).to_float32()) + assert math.isinf(bfloat16(float("-inf")).to_float32()) + + +def test_bfloat16_serialization(): + fory = Fory(xlang=True) + assert ser_de(fory, bfloat16(0.0)).to_bits() == bfloat16(0.0).to_bits() + assert ser_de(fory, bfloat16(1.0)).to_bits() == bfloat16(1.0).to_bits() + assert ser_de(fory, bfloat16(3.14)).to_bits() == bfloat16(3.14).to_bits() + assert ser_de(fory, bfloat16(float("inf"))).is_inf() + assert ser_de(fory, bfloat16(float("nan"))).is_nan() + + +def test_bfloat16_array_basic(): + arr = BFloat16Array([1.0, 2.0, 3.14]) + assert len(arr) == 3 + assert arr[0].to_float32() == pytest.approx(1.0) + arr[0] = bfloat16(5.0) + assert arr[0].to_float32() == pytest.approx(5.0) + + +def test_bfloat16_array_serialization(): + fory = Fory(xlang=True) + arr = BFloat16Array([1.0, 2.0, 3.14]) + result = ser_de(fory, arr) + assert isinstance(result, BFloat16Array) + assert len(result) == 3 + assert result[0].to_float32() == pytest.approx(1.0) + + +def test_bfloat16_in_dataclass(): + from dataclasses import dataclass + + @dataclass + class TestStruct: + value: bfloat16 + arr: BFloat16Array + + fory = Fory(xlang=True) + fory.register_type(TestStruct) + obj = TestStruct(value=bfloat16(3.14), arr=BFloat16Array([1.0, 2.0])) + result = ser_de(fory, obj) + assert result.value.to_float32() == pytest.approx(3.14, abs=0.01) + assert len(result.arr) == 2 + + +def test_bfloat16_in_list(): + fory = Fory(xlang=True) + values = [bfloat16(1.0), bfloat16(2.0)] + result = ser_de(fory, values) + assert len(result) == 2 + assert result[0].to_float32() == pytest.approx(1.0) + + +def test_bfloat16_in_map(): + fory = Fory(xlang=True) + data = {"a": bfloat16(1.0), "b": bfloat16(2.0)} + result = ser_de(fory, data) + assert result["a"].to_float32() == pytest.approx(1.0) + + +def test_bfloat16_type_registration(): + fory = Fory(xlang=True) + type_info = fory.type_resolver.get_type_info(bfloat16) + assert type_info.type_id == TypeId.BFLOAT16 + + +def test_bfloat16_array_type_registration(): + fory = Fory(xlang=True) + type_info = fory.type_resolver.get_type_info(BFloat16Array) + assert type_info.type_id == TypeId.BFLOAT16_ARRAY diff --git a/python/pyfory/types.py b/python/pyfory/types.py index 7f8f871dd8..db4333a83b 100644 --- a/python/pyfory/types.py +++ b/python/pyfory/types.py @@ -198,6 +198,7 @@ def is_type_share_meta(type_id: int) -> bool: tagged_uint64 = TypeVar("tagged_uint64", bound=int) float32 = TypeVar("float32", bound=float) float64 = TypeVar("float64", bound=float) +bfloat16 = TypeVar("bfloat16", bound=float) class RefMeta: @@ -314,6 +315,7 @@ def get_primitive_type_size(type_id) -> int: uint64_array = TypeVar("uint64_array", bound=array.ArrayType) float32_array = TypeVar("float32_array", bound=array.ArrayType) float64_array = TypeVar("float64_array", bound=array.ArrayType) +bfloat16_array = TypeVar("bfloat16_array", bound=array.ArrayType) BoolNDArrayType = TypeVar("BoolNDArrayType", bound=ndarray) Int8NDArrayType = TypeVar("Int8NDArrayType", bound=ndarray) Uint8NDArrayType = TypeVar("Uint8NDArrayType", bound=ndarray) @@ -351,6 +353,7 @@ def get_primitive_type_size(type_id) -> int: uint64_array, float32_array, float64_array, + bfloat16_array, } _np_array_types = { BoolNDArrayType, @@ -384,6 +387,7 @@ def is_py_array_type(type_) -> bool: TypeId.UINT64_ARRAY, TypeId.FLOAT32_ARRAY, TypeId.FLOAT64_ARRAY, + TypeId.BFLOAT16_ARRAY, } diff --git a/python/setup.py b/python/setup.py index 6dc32ade8a..991fffe760 100644 --- a/python/setup.py +++ b/python/setup.py @@ -18,6 +18,7 @@ import os import platform import subprocess +import time from os.path import abspath, join as pjoin from setuptools import setup @@ -41,10 +42,33 @@ print(f"fory_cpp_src_dir: {fory_cpp_src_dir}") +def _configure_bazel_shell_for_windows(): + if os.name != "nt": + return + # Bazel genrules require a POSIX shell; prefer Git Bash on Windows. + candidates = [] + for env_key in ("BAZEL_SH", "GIT_BASH", "BASH"): + value = os.environ.get(env_key) + if value: + candidates.append(value) + program_files = [os.environ.get("ProgramFiles"), os.environ.get("ProgramFiles(x86)")] + for base in program_files: + if not base: + continue + candidates.append(pjoin(base, "Git", "bin", "bash.exe")) + candidates.append(pjoin(base, "Git", "usr", "bin", "bash.exe")) + for path in candidates: + if os.path.exists(path): + os.environ["BAZEL_SH"] = path + print(f"Using BAZEL_SH={path}") + return + + class BinaryDistribution(Distribution): def __init__(self, attrs=None): super().__init__(attrs=attrs) if BAZEL_BUILD_EXT: + _configure_bazel_shell_for_windows() import sys python_version = f"{sys.version_info.major}.{sys.version_info.minor}" @@ -59,7 +83,18 @@ def __init__(self, attrs=None): bazel_args += ["//:cp_fory_so"] # Ensure Windows path compatibility cwd_path = os.path.normpath(project_dir) - subprocess.check_call(bazel_args, cwd=cwd_path) + max_attempts = 3 + for attempt in range(1, max_attempts + 1): + try: + subprocess.check_call(bazel_args, cwd=cwd_path) + break + except subprocess.CalledProcessError: + if attempt == max_attempts: + raise + # Retry transient dependency fetch failures (e.g. 502 from external archives). + backoff_seconds = 5 * attempt + print(f"Bazel build failed (attempt {attempt}/{max_attempts}), retrying in {backoff_seconds}s...") + time.sleep(backoff_seconds) def has_ext_modules(self): return True