diff --git a/examples/dataframe_example.py b/examples/dataframe_example.py new file mode 100644 index 0000000..f2d14d6 --- /dev/null +++ b/examples/dataframe_example.py @@ -0,0 +1,17 @@ +from substrait.builders.plan import read_named_table +from substrait.builders.type import i64, boolean, struct, named_struct +from substrait.extension_registry import ExtensionRegistry +import substrait.dataframe as sdf + +registry = ExtensionRegistry(load_default_extensions=True) + +ns = named_struct( + names=["id", "is_applicable"], + struct=struct(types=[i64(nullable=False), boolean()], nullable=False), +) + +table = read_named_table("example_table", ns) + +frame = sdf.DataFrame(read_named_table("example_table", ns)) +frame = frame.select(sdf.col("id")) +print(frame.to_substrait(registry)) diff --git a/examples/narwhals_example.py b/examples/narwhals_example.py new file mode 100644 index 0000000..736af04 --- /dev/null +++ b/examples/narwhals_example.py @@ -0,0 +1,36 @@ +# Install duckdb and pyarrow before running this example +# /// script +# dependencies = [ +# "narwhals==2.9.0", +# "substrait[extensions] @ file:///${PROJECT_ROOT}/" +# ] +# /// + +from substrait.builders.plan import read_named_table +from substrait.builders.type import i64, boolean, struct, named_struct +from substrait.extension_registry import ExtensionRegistry + +from narwhals.typing import FrameT +import narwhals as nw +import substrait.dataframe as sdf + + +registry = ExtensionRegistry(load_default_extensions=True) + +ns = named_struct( + names=["id", "is_applicable"], + struct=struct(types=[i64(nullable=False), boolean()], nullable=False), +) + +table = read_named_table("example_table", ns) + + +lazy_frame: FrameT = nw.from_native( + sdf.DataFrame(read_named_table("example_table", ns)) +) + +lazy_frame = lazy_frame.select(nw.col("id").abs(), new_id=nw.col("id")) + +df: sdf.DataFrame = lazy_frame.to_native() + +print(df.to_substrait(registry)) diff --git a/src/substrait/dataframe/__init__.py b/src/substrait/dataframe/__init__.py new file mode 100644 index 0000000..8f2271b --- /dev/null +++ b/src/substrait/dataframe/__init__.py @@ -0,0 +1,16 @@ +import substrait.dataframe +from substrait.builders.extended_expression import column + +from substrait.dataframe.dataframe import DataFrame +from substrait.dataframe.expression import Expression + +__all__ = [DataFrame, Expression] + + +def col(name: str) -> Expression: + """Column selection.""" + return Expression(column(name)) + +# TODO handle str_as_lit argument +def parse_into_expr(expr, str_as_lit: bool): + return expr._to_compliant_expr(substrait.dataframe) diff --git a/src/substrait/dataframe/dataframe.py b/src/substrait/dataframe/dataframe.py new file mode 100644 index 0000000..f5d25d2 --- /dev/null +++ b/src/substrait/dataframe/dataframe.py @@ -0,0 +1,36 @@ +from typing import Union, Iterable +import substrait.dataframe +from substrait.builders.plan import select +from substrait.dataframe.expression import Expression + + +class DataFrame: + def __init__(self, plan): + self.plan = plan + self._native_frame = self + + def to_substrait(self, registry): + return self.plan(registry) + + def __narwhals_lazyframe__(self) -> "DataFrame": + """Return object implementing CompliantDataFrame protocol.""" + return self + + def __narwhals_namespace__(self): + """ + Return the namespace object that contains functions like col, lit, etc. + This is how Narwhals knows which backend's functions to use. + """ + return substrait.dataframe + + def select( + self, *exprs: Union[Expression, Iterable[Expression]], **named_exprs: Expression + ) -> "DataFrame": + expressions = [e.expr for e in exprs] + [ + expr.alias(alias).expr for alias, expr in named_exprs.items() + ] + return DataFrame(select(self.plan, expressions=expressions)) + + # TODO handle version + def _with_version(self, version): + return self diff --git a/src/substrait/dataframe/expression.py b/src/substrait/dataframe/expression.py new file mode 100644 index 0000000..011b625 --- /dev/null +++ b/src/substrait/dataframe/expression.py @@ -0,0 +1,36 @@ +from substrait.builders.extended_expression import ( + UnboundExtendedExpression, + ExtendedExpressionOrUnbound, + resolve_expression, + scalar_function +) +import substrait.gen.proto.type_pb2 as stp +import substrait.gen.proto.extended_expression_pb2 as stee +from substrait.extension_registry import ExtensionRegistry + + +def _alias( + expr: ExtendedExpressionOrUnbound, + alias: str = None, +): + def resolve( + base_schema: stp.NamedStruct, registry: ExtensionRegistry + ) -> stee.ExtendedExpression: + bound_expression = resolve_expression(expr, base_schema, registry) + bound_expression.referred_expr[0].output_names[0] = alias + return bound_expression + + return resolve + + +class Expression: + def __init__(self, expr: UnboundExtendedExpression): + self.expr = expr + + def alias(self, alias: str): + self.expr = _alias(self.expr, alias) + return self + + def abs(self): + self.expr = scalar_function("functions_arithmetic.yaml", "abs", expressions=[self.expr]) + return self diff --git a/src/substrait/gen/json/simple_extensions.py b/src/substrait/gen/json/simple_extensions.py index 765fbef..c1c8bcf 100644 --- a/src/substrait/gen/json/simple_extensions.py +++ b/src/substrait/gen/json/simple_extensions.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Optional, TypeAlias, Union +from typing import Any, TypeAlias class Functions(Enum): @@ -13,7 +13,7 @@ class Functions(Enum): SEPARATE = 'SEPARATE' -Type: TypeAlias = Union[str, Dict[str, Any]] +Type: TypeAlias = str | dict[str, Any] class Type1(Enum): @@ -24,41 +24,41 @@ class Type1(Enum): string = 'string' -EnumOptions: TypeAlias = List[str] +EnumOptions: TypeAlias = list[str] @dataclass class EnumerationArg: options: EnumOptions - name: Optional[str] = None - description: Optional[str] = None + name: str | None = None + description: str | None = None @dataclass class ValueArg: value: Type - name: Optional[str] = None - description: Optional[str] = None - constant: Optional[bool] = None + name: str | None = None + description: str | None = None + constant: bool | None = None @dataclass class TypeArg: type: str - name: Optional[str] = None - description: Optional[str] = None + name: str | None = None + description: str | None = None -Arguments: TypeAlias = List[Union[EnumerationArg, ValueArg, TypeArg]] +Arguments: TypeAlias = list[EnumerationArg | ValueArg | TypeArg] @dataclass class Options1: - values: List[str] - description: Optional[str] = None + values: list[str] + description: str | None = None -Options: TypeAlias = Dict[str, Options1] +Options: TypeAlias = dict[str, Options1] class ParameterConsistency(Enum): @@ -68,9 +68,9 @@ class ParameterConsistency(Enum): @dataclass class VariadicBehavior: - min: Optional[float] = None - max: Optional[float] = None - parameterConsistency: Optional[ParameterConsistency] = None + min: float | None = None + max: float | None = None + parameterConsistency: ParameterConsistency | None = None Deterministic: TypeAlias = bool @@ -88,7 +88,7 @@ class NullabilityHandling(Enum): ReturnValue: TypeAlias = Type -Implementation: TypeAlias = Dict[str, str] +Implementation: TypeAlias = dict[str, str] Intermediate: TypeAlias = Type @@ -109,43 +109,43 @@ class Decomposable(Enum): @dataclass class Impl: return_: ReturnValue - args: Optional[Arguments] = None - options: Optional[Options] = None - variadic: Optional[VariadicBehavior] = None - sessionDependent: Optional[SessionDependent] = None - deterministic: Optional[Deterministic] = None - nullability: Optional[NullabilityHandling] = None - implementation: Optional[Implementation] = None + args: Arguments | None = None + options: Options | None = None + variadic: VariadicBehavior | None = None + sessionDependent: SessionDependent | None = None + deterministic: Deterministic | None = None + nullability: NullabilityHandling | None = None + implementation: Implementation | None = None @dataclass class ScalarFunction: name: str - impls: List[Impl] - description: Optional[str] = None + impls: list[Impl] + description: str | None = None @dataclass class Impl1: return_: ReturnValue - args: Optional[Arguments] = None - options: Optional[Options] = None - variadic: Optional[VariadicBehavior] = None - sessionDependent: Optional[SessionDependent] = None - deterministic: Optional[Deterministic] = None - nullability: Optional[NullabilityHandling] = None - implementation: Optional[Implementation] = None - intermediate: Optional[Intermediate] = None - ordered: Optional[Ordered] = None - maxset: Optional[Maxset] = None - decomposable: Optional[Decomposable] = None + args: Arguments | None = None + options: Options | None = None + variadic: VariadicBehavior | None = None + sessionDependent: SessionDependent | None = None + deterministic: Deterministic | None = None + nullability: NullabilityHandling | None = None + implementation: Implementation | None = None + intermediate: Intermediate | None = None + ordered: Ordered | None = None + maxset: Maxset | None = None + decomposable: Decomposable | None = None @dataclass class AggregateFunction: name: str - impls: List[Impl1] - description: Optional[str] = None + impls: list[Impl1] + description: str | None = None class WindowType(Enum): @@ -156,64 +156,64 @@ class WindowType(Enum): @dataclass class Impl2: return_: ReturnValue - args: Optional[Arguments] = None - options: Optional[Options] = None - variadic: Optional[VariadicBehavior] = None - sessionDependent: Optional[SessionDependent] = None - deterministic: Optional[Deterministic] = None - nullability: Optional[NullabilityHandling] = None - implementation: Optional[Implementation] = None - intermediate: Optional[Intermediate] = None - ordered: Optional[Ordered] = None - maxset: Optional[Maxset] = None - decomposable: Optional[Decomposable] = None - window_type: Optional[WindowType] = None + args: Arguments | None = None + options: Options | None = None + variadic: VariadicBehavior | None = None + sessionDependent: SessionDependent | None = None + deterministic: Deterministic | None = None + nullability: NullabilityHandling | None = None + implementation: Implementation | None = None + intermediate: Intermediate | None = None + ordered: Ordered | None = None + maxset: Maxset | None = None + decomposable: Decomposable | None = None + window_type: WindowType | None = None @dataclass class WindowFunction: name: str - impls: List[Impl2] - description: Optional[str] = None + impls: list[Impl2] + description: str | None = None @dataclass class TypeVariation: parent: Type name: str - description: Optional[str] = None - functions: Optional[Functions] = None + description: str | None = None + functions: Functions | None = None @dataclass class TypeParamDef: type: Type1 - name: Optional[str] = None - description: Optional[str] = None - min: Optional[float] = None - max: Optional[float] = None - options: Optional[EnumOptions] = None - optional: Optional[bool] = None + name: str | None = None + description: str | None = None + min: float | None = None + max: float | None = None + options: EnumOptions | None = None + optional: bool | None = None -TypeParamDefs: TypeAlias = List[TypeParamDef] +TypeParamDefs: TypeAlias = list[TypeParamDef] @dataclass class TypeModel: name: str - description: Optional[str] = None - structure: Optional[Type] = None - parameters: Optional[TypeParamDefs] = None - variadic: Optional[bool] = None + description: str | None = None + structure: Type | None = None + parameters: TypeParamDefs | None = None + variadic: bool | None = None @dataclass class SimpleExtensions: urn: str - dependencies: Optional[Dict[str, str]] = None - types: Optional[List[TypeModel]] = None - type_variations: Optional[List[TypeVariation]] = None - scalar_functions: Optional[List[ScalarFunction]] = None - aggregate_functions: Optional[List[AggregateFunction]] = None - window_functions: Optional[List[WindowFunction]] = None + dependencies: dict[str, str] | None = None + types: list[TypeModel] | None = None + type_variations: list[TypeVariation] | None = None + scalar_functions: list[ScalarFunction] | None = None + aggregate_functions: list[AggregateFunction] | None = None + window_functions: list[WindowFunction] | None = None diff --git a/tests/dataframe/test_df_project.py b/tests/dataframe/test_df_project.py new file mode 100644 index 0000000..8eeecbd --- /dev/null +++ b/tests/dataframe/test_df_project.py @@ -0,0 +1,59 @@ +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.algebra_pb2 as stalg +from substrait.builders.type import boolean, i64 +from substrait.builders.plan import read_named_table, default_version +from substrait.extension_registry import ExtensionRegistry +import substrait.dataframe as sdf + + +registry = ExtensionRegistry(load_default_extensions=False) + +struct = stt.Type.Struct( + types=[i64(nullable=False), boolean()], nullability=stt.Type.NULLABILITY_REQUIRED +) + +named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) + + +def test_project(): + df = sdf.DataFrame(read_named_table("table", named_struct)) + + actual = df.select(id=sdf.col("id")).to_substrait(registry) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + project=stalg.ProjectRel( + common=stalg.RelCommon( + emit=stalg.RelCommon.Emit(output_mapping=[2]) + ), + input=df.to_substrait(None).relations[-1].root.input, + expressions=[ + stalg.Expression( + selection=stalg.Expression.FieldReference( + direct_reference=stalg.Expression.ReferenceSegment( + struct_field=stalg.Expression.ReferenceSegment.StructField( + field=0 + ) + ), + root_reference=stalg.Expression.FieldReference.RootReference(), + ) + ) + ], + ) + ), + names=["id"], + ) + ) + ], + version=default_version, + ) + + print(actual) + print("--------------") + print(expected) + + assert actual == expected