diff --git a/vortex-python/python/vortex/arrow/expression.py b/vortex-python/python/vortex/arrow/expression.py index b306acd1874..931b2222049 100644 --- a/vortex-python/python/vortex/arrow/expression.py +++ b/vortex-python/python/vortex/arrow/expression.py @@ -28,9 +28,26 @@ def ensure_vortex_expression(expression: pc.Expression | Expr | None, *, schema: return expression +def _schema_for_substrait(schema: pa.Schema) -> pa.Schema: + # PyArrow's to_substrait doesn't support view types; map to string/binary. + # This is safe because Vortex handles both equivalently. + # If/When PyArrow to_substrait supports view types, revert. + # Workaround for: https://github.com/vortex-data/vortex/issues/5759 + fields = [] + for field in schema: + if field.type == pa.string_view(): + fields.append(field.with_type(pa.string())) + elif field.type == pa.binary_view(): + fields.append(field.with_type(pa.binary())) + else: + fields.append(field) + return pa.schema(fields) + + def arrow_to_vortex(arrow_expression: pc.Expression, schema: pa.Schema) -> Expr: + compat_schema = _schema_for_substrait(schema) substrait_object = ExtendedExpression() # pyright: ignore[reportUnknownVariableType] - substrait_object.ParseFromString(arrow_expression.to_substrait(schema)) # pyright: ignore[reportUnknownMemberType] + substrait_object.ParseFromString(arrow_expression.to_substrait(compat_schema)) # pyright: ignore[reportUnknownMemberType] expressions = extended_expression(substrait_object) # pyright: ignore[reportUnknownArgumentType] diff --git a/vortex-python/test/test_expression.py b/vortex-python/test/test_expression.py new file mode 100644 index 00000000000..21a6c77eb89 --- /dev/null +++ b/vortex-python/test/test_expression.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright the Vortex contributors + +# Tests the _schema_for_substrait workaround in vortex/arrow/expression.py + +import pyarrow as pa +import pyarrow.compute as pc +import pytest +from vortex.arrow.expression import _schema_for_substrait, arrow_to_vortex + + +class TestSchemaForSubstrait: + """Verifies mapping: string_view=>string, binary_view=>binary, else unchanged""" + + def test_string_view_mapped_to_string(self): + schema = pa.schema([("col", pa.string_view())]) + result = _schema_for_substrait(schema) + assert result.field("col").type == pa.string() + + def test_binary_view_mapped_to_binary(self): + schema = pa.schema([("col", pa.binary_view())]) + result = _schema_for_substrait(schema) + assert result.field("col").type == pa.binary() + + def test_other_types_unchanged(self): + schema = pa.schema( + [ + ("int_col", pa.int64()), + ("str_col", pa.string()), + ("bin_col", pa.binary()), + ("float_col", pa.float64()), + ] + ) + result = _schema_for_substrait(schema) + assert result == schema + + def test_mixed_schema(self): + schema = pa.schema( + [ + ("sv", pa.string_view()), + ("bv", pa.binary_view()), + ("s", pa.string()), + ("i", pa.int64()), + ] + ) + result = _schema_for_substrait(schema) + expected = pa.schema( + [ + ("sv", pa.string()), + ("bv", pa.binary()), + ("s", pa.string()), + ("i", pa.int64()), + ] + ) + assert result == expected + + +class TestArrowToVortexWithViews: + """Tests comparisons over string_views and binary_views""" + + def test_string_view_equality_expression(self): + schema = pa.schema([("name", pa.string_view())]) + expr = pc.field("name") == "alice" + vortex_expr = arrow_to_vortex(expr, schema) + assert vortex_expr is not None + + def test_binary_view_equality_expression(self): + schema = pa.schema([("data", pa.binary_view())]) + expr = pc.field("data") == b"hello" + vortex_expr = arrow_to_vortex(expr, schema) + assert vortex_expr is not None + + def test_string_view_comparison_expression(self): + schema = pa.schema([("name", pa.string_view())]) + expr = pc.field("name") > "bob" + vortex_expr = arrow_to_vortex(expr, schema) + assert vortex_expr is not None + + def test_mixed_view_and_regular_types(self): + schema = pa.schema( + [ + ("id", pa.int64()), + ("name", pa.string_view()), + ("data", pa.binary_view()), + ] + ) + expr = (pc.field("id") > 10) & (pc.field("name") == "test") + vortex_expr = arrow_to_vortex(expr, schema) + assert vortex_expr is not None + + @pytest.mark.parametrize( + "view_type,value", + [ + (pa.string_view(), "test"), + (pa.binary_view(), b"test"), + ], + ) + def test_view_types_parametrized(self, view_type, value): + schema = pa.schema([("col", view_type)]) + expr = pc.field("col") == value + vortex_expr = arrow_to_vortex(expr, schema) + assert vortex_expr is not None