diff --git a/.ai/skills/check-upstream/SKILL.md b/.ai/skills/check-upstream/SKILL.md index 3bac018ef..23873feab 100644 --- a/.ai/skills/check-upstream/SKILL.md +++ b/.ai/skills/check-upstream/SKILL.md @@ -66,11 +66,17 @@ The user may specify an area via `$ARGUMENTS`. If no area is specified or "all" - Python API: `python/datafusion/functions.py` — each function wraps a call to `datafusion._internal.functions` - Rust bindings: `crates/core/src/functions.rs` — `#[pyfunction]` definitions registered via `init_module()` +**Evaluated and not requiring separate Python exposure:** +- `get_field_path` — already covered by `get_field(expr, *names)`, which takes a + variadic field path and dispatches to the same underlying + `functions::core::get_field` UDF as the upstream `get_field_path` helper. + **How to check:** 1. Fetch the upstream scalar function documentation page 2. Compare against functions listed in `python/datafusion/functions.py` (check the `__all__` list and function definitions) 3. A function is covered if it exists in the Python API — it does NOT need a dedicated Rust `#[pyfunction]`. Many functions are aliases that reuse another function's Rust binding. -4. Only report functions that are missing from the Python `__all__` list / function definitions +4. Check against the "evaluated and not requiring exposure" list before flagging as a gap +5. Only report functions that are missing from the Python `__all__` list / function definitions ### 2. Aggregate Functions diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2d3c2bc59..0a212480b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ repos: - repo: https://github.com/rhysd/actionlint - rev: v1.7.6 + rev: v1.7.12 hooks: - id: actionlint-docker - repo: https://github.com/astral-sh/ruff-pre-commit diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index 642afeef7..6a018a85b 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -35,7 +35,6 @@ use datafusion::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }; use datafusion::datasource::{MemTable, TableProvider}; -use datafusion::execution::TaskContextProvider; use datafusion::execution::context::{ DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext, }; @@ -44,6 +43,7 @@ use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, Unboun use datafusion::execution::options::{ArrowReadOptions, ReadOptions}; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::execution::session_state::SessionStateBuilder; +use datafusion::execution::{FunctionRegistry, TaskContextProvider}; use datafusion::prelude::{ AvroReadOptions, CsvReadOptions, DataFrame, JsonReadOptions, ParquetReadOptions, }; @@ -847,6 +847,13 @@ impl PySessionContext { Ok(()) } + pub fn read_batches( + &self, + batches: PyArrowType>, + ) -> PyDataFusionResult { + Ok(PyDataFrame::new(self.ctx.read_batches(batches.0)?)) + } + #[allow(clippy::too_many_arguments)] #[pyo3(signature = (name, path, table_partition_cols=vec![], parquet_pruning=true, @@ -1065,6 +1072,39 @@ impl PySessionContext { self.ctx.deregister_udwf(name); } + pub fn udf(&self, name: &str) -> PyDataFusionResult { + let function = (*self.ctx.udf(name)?).clone(); + Ok(PyScalarUDF { function }) + } + + pub fn udaf(&self, name: &str) -> PyDataFusionResult { + let function = (*self.ctx.udaf(name)?).clone(); + Ok(PyAggregateUDF { function }) + } + + pub fn udwf(&self, name: &str) -> PyDataFusionResult { + let function = (*self.ctx.udwf(name)?).clone(); + Ok(PyWindowUDF { function }) + } + + pub fn udfs(&self) -> Vec { + let mut names: Vec = self.ctx.udfs().into_iter().collect(); + names.sort(); + names + } + + pub fn udafs(&self) -> Vec { + let mut names: Vec = self.ctx.udafs().into_iter().collect(); + names.sort(); + names + } + + pub fn udwfs(&self) -> Vec { + let mut names: Vec = self.ctx.udwfs().into_iter().collect(); + names.sort(); + names + } + #[pyo3(signature = (name="datafusion"))] pub fn catalog(&self, py: Python, name: &str) -> PyResult> { let catalog = self.ctx.catalog(name).ok_or(PyKeyError::new_err(format!( diff --git a/crates/core/src/functions.rs b/crates/core/src/functions.rs index 7feb62d79..5f47d123b 100644 --- a/crates/core/src/functions.rs +++ b/crates/core/src/functions.rs @@ -574,10 +574,10 @@ expr_fn!(union_tag, arg1); expr_fn!(random); #[pyfunction] -fn get_field(expr: PyExpr, name: PyExpr) -> PyExpr { - functions::core::get_field() - .call(vec![expr.into(), name.into()]) - .into() +fn get_field(expr: PyExpr, names: Vec) -> PyExpr { + let mut args = vec![expr.into()]; + args.extend(names.into_iter().map(Into::into)); + functions::core::get_field().call(args).into() } #[pyfunction] diff --git a/examples/datafusion-ffi-example/src/table_function.rs b/examples/datafusion-ffi-example/src/table_function.rs index 79c13f64d..ed3ef142b 100644 --- a/examples/datafusion-ffi-example/src/table_function.rs +++ b/examples/datafusion-ffi-example/src/table_function.rs @@ -17,9 +17,8 @@ use std::sync::Arc; -use datafusion_catalog::{TableFunctionImpl, TableProvider}; +use datafusion_catalog::{TableFunctionArgs, TableFunctionImpl, TableProvider}; use datafusion_common::error::Result as DataFusionResult; -use datafusion_expr::Expr; use datafusion_ffi::udtf::FFI_TableFunction; use datafusion_python_util::ffi_logical_codec_from_pycapsule; use pyo3::types::PyCapsule; @@ -59,7 +58,7 @@ impl MyTableFunction { } impl TableFunctionImpl for MyTableFunction { - fn call(&self, _args: &[Expr]) -> DataFusionResult> { + fn call_with_args(&self, _args: TableFunctionArgs) -> DataFusionResult> { let provider = MyTableProvider::new(4, 3, 2).create_table()?; Ok(Arc::new(provider)) } diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 5c3501941..f4c056b66 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -86,6 +86,7 @@ import pandas as pd import polars as pl # type: ignore[import] + from _typeshed import CapsuleType as _PyCapsule from datafusion.catalog import CatalogProvider, Table from datafusion.common import DFSchema @@ -93,6 +94,8 @@ from datafusion.plan import ExecutionPlan, LogicalPlan from datafusion.user_defined import ( AggregateUDF, + LogicalExtensionCodecExportable, + PhysicalExtensionCodecExportable, ScalarUDF, TableFunction, WindowUDF, @@ -959,6 +962,45 @@ def register_record_batches( """ self.ctx.register_record_batches(name, partitions) + def read_batch(self, batch: pa.RecordBatch) -> DataFrame: + """Return a :py:class:`~datafusion.DataFrame` reading a single batch. + + Convenience wrapper around :py:meth:`read_batches` for the single-batch + case. Unlike :py:meth:`register_batch`, this does not register the + batch as a named table; it returns an anonymous + :py:class:`~datafusion.DataFrame` directly. + + Args: + batch: Record batch to wrap as a DataFrame. + + Examples: + >>> ctx = dfn.SessionContext() + >>> batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3]}) + >>> ctx.read_batch(batch).to_pydict() + {'a': [1, 2, 3]} + """ + return self.read_batches([batch]) + + def read_batches(self, batches: list[pa.RecordBatch]) -> DataFrame: + """Return a :py:class:`~datafusion.DataFrame` reading the given batches. + + All batches must share the same schema. Unlike + :py:meth:`register_record_batches`, this does not register the batches + as a named table; it returns an anonymous + :py:class:`~datafusion.DataFrame` directly. + + Args: + batches: Record batches to wrap as a DataFrame. + + Examples: + >>> ctx = dfn.SessionContext() + >>> b1 = pa.RecordBatch.from_pydict({"a": [1, 2]}) + >>> b2 = pa.RecordBatch.from_pydict({"a": [3, 4]}) + >>> ctx.read_batches([b1, b2]).to_pydict() + {'a': [1, 2, 3, 4]} + """ + return DataFrame(self.ctx.read_batches(batches)) + def register_parquet( self, name: str, @@ -1268,6 +1310,152 @@ def deregister_udwf(self, name: str) -> None: """ self.ctx.deregister_udwf(name) + def udf(self, name: str) -> ScalarUDF: + """Look up a registered scalar UDF by name. + + Returns the same :py:class:`~datafusion.user_defined.ScalarUDF` + wrapper that :py:meth:`register_udf` accepts, so it can be invoked + as an expression in the DataFrame API or re-registered into a + different :py:class:`SessionContext`. Built-in scalar functions + from the session's function registry are also looked up. + + Args: + name: Name of the registered scalar UDF. + + Raises: + Exception: If no scalar UDF is registered under ``name``. + + Examples: + Register a UDF, then look it up by name and use it in the + DataFrame API: + + >>> ctx = dfn.SessionContext() + >>> nullcheck = dfn.udf( + ... lambda x: x.is_null(), + ... [pa.int64()], + ... pa.bool_(), + ... volatility="immutable", + ... name="nullcheck", + ... ) + >>> ctx.register_udf(nullcheck) + >>> fn = ctx.udf("nullcheck") + >>> df = ctx.from_pydict({"a": [1, None, 3]}) + >>> df.select(fn(col("a")).alias("is_null")).to_pydict() + {'is_null': [False, True, False]} + + Late-binding: the function name can come from configuration + rather than an imported symbol, which is useful when the set + of UDFs is plugin-driven or chosen at runtime: + + >>> config = {"null_check": "nullcheck"} + >>> fn = ctx.udf(config["null_check"]) + >>> df.select(fn(col("a")).alias("is_null")).to_pydict() + {'is_null': [False, True, False]} + """ + from datafusion.user_defined import ScalarUDF as _ScalarUDF # noqa: PLC0415 + + wrapper = _ScalarUDF.__new__(_ScalarUDF) + wrapper._udf = self.ctx.udf(name) + return wrapper + + def udaf(self, name: str) -> AggregateUDF: + """Look up a registered aggregate UDF by name. + + Returns the same :py:class:`~datafusion.user_defined.AggregateUDF` + wrapper that :py:meth:`register_udaf` accepts. Built-in aggregate + functions such as ``sum`` or ``avg`` are also discoverable through + this lookup. See :py:meth:`udf` for a worked late-binding example; + the pattern is identical for aggregates. + + Args: + name: Name of the registered aggregate UDF. + + Raises: + Exception: If no aggregate UDF is registered under ``name``. + + Examples: + Look up a built-in aggregate by name and use it in + :py:meth:`~datafusion.DataFrame.aggregate`: + + >>> ctx = dfn.SessionContext() + >>> sum_fn = ctx.udaf("sum") + >>> df = ctx.from_pydict({"a": [1, 2, 3]}) + >>> df.aggregate([], [sum_fn(col("a")).alias("total")]).to_pydict() + {'total': [6]} + """ + from datafusion.user_defined import ( # noqa: PLC0415 + AggregateUDF as _AggregateUDF, + ) + + wrapper = _AggregateUDF.__new__(_AggregateUDF) + wrapper._udaf = self.ctx.udaf(name) + return wrapper + + def udwf(self, name: str) -> WindowUDF: + """Look up a registered window UDF by name. + + Returns the same :py:class:`~datafusion.user_defined.WindowUDF` + wrapper that :py:meth:`register_udwf` accepts. Built-in window + functions such as ``row_number`` or ``rank`` are also discoverable + through this lookup. See :py:meth:`udf` for a worked late-binding + example; the pattern is identical for window functions. + + Args: + name: Name of the registered window UDF. + + Raises: + Exception: If no window UDF is registered under ``name``. + + Examples: + Look up a built-in window function by name and use it in + ``select``: + + >>> ctx = dfn.SessionContext() + >>> rn = ctx.udwf("row_number") + >>> df = ctx.from_pydict({"a": [10, 20, 30]}) + >>> df.select(col("a"), rn().alias("rn")).to_pydict() + {'a': [10, 20, 30], 'rn': [1, 2, 3]} + """ + from datafusion.user_defined import WindowUDF as _WindowUDF # noqa: PLC0415 + + wrapper = _WindowUDF.__new__(_WindowUDF) + wrapper._udwf = self.ctx.udwf(name) + return wrapper + + def udfs(self) -> list[str]: + """Return the sorted names of all registered scalar UDFs. + + Includes both user-registered and built-in scalar functions. Pair + with :py:meth:`udf` to drive discovery, validation, or config-based + dispatch. + + Examples: + >>> ctx = dfn.SessionContext() + >>> "abs" in ctx.udfs() + True + """ + return self.ctx.udfs() + + def udafs(self) -> list[str]: + """Return the sorted names of all registered aggregate UDFs. + + Examples: + >>> ctx = dfn.SessionContext() + >>> "sum" in ctx.udafs() + True + """ + return self.ctx.udafs() + + def udwfs(self) -> list[str]: + """Return the sorted names of all registered window UDFs. + + Examples: + >>> ctx = dfn.SessionContext() + >>> "row_number" in ctx.udwfs() + True + """ + return self.ctx.udwfs() + def catalog(self, name: str = "datafusion") -> Catalog: """Retrieve a catalog by name.""" return Catalog(self.ctx.catalog(name)) @@ -1744,11 +1932,15 @@ def __datafusion_logical_extension_codec__(self) -> Any: """Access the PyCapsule FFI_LogicalExtensionCodec.""" return self.ctx.__datafusion_logical_extension_codec__() - def with_logical_extension_codec(self, codec: Any) -> SessionContext: + def with_logical_extension_codec( + self, codec: LogicalExtensionCodecExportable | _PyCapsule + ) -> SessionContext: """Create a new session context with specified codec. This only supports codecs that have been implemented using the - FFI interface. + FFI interface. ``codec`` must either be a raw ``FFI_LogicalExtensionCodec`` + ``PyCapsule`` or an object exposing + ``__datafusion_logical_extension_codec__``. """ new_internal = self.ctx.with_logical_extension_codec(codec) new = SessionContext.__new__(SessionContext) @@ -1759,11 +1951,15 @@ def __datafusion_physical_extension_codec__(self) -> Any: """Access the PyCapsule FFI_PhysicalExtensionCodec.""" return self.ctx.__datafusion_physical_extension_codec__() - def with_physical_extension_codec(self, codec: Any) -> SessionContext: + def with_physical_extension_codec( + self, codec: PhysicalExtensionCodecExportable | _PyCapsule + ) -> SessionContext: """Create a new session context with the specified physical codec. This only supports codecs that have been implemented using the - FFI interface. + FFI interface. ``codec`` must either be a raw + ``FFI_PhysicalExtensionCodec`` ``PyCapsule`` or an object exposing + ``__datafusion_physical_extension_codec__``. """ new_internal = self.ctx.with_physical_extension_codec(codec) new = SessionContext.__new__(SessionContext) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 9761d1879..02fefd709 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -2727,14 +2727,24 @@ def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr: return Expr(f.arrow_metadata(expr.expr, key.expr)) -def get_field(expr: Expr, name: Expr | str) -> Expr: - """Extracts a field from a struct or map by name. +def get_field(expr: Expr, *names: Expr | str) -> Expr: + """Extracts a (possibly nested) field from a struct or map by name. - When the field name is a static string, the bracket operator - ``expr["field"]`` is a convenient shorthand. Use ``get_field`` - when the field name is a dynamic expression. + Pass one name for a single-level lookup, or several names to walk a path + of nested struct/map fields in a single ``get_field`` call. For a single + static-string name, ``expr["field"]`` is a convenient shorthand; use + ``get_field`` when the field name is a dynamic + :py:class:`~datafusion.expr.Expr` or when traversing multiple levels at + once. + + Args: + expr: The struct or map expression to read from. + *names: One or more field names (``str``) or expressions + (:py:class:`~datafusion.expr.Expr`). Examples: + Single-level lookup: + >>> ctx = dfn.SessionContext() >>> df = ctx.from_pydict({"a": [1], "b": [2]}) >>> df = df.with_column( @@ -2756,10 +2766,26 @@ def get_field(expr: Expr, name: Expr | str) -> Expr: ... ) >>> result.collect_column("x_val")[0].as_py() 1 + + Multi-level lookup: + + >>> df = df.with_column( + ... "outer", + ... dfn.functions.named_struct([("inner", dfn.col("s"))]), + ... ) + >>> result = df.select( + ... dfn.functions.get_field( + ... dfn.col("outer"), "inner", "x" + ... ).alias("x_val") + ... ) + >>> result.collect_column("x_val")[0].as_py() + 1 """ - if isinstance(name, str): - name = Expr.string_literal(name) - return Expr(f.get_field(expr.expr, name.expr)) + if not names: + msg = "get_field requires at least one field name" + raise ValueError(msg) + resolved = [Expr.string_literal(n) if isinstance(n, str) else n for n in names] + return Expr(f.get_field(expr.expr, [n.expr for n in resolved])) def union_extract(union_expr: Expr, field_name: Expr | str) -> Expr: diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 3eb50a094..3d43cce21 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -113,6 +113,18 @@ def _is_pycapsule(value: object) -> TypeGuard[_PyCapsule]: return value.__class__.__name__ == "PyCapsule" +class LogicalExtensionCodecExportable(Protocol): + """Type hint for objects exposing ``__datafusion_logical_extension_codec__``.""" + + def __datafusion_logical_extension_codec__(self) -> object: ... # noqa: D105 + + +class PhysicalExtensionCodecExportable(Protocol): + """Type hint for objects exposing ``__datafusion_physical_extension_codec__``.""" + + def __datafusion_physical_extension_codec__(self) -> object: ... # noqa: D105 + + class ScalarUDF: """Class for performing scalar user-defined functions (UDF). diff --git a/python/tests/test_context.py b/python/tests/test_context.py index e0ebdbae5..27a35c90c 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -905,6 +905,21 @@ def test_register_batch_empty(ctx): assert result[0].num_rows == 0 +def test_read_batch_returns_dataframe(ctx): + batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) + df = ctx.read_batch(batch) + assert df.to_pydict() == {"a": [1, 2, 3], "b": [4, 5, 6]} + # read_batch should not register a named table. + assert ctx.catalog().schema().names() == set() + + +def test_read_batches_concatenates(ctx): + b1 = pa.RecordBatch.from_pydict({"a": [1, 2]}) + b2 = pa.RecordBatch.from_pydict({"a": [3, 4]}) + df = ctx.read_batches([b1, b2]) + assert df.to_pydict() == {"a": [1, 2, 3, 4]} + + def test_create_sql_options(): SQLOptions() diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 6bd0ce9f9..ab3992a79 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -1704,8 +1704,12 @@ def test_repr_rows_backward_compatibility(clean_formatter_state): assert formatter.max_rows == 15 assert formatter.repr_rows == 15 - # Should fail when conflicting with max_rows - with pytest.raises(ValueError, match="Cannot specify both repr_rows and max_rows"): + # Should fail when conflicting with max_rows. The deprecation warning still + # fires before the ValueError, so assert both. + with ( + pytest.raises(ValueError, match="Cannot specify both repr_rows and max_rows"), + pytest.warns(DeprecationWarning, match="repr_rows parameter is deprecated"), + ): DataFrameHtmlFormatter(repr_rows=5, max_rows=10) # Setting repr_rows via property should warn diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 5538fc33b..55d9c8ee8 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -1957,6 +1957,37 @@ def test_get_field(df): assert result.column(1) == pa.array([4, 5, 6]) +def test_get_field_path(df): + df = df.with_column( + "outer", + f.named_struct( + [ + ( + "inner", + f.named_struct( + [ + ("x", column("a")), + ("y", column("b")), + ] + ), + ), + ] + ), + ) + result = df.select( + f.get_field(column("outer"), "inner", "x").alias("x_val"), + f.get_field(column("outer"), "inner", "y").alias("y_val"), + ).collect()[0] + + assert result.column(0) == pa.array(["Hello", "World", "!"], type=pa.string_view()) + assert result.column(1) == pa.array([4, 5, 6]) + + +def test_get_field_requires_a_name(): + with pytest.raises(ValueError, match="at least one field name"): + f.get_field(column("s")) + + def test_arrow_metadata(): ctx = SessionContext() field = pa.field("val", pa.int64(), metadata={"key1": "value1", "key2": "value2"}) diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index 1ed1746e1..924d2655c 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -450,13 +450,9 @@ def test_udf( pa.array([b"1111", b"2222", b"3333"], pa.binary(4), _null_mask), id="binary4", ), - # `timestamp[s]` does not roundtrip for pyarrow.parquet: https://github.com/apache/arrow/issues/41382 pytest.param( helpers.data_datetime("s"), id="datetime_s", - marks=pytest.mark.xfail( - reason="pyarrow.parquet does not support timestamp[s] roundtrips" - ), ), pytest.param( helpers.data_datetime("ms"), @@ -484,6 +480,16 @@ def test_simple_select(ctx, tmp_path, arr): batches = ctx.sql("SELECT a AS tt FROM t").collect() result = batches[0].column(0) + # pyarrow.parquet promotes timestamp[s] to timestamp[ms] on write + # (https://github.com/apache/arrow/issues/41382). Compensate so the + # comparison checks DataFusion reads what Arrow actually stored. + if ( + isinstance(arr, pa.Array) + and pa.types.is_timestamp(arr.type) + and arr.type.unit == "s" + ): + arr = arr.cast(pa.timestamp("ms")) + # In DF 43.0.0 we now default to having BinaryView and StringView # so the array that is saved to the parquet is slightly different # than the array read. Convert to values for comparison. diff --git a/python/tests/test_udf.py b/python/tests/test_udf.py index b2540fb57..189e8e2f3 100644 --- a/python/tests/test_udf.py +++ b/python/tests/test_udf.py @@ -76,6 +76,77 @@ def test_register_udf(ctx, df) -> None: assert result == pa.array([False, False, True]) +def test_udf_lookup(ctx, df) -> None: + is_null = udf( + lambda x: x.is_null(), + [pa.float64()], + pa.bool_(), + volatility="immutable", + name="lookup_is_null", + ) + ctx.register_udf(is_null) + + assert "lookup_is_null" in ctx.udfs() + + looked_up = ctx.udf("lookup_is_null") + df_result = df.select(looked_up(column("b"))) + result = df_result.collect()[0].column(0) + assert result == pa.array([False, False, True]) + + with pytest.raises(Exception, match="no UDF named"): + ctx.udf("does_not_exist") + + +def test_udf_late_binding_dispatch(ctx, df) -> None: + """Resolve a UDF chosen by configuration string, then invoke it.""" + late_is_null = udf( + lambda x: x.is_null(), + [pa.int64()], + pa.bool_(), + volatility="immutable", + name="late_is_null", + ) + late_is_not_null = udf( + lambda x: pc.invert(x.is_null()), + [pa.int64()], + pa.bool_(), + volatility="immutable", + name="late_is_not_null", + ) + + ctx.register_udf(late_is_null) + ctx.register_udf(late_is_not_null) + + # Pretend this came from a config file / API request — only a string. + runtime_config = {"check_fn": "late_is_not_null"} + + assert runtime_config["check_fn"] in ctx.udfs() + + fn = ctx.udf(runtime_config["check_fn"]) + result = df.select(fn(column("b")).alias("ok")).collect()[0].column(0) + assert result == pa.array([True, True, False]) + + +def test_udaf_lookup_builtin(ctx, df) -> None: + assert "sum" in ctx.udafs() + sum_fn = ctx.udaf("sum") + result = df.aggregate([], [sum_fn(column("a")).alias("total")]).collect() + assert result[0].column(0).to_pylist() == [6] + + with pytest.raises(Exception, match="no UDAF named"): + ctx.udaf("does_not_exist") + + +def test_udwf_lookup_builtin(ctx, df) -> None: + assert "row_number" in ctx.udwfs() + rn = ctx.udwf("row_number") + result = df.select(column("a"), rn().alias("rn")).collect() + assert result[0].column(1).to_pylist() == [1, 2, 3] + + with pytest.raises(Exception, match="no UDWF named"): + ctx.udwf("does_not_exist") + + class OverThresholdUDF: def __init__(self, threshold: int = 0) -> None: self.threshold = threshold