From 79fc511f479da7d5a0503951289802d7fdb6728e Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Wed, 31 Dec 2025 15:05:56 -0800 Subject: [PATCH 1/6] Replace pandas.eval with native implementation This commit removes the dependency on pandas.eval() and implements a native expression evaluator in Dataset.eval() using Python's ast module. The new implementation provides better support for multi-dimensional arrays and maintains backward compatibility with deprecated operators through automatic transformation. Key changes: - Remove pd.eval() call and replace with custom _eval_expression() method - Add _LogicalOperatorTransformer to convert deprecated operators (and/or/not) to bitwise operators (&/|/~) that work element-wise on arrays - Implement automatic transformation of chained comparisons to explicit bitwise AND operations - Add security validation to block lambda expressions and private attributes - Emit FutureWarning for deprecated constructs (logical operators, chained comparisons, parser= argument) - Support assignment statements (target = expression) in eval() - Make data variables and coordinates take priority in namespace resolution - Provide safe builtins (abs, min, max, round, len, sum, pow, any, all, type constructors, iteration helpers) while blocking __import__, open, etc. - Add comprehensive test coverage including edge cases, error messages, dask compatibility, and security validation --- xarray/core/dataset.py | 260 +++++++++++++++- xarray/tests/test_dataset.py | 580 ++++++++++++++++++++++++++++++++++- 2 files changed, 813 insertions(+), 27 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 01baa9aed3d..35a9e24d6ab 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1,6 +1,8 @@ from __future__ import annotations +import ast import asyncio +import builtins import copy import datetime import io @@ -72,7 +74,6 @@ Self, T_ChunkDim, T_ChunksFreq, - T_DataArray, T_DataArrayOrSet, ZarrWriteModes, ) @@ -9533,19 +9534,202 @@ def argmax(self, dim: Hashable | None = None, **kwargs) -> Self: "Dataset.argmin() with a sequence or ... for dim" ) + # Base namespace for eval expressions (modules added lazily in _eval_expression + # to avoid circular imports for xarray). + # We add common builtins back since we block __builtins__ for security. + # Note: builtins.map is used explicitly because 'map' in class scope refers + # to the Dataset.map method defined earlier in this class body. + _EVAL_NAMESPACE_BUILTINS: dict[str, Any] = { + # Numeric/aggregation functions + "abs": abs, + "min": min, + "max": max, + "round": round, + "len": len, + "sum": sum, + "pow": pow, + "any": any, + "all": all, + # Type constructors + "int": int, + "float": float, + "bool": bool, + "str": str, + "list": list, + "tuple": tuple, + "dict": dict, + "set": set, + "slice": slice, + # Iteration helpers + "range": range, + "zip": zip, + "enumerate": enumerate, + "map": builtins.map, + "filter": filter, + } + + # ------------------------------------------------------------------------- + # eval() Implementation Notes (for future maintainers): + # + # This implementation uses native AST-based evaluation instead of pd.eval() + # to support N-dimensional arrays (N > 2). See GitHub issue #11062. + # + # We retain logical operator transformation ('and'/'or'/'not' to '&'/'|'/'~', + # and chained comparisons) for consistency with query(), which still uses + # pd.eval(). We don't migrate query() to this implementation because: + # - query() typically works fine (expressions usually compare 1D coordinates) + # - pd.eval() with numexpr is faster and well-tested for query's use case + # ------------------------------------------------------------------------- + + class _LogicalOperatorTransformer(ast.NodeTransformer): + """Transform operators for consistency with query(). + + query() uses pd.eval() which transforms these operators automatically. + We replicate that behavior here so syntax that works in query() also + works in eval(). + + Transformations: + 1. 'and'/'or'/'not' -> '&'/'|'/'~' + 2. 'a < b < c' -> '(a < b) & (b < c)' + + These constructs fail on arrays in standard Python because they call + __bool__(), which is ambiguous for multi-element arrays. + """ + + def visit_BoolOp(self, node: ast.BoolOp) -> ast.AST: + # Transform: a and b -> a & b, a or b -> a | b + self.generic_visit(node) + op: ast.BitAnd | ast.BitOr + if isinstance(node.op, ast.And): + op = ast.BitAnd() + elif isinstance(node.op, ast.Or): + op = ast.BitOr() + else: + return node + + # BoolOp can have multiple values: a and b and c + # Transform to chained BinOp: (a & b) & c + result = node.values[0] + for value in node.values[1:]: + result = ast.BinOp(left=result, op=op, right=value) + return ast.fix_missing_locations(result) + + def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.AST: + # Transform: not a -> ~a + self.generic_visit(node) + if isinstance(node.op, ast.Not): + return ast.fix_missing_locations( + ast.UnaryOp(op=ast.Invert(), operand=node.operand) + ) + return node + + def visit_Compare(self, node: ast.Compare) -> ast.AST: + # Transform chained comparisons: 1 < x < 5 -> (1 < x) & (x < 5) + # Python's chained comparisons use short-circuit evaluation at runtime, + # which calls __bool__ on intermediate results. This fails for arrays. + # We transform to bitwise AND which works element-wise. + self.generic_visit(node) + + if len(node.ops) == 1: + # Simple comparison, no transformation needed + return node + + # Build individual comparisons and chain with BitAnd + # For: a < b < c < d + # We need: (a < b) & (b < c) & (c < d) + comparisons = [] + left = node.left + for op, comparator in zip(node.ops, node.comparators, strict=True): + comp = ast.Compare(left=left, ops=[op], comparators=[comparator]) + comparisons.append(comp) + left = comparator + + # Chain with BitAnd: (a < b) & (b < c) & ... + result: ast.Compare | ast.BinOp = comparisons[0] + for comp in comparisons[1:]: + result = ast.BinOp(left=result, op=ast.BitAnd(), right=comp) + return ast.fix_missing_locations(result) + + def _validate_eval_expression(self, tree: ast.AST) -> None: + """Validate that an AST doesn't contain unsafe patterns. + + This provides basic protection against common attack vectors but is NOT + designed to be a robust security boundary. Eval with untrusted user input + should always be treated with caution. + + Security measures: + - Empty __builtins__ dict blocks __import__, open, exec, etc. + - Blocking private/dunder attributes prevents class hierarchy traversal + attacks (e.g., x.__class__.__bases__[0].__subclasses__()) + - Limited namespace: data variables, coordinates, np/pd/xr modules, and + safe builtins: + - Numeric/aggregation: abs, min, max, round, len, sum, pow, any, all + - Type constructors: int, float, bool, str, list, tuple, dict, set, slice + - Iteration helpers: range, zip, enumerate, map, filter + + Known limitations: + - Format strings (e.g., "{0.__class__}".format(x)) can access dunder + attributes at runtime, bypassing AST-level checks. This allows + information disclosure but not direct code execution. + + We welcome contributions to improve the security model. + """ + for node in ast.walk(tree): + # Block lambda expressions to reduce attack surface + if isinstance(node, ast.Lambda): + raise ValueError( + "Lambda expressions are not allowed in eval(). " + "Use direct operations on data variables instead." + ) + # Block private/dunder attributes to prevent class hierarchy traversal + if isinstance(node, ast.Attribute) and node.attr.startswith("_"): + raise ValueError( + f"Access to private attributes is not allowed: '{node.attr}'. " + f"For security, attributes starting with '_' are blocked." + ) + + def _eval_expression(self, expr: str) -> DataArray: + """Evaluate an expression string using xarray's native operations.""" + try: + tree = ast.parse(expr, mode="eval") + except SyntaxError as e: + raise ValueError(f"Invalid expression syntax: {expr}") from e + + # Transform logical operators for consistency with query(). + # See _LogicalOperatorTransformer docstring for details. + tree = self._LogicalOperatorTransformer().visit(tree) + ast.fix_missing_locations(tree) + + self._validate_eval_expression(tree) + + # Build namespace: data variables, coordinates, modules, and safe builtins. + # Empty __builtins__ blocks dangerous functions like __import__, exec, open. + # Priority order (highest to lowest): data variables > coordinates > modules > builtins + # This ensures user data always wins when names collide with builtins. + import xarray as xr # Lazy import to avoid circular dependency + + namespace: dict[str, Any] = dict(self._EVAL_NAMESPACE_BUILTINS) + namespace.update({"np": np, "pd": pd, "xr": xr}) + namespace.update({str(name): self.coords[name] for name in self.coords}) + namespace.update({str(name): self[name] for name in self.data_vars}) + + code = compile(tree, "", "eval") + return builtins.eval(code, {"__builtins__": {}}, namespace) + def eval( self, statement: str, *, - parser: QueryParserOptions = "pandas", - ) -> Self | T_DataArray: + parser: QueryParserOptions | Default = _default, + ) -> Self | DataArray: """ Calculate an expression supplied as a string in the context of the dataset. This is currently experimental; the API may change particularly around assignments, which currently return a ``Dataset`` with the additional variable. - Currently only the ``python`` engine is supported, which has the same - performance as executing in python. + + Logical operators (``and``, ``or``, ``not``) are automatically transformed + to bitwise operators (``&``, ``|``, ``~``) which work element-wise on arrays. Parameters ---------- @@ -9555,7 +9739,14 @@ def eval( Returns ------- result : Dataset or DataArray, depending on whether ``statement`` contains an - assignment. + assignment. + + Warning + ------- + This method evaluates Python expressions and should not be used with + untrusted input. While basic security measures are in place (empty + ``__builtins__``, blocked private attributes, limited namespace), they + are not designed to be a robust security sandbox. Examples -------- @@ -9584,16 +9775,55 @@ def eval( b (x) float64 40B 0.0 0.25 0.5 0.75 1.0 c (x) float64 40B 0.0 1.25 2.5 3.75 5.0 """ + if parser is not _default: + emit_user_level_warning( + "The 'parser' argument to Dataset.eval() is deprecated and will be " + "removed in a future version. Logical operators (and/or/not) are now " + "always transformed to bitwise operators (&/|/~) for array compatibility.", + FutureWarning, + ) - return pd.eval( # type: ignore[return-value] - statement, - resolvers=[self], - target=self, - parser=parser, - # Because numexpr returns a numpy array, using that engine results in - # different behavior. We'd be very open to a contribution handling this. - engine="python", - ) + statement = statement.strip() + + # Check for assignment: "target = expr" + # Must handle compound operators like ==, !=, <=, >= + # Use ast to detect assignment properly + try: + tree = ast.parse(statement, mode="exec") + except SyntaxError as e: + raise ValueError(f"Invalid statement syntax: {statement}") from e + + if len(tree.body) != 1: + raise ValueError("Only single statements are supported") + + stmt = tree.body[0] + + if isinstance(stmt, ast.Assign): + # Assignment: "c = a + b" + if len(stmt.targets) != 1: + raise ValueError("Only single assignment targets are supported") + target = stmt.targets[0] + if not isinstance(target, ast.Name): + raise ValueError( + f"Assignment target must be a simple name, got {type(target).__name__}" + ) + target_name = target.id + + # Get the expression source + expr_source = ast.unparse(stmt.value) + result: DataArray = self._eval_expression(expr_source) + return self.assign({target_name: result}) + + elif isinstance(stmt, ast.Expr): + # Expression: "a + b" + expr_source = ast.unparse(stmt.value) + return self._eval_expression(expr_source) + + else: + raise ValueError( + f"Unsupported statement type: {type(stmt).__name__}. " + f"Only expressions and assignments are supported." + ) def query( self, diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index d25ef5a2771..ae83fb66be4 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -7636,21 +7636,577 @@ def test_query(self, backend, engine, parser) -> None: # pytest tests — new tests should go here, rather than in the class. -@pytest.mark.parametrize("parser", ["pandas", "python"]) -def test_eval(ds, parser) -> None: - """Currently much more minimal testing that `query` above, and much of the setup - isn't used. But the risks are fairly low — `query` shares much of the code, and - the method is currently experimental.""" - - actual = ds.eval("z1 + 5", parser=parser) +def test_eval(ds) -> None: + """Test basic eval functionality.""" + actual = ds.eval("z1 + 5") expect = ds["z1"] + 5 assert_identical(expect, actual) - # check pandas query syntax is supported - if parser == "pandas": - actual = ds.eval("(z1 > 5) and (z2 > 0)", parser=parser) - expect = (ds["z1"] > 5) & (ds["z2"] > 0) - assert_identical(expect, actual) + # Use bitwise operators for element-wise operations on arrays + actual = ds.eval("(z1 > 5) & (z2 > 0)") + expect = (ds["z1"] > 5) & (ds["z2"] > 0) + assert_identical(expect, actual) + + +def test_eval_parser_deprecated(ds) -> None: + """Test that passing parser= raises a FutureWarning.""" + with pytest.warns(FutureWarning, match="parser.*deprecated"): + ds.eval("z1 + 5", parser="pandas") + + +def test_eval_logical_operators(ds) -> None: + """Test that 'and'/'or'/'not' are transformed for query() consistency. + + These operators are transformed to '&'/'|'/'~' to match pd.eval() behavior, + which query() uses. This ensures syntax that works in query() also works in + eval(). + """ + # 'and' transformed to '&' + actual = ds.eval("(z1 > 5) and (z2 > 0)") + expect = (ds["z1"] > 5) & (ds["z2"] > 0) + assert_identical(expect, actual) + + # 'or' transformed to '|' + actual = ds.eval("(z1 > 5) or (z2 > 0)") + expect = (ds["z1"] > 5) | (ds["z2"] > 0) + assert_identical(expect, actual) + + # 'not' transformed to '~' + actual = ds.eval("not (z1 > 5)") + expect = ~(ds["z1"] > 5) + assert_identical(expect, actual) + + +def test_eval_ndimensional() -> None: + """Test that eval works with N-dimensional data where N > 2.""" + # Create a 3D dataset - this previously failed with pd.eval + rng = np.random.default_rng(42) + ds = Dataset( + { + "x": (["time", "lat", "lon"], rng.random((3, 4, 5))), + "y": (["time", "lat", "lon"], rng.random((3, 4, 5))), + } + ) + + # Basic arithmetic + actual = ds.eval("x + y") + expect = ds["x"] + ds["y"] + assert_identical(expect, actual) + + # Assignment + actual = ds.eval("z = x + y") + assert "z" in actual.data_vars + assert_equal(ds["x"] + ds["y"], actual["z"]) + + # Complex expression + actual = ds.eval("x * 2 + y ** 2") + expect = ds["x"] * 2 + ds["y"] ** 2 + assert_identical(expect, actual) + + # Comparison + actual = ds.eval("x > y") + expect = ds["x"] > ds["y"] + assert_identical(expect, actual) + + # Use bitwise operators for element-wise boolean operations + actual = ds.eval("(x > 0.5) & (y < 0.5)") + expect = (ds["x"] > 0.5) & (ds["y"] < 0.5) + assert_identical(expect, actual) + + +def test_eval_chained_comparisons() -> None: + """Test that chained comparisons are transformed for query() consistency. + + Chained comparisons like 'a < b < c' are transformed to '(a < b) & (b < c)' + to match pd.eval() behavior, which query() uses. + """ + ds = Dataset({"x": ("dim", np.arange(10))}) + + # Basic chained comparison: 2 < x < 7 + actual = ds.eval("2 < x < 7") + expect = (ds["x"] > 2) & (ds["x"] < 7) + assert_identical(expect, actual) + + # Mixed operators: 0 <= x < 5 + actual = ds.eval("0 <= x < 5") + expect = (ds["x"] >= 0) & (ds["x"] < 5) + assert_identical(expect, actual) + + # Explicit bitwise operators also work + actual = ds.eval("(x > 2) & (x < 7)") + expect = (ds["x"] > 2) & (ds["x"] < 7) + assert_identical(expect, actual) + + +def test_eval_security() -> None: + """Test that eval blocks unsafe operations.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + + # Dunder/private attribute access should be blocked (sandbox escape vector) + with pytest.raises(ValueError, match="Access to private attributes is not allowed"): + ds.eval("a.__class__") + + with pytest.raises(ValueError, match="Access to private attributes is not allowed"): + ds.eval("a._private") + + # Lambda expressions should be blocked to reduce attack surface + with pytest.raises(ValueError, match="Lambda expressions are not allowed"): + ds.eval("(lambda x: x + 1)(a)") + + # Dangerous builtins should not be available + with pytest.raises(NameError): + ds.eval("__import__('os')") + + with pytest.raises(NameError): + ds.eval("open('file.txt')") + + +def test_eval_unsupported_statements() -> None: + """Test that unsupported statement types produce clear errors.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + + # Augmented assignment is not supported + with pytest.raises(ValueError, match="Unsupported statement type"): + ds.eval("a += 1") + + +def test_eval_functions() -> None: + """Test that numpy and other functions work in eval.""" + ds = Dataset({"a": ("x", [0.0, 1.0, 4.0])}) + + # numpy functions via np namespace should work + result = ds.eval("np.sqrt(a)") + assert_equal(result, np.sqrt(ds["a"])) + + result = ds.eval("np.sin(a) + np.cos(a)") + assert_equal(result, np.sin(ds["a"]) + np.cos(ds["a"])) + + # pandas namespace should work + result = ds.eval("pd.isna(a)") + np.testing.assert_array_equal(result, pd.isna(ds["a"])) + + # xarray namespace should work + result = ds.eval("xr.where(a > 1, a, 0)") + import xarray as xr + + assert_equal(result, xr.where(ds["a"] > 1, ds["a"], 0)) + + # Common builtins should work (we block __builtins__ for security) + result = ds.eval("abs(a - 2)") + assert_equal(result, abs(ds["a"] - 2)) + + result = ds.eval("round(float(a.mean()))") + assert result == round(float(ds["a"].mean())) + + result = ds.eval("len(a)") + assert result == 3 + + result = ds.eval("pow(a, 2)") + assert_equal(result, ds["a"] ** 2) + + # Attribute access on DataArrays should work + result = ds.eval("a.values") + assert isinstance(result, np.ndarray) + + # Method calls on DataArrays should work + result = ds.eval("a.mean()") + assert float(result) == np.mean([0.0, 1.0, 4.0]) + + +def test_eval_extended_builtins() -> None: + """Test extended builtins available in eval namespace. + + These builtins are safe (no I/O, no code execution) and commonly needed + for typical xarray operations like slicing, type conversion, and iteration. + """ + ds = Dataset( + {"a": ("x", [1.0, 2.0, 3.0, 4.0, 5.0])}, + coords={"time": pd.date_range("2019-01-01", periods=5)}, + ) + + # slice - essential for .sel() with ranges + result = ds.eval("a.sel(x=slice(1, 3))") + expected = ds["a"].sel(x=slice(1, 3)) + assert_equal(result, expected) + + # str - type constructor + result = ds.eval("str(int(a.mean()))") + assert result == "3" + + # list, tuple - type constructors + result = ds.eval("list(range(3))") + assert result == [0, 1, 2] + + result = ds.eval("tuple(range(3))") + assert result == (0, 1, 2) + + # dict, set - type constructors + result = ds.eval("dict(x=1, y=2)") + assert result == {"x": 1, "y": 2} + + result = ds.eval("set([1, 2, 2, 3])") + assert result == {1, 2, 3} + + # range - iteration + result = ds.eval("list(range(3))") + assert result == [0, 1, 2] + + # zip, enumerate - iteration helpers + result = ds.eval("list(zip([1, 2], [3, 4]))") + assert result == [(1, 3), (2, 4)] + + result = ds.eval("list(enumerate(['a', 'b']))") + assert result == [(0, "a"), (1, "b")] + + # map, filter - functional helpers + result = ds.eval("list(map(abs, [-1, -2, 3]))") + assert result == [1, 2, 3] + + result = ds.eval("list(filter(bool, [0, 1, 0, 2]))") + assert result == [1, 2] + + # any, all - aggregation + result = ds.eval("any([False, True, False])") + assert result is True + + result = ds.eval("all([True, True, True])") + assert result is True + + result = ds.eval("all([True, False, True])") + assert result is False + + +def test_eval_data_variable_priority() -> None: + """Test that data variables take priority over builtin functions. + + Users may have data variables named 'sum', 'abs', 'min', etc. When they + reference these in eval(), they should get their data, not the Python builtins. + The builtins should still be accessible via the np namespace (np.sum, np.abs). + """ + # Create dataset with data variables that shadow builtins + ds = Dataset( + { + "sum": ("x", [10.0, 20.0, 30.0]), # shadows builtin sum + "abs": ("x", [1.0, 2.0, 3.0]), # shadows builtin abs + "min": ("x", [100.0, 200.0, 300.0]), # shadows builtin min + "other": ("x", [5.0, 10.0, 15.0]), + } + ) + + # Data variables should take priority - user data wins + result = ds.eval("sum + other") + expected = ds["sum"] + ds["other"] + assert_equal(result, expected) + + # Should get the data variable, not builtin sum applied to something + result = ds.eval("sum * 2") + expected = ds["sum"] * 2 + assert_equal(result, expected) + + # abs as data variable should work + result = ds.eval("abs + 1") + expected = ds["abs"] + 1 + assert_equal(result, expected) + + # min as data variable should work + result = ds.eval("min - 50") + expected = ds["min"] - 50 + assert_equal(result, expected) + + # np namespace should still provide access to actual functions + result = ds.eval("np.abs(other - 10)") + expected = np.abs(ds["other"] - 10) + assert_equal(result, expected) + + # np.sum should work even when 'sum' is a data variable + result = ds.eval("np.sum(other)") + expected = np.sum(ds["other"]) + assert result == expected + + +def test_eval_coordinate_priority() -> None: + """Test that coordinates also take priority over builtins.""" + ds = Dataset( + {"data": ("x", [1.0, 2.0, 3.0])}, + coords={"sum": ("x", [10.0, 20.0, 30.0])}, # coordinate named 'sum' + ) + + # Coordinate should be accessible and take priority over builtin + result = ds.eval("data + sum") + expected = ds["data"] + ds.coords["sum"] + assert_equal(result, expected) + + +class TestEvalErrorMessages: + """Test that eval() produces clear error messages for common mistakes.""" + + def test_undefined_variable(self) -> None: + """Test error message when referencing an undefined variable.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + with pytest.raises(NameError, match="undefined_var"): + ds.eval("undefined_var + a") + + def test_syntax_error(self) -> None: + """Test error message for malformed expressions.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + with pytest.raises(ValueError, match="Invalid"): + ds.eval("a +") + + def test_invalid_assignment(self) -> None: + """Test error message when assignment target is invalid.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + # "1 = a" should fail during parsing - can't assign to a literal + with pytest.raises(ValueError, match="Invalid"): + ds.eval("1 = a") + + def test_dunder_access(self) -> None: + """Test error message when trying to access dunder attributes.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + with pytest.raises(ValueError, match="private attributes"): + ds.eval("a.__class__") + + def test_missing_method(self) -> None: + """Test error message when calling a nonexistent method.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + # This should raise AttributeError from the DataArray + with pytest.raises(AttributeError, match="nonexistent_method"): + ds.eval("a.nonexistent_method()") + + def test_type_error_in_expression(self) -> None: + """Test error message when types are incompatible.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + # Adding string to numeric array should raise TypeError or similar + with pytest.raises((TypeError, np.exceptions.DTypePromotionError)): + ds.eval("a + 'string'") + + +class TestEvalEdgeCases: + """Test edge cases for eval().""" + + def test_empty_expression(self) -> None: + """Test handling of empty expression string.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + with pytest.raises(ValueError): + ds.eval("") + + def test_whitespace_only_expression(self) -> None: + """Test handling of whitespace-only expression.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + with pytest.raises(ValueError): + ds.eval(" ") + + def test_just_variable_name(self) -> None: + """Test that just a variable name returns the variable.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + result = ds.eval("a") + expected = ds["a"] + assert_equal(result, expected) + + def test_unicode_variable_names(self) -> None: + """Test that unicode variable names work in expressions.""" + # Greek letters are valid Python identifiers + ds = Dataset({"α": ("x", [1.0, 2.0, 3.0]), "β": ("x", [4.0, 5.0, 6.0])}) + result = ds.eval("α + β") + expected = ds["α"] + ds["β"] + assert_equal(result, expected) + + def test_long_expression(self) -> None: + """Test that very long expressions work correctly.""" + ds = Dataset({"a": ("x", [1.0, 2.0, 3.0])}) + # Build a long expression: a + a + a + ... (50 times) + long_expr = " + ".join(["a"] * 50) + result = ds.eval(long_expr) + expected = ds["a"] * 50 + assert_equal(result, expected) + + +class TestEvalDask: + """Test Dataset.eval() with dask-backed arrays.""" + + @requires_dask + def test_basic_arithmetic_preserves_dask(self) -> None: + """Test that basic arithmetic with dask arrays returns dask-backed result.""" + from xarray.core.utils import is_duck_dask_array + + ds = Dataset( + {"a": ("x", np.arange(10.0)), "b": ("x", np.linspace(0, 1, 10))} + ).chunk({"x": 5}) + + with raise_if_dask_computes(): + result = ds.eval("a + b") + + assert isinstance(result, DataArray) + assert is_duck_dask_array(result.data) + + # Verify correctness when computed + expected = ds["a"] + ds["b"] + assert_equal(result, expected) + + @requires_dask + def test_assignment_preserves_dask(self) -> None: + """Test that assignments with dask arrays preserve lazy evaluation.""" + from xarray.core.utils import is_duck_dask_array + + ds = Dataset( + {"a": ("x", np.arange(10.0)), "b": ("x", np.linspace(0, 1, 10))} + ).chunk({"x": 5}) + + with raise_if_dask_computes(): + result = ds.eval("z = a + b") + + assert isinstance(result, Dataset) + assert "z" in result.data_vars + assert is_duck_dask_array(result["z"].data) + + # Verify correctness when computed + expected = ds["a"] + ds["b"] + assert_equal(result["z"], expected) + + @requires_dask + def test_method_chaining_with_compute(self) -> None: + """Test that method chaining works with dask arrays.""" + ds = Dataset({"a": (("x", "y"), np.arange(20.0).reshape(4, 5))}).chunk( + {"x": 2, "y": 5} + ) + + # Calling .mean() should still be lazy + result = ds.eval("a.mean(dim='x')") + # Calling .compute() should return numpy-backed result + computed = result.compute() + + expected = ds["a"].mean(dim="x").compute() + assert_equal(computed, expected) + + @requires_dask + def test_xr_where_preserves_dask(self) -> None: + """Test that xr.where() with dask arrays preserves lazy evaluation.""" + from xarray.core.utils import is_duck_dask_array + + ds = Dataset({"a": ("x", np.arange(-5, 5, dtype=float))}).chunk({"x": 5}) + + with raise_if_dask_computes(): + result = ds.eval("xr.where(a > 0, a, 0)") + + assert isinstance(result, DataArray) + assert is_duck_dask_array(result.data) + + # Verify correctness when computed + expected = xr.where(ds["a"] > 0, ds["a"], 0) + assert_equal(result, expected) + + @requires_dask + def test_complex_expression_preserves_dask(self) -> None: + """Test that complex expressions preserve dask backing.""" + from xarray.core.utils import is_duck_dask_array + + rng = np.random.default_rng(42) + ds = Dataset( + { + "x": (["time", "lat", "lon"], rng.random((3, 4, 5))), + "y": (["time", "lat", "lon"], rng.random((3, 4, 5))), + } + ).chunk({"time": 1, "lat": 2, "lon": 5}) + + with raise_if_dask_computes(): + result = ds.eval("x * 2 + y ** 2") + + assert is_duck_dask_array(result.data) + + # Verify correctness when computed + expected = ds["x"] * 2 + ds["y"] ** 2 + assert_equal(result, expected) + + @requires_dask + def test_mixed_dask_and_numpy(self) -> None: + """Test expressions with mixed dask and numpy arrays.""" + from xarray.core.utils import is_duck_dask_array + + ds = Dataset( + { + "dask_var": ("x", np.arange(10.0)), + "numpy_var": ("x", np.linspace(0, 1, 10)), + } + ) + # Only chunk one variable + ds["dask_var"] = ds["dask_var"].chunk({"x": 5}) + + with raise_if_dask_computes(): + result = ds.eval("dask_var + numpy_var") + + # Result should be dask-backed when any input is dask + assert is_duck_dask_array(result.data) + + # Verify correctness + expected = ds["dask_var"] + ds["numpy_var"] + assert_equal(result, expected) + + @requires_dask + def test_np_functions_preserve_dask(self) -> None: + """Test that numpy functions via np namespace preserve dask.""" + from xarray.core.utils import is_duck_dask_array + + ds = Dataset({"a": ("x", np.arange(1.0, 11.0))}).chunk({"x": 5}) + + with raise_if_dask_computes(): + result = ds.eval("np.sqrt(a)") + + assert is_duck_dask_array(result.data) + + # Verify correctness + expected = np.sqrt(ds["a"]) + assert_equal(result, expected) + + @requires_dask + def test_comparison_preserves_dask(self) -> None: + """Test that comparison operations preserve dask backing.""" + from xarray.core.utils import is_duck_dask_array + + ds = Dataset( + {"a": ("x", np.arange(10.0)), "b": ("x", np.arange(10.0)[::-1])} + ).chunk({"x": 5}) + + with raise_if_dask_computes(): + result = ds.eval("a > b") + + assert is_duck_dask_array(result.data) + + # Verify correctness + expected = ds["a"] > ds["b"] + assert_equal(result, expected) + + @requires_dask + def test_boolean_operators_preserve_dask(self) -> None: + """Test that bitwise boolean operators preserve dask.""" + from xarray.core.utils import is_duck_dask_array + + ds = Dataset( + {"a": ("x", np.arange(10.0)), "b": ("x", np.arange(10.0)[::-1])} + ).chunk({"x": 5}) + + with raise_if_dask_computes(): + result = ds.eval("(a > 3) & (b < 7)") + + assert is_duck_dask_array(result.data) + + # Verify correctness + expected = (ds["a"] > 3) & (ds["b"] < 7) + assert_equal(result, expected) + + @requires_dask + def test_chained_comparisons_preserve_dask(self) -> None: + """Test that chained comparisons preserve dask backing.""" + from xarray.core.utils import is_duck_dask_array + + ds = Dataset({"x": ("dim", np.arange(10.0))}).chunk({"dim": 5}) + + with raise_if_dask_computes(): + result = ds.eval("2 < x < 7") + + assert is_duck_dask_array(result.data) + + # Verify correctness + expected = (ds["x"] > 2) & (ds["x"] < 7) + assert_equal(result, expected) @pytest.mark.parametrize("test_elements", ([1, 2], np.array([1, 2]), DataArray([1, 2]))) From cac8e5a912d4ba56576848a2977acd15fa6f454a Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Thu, 1 Jan 2026 12:41:29 -0800 Subject: [PATCH 2/6] Fix mypy errors in eval tests - Use pd.isna(ds["a"].values) instead of pd.isna(ds["a"]) since pandas type stubs don't have overloads for DataArray - Use abs() instead of np.abs() to get DataArray return type Co-authored-by: Claude --- xarray/tests/test_dataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index ae83fb66be4..9dda44da068 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -7783,7 +7783,8 @@ def test_eval_functions() -> None: # pandas namespace should work result = ds.eval("pd.isna(a)") - np.testing.assert_array_equal(result, pd.isna(ds["a"])) + # pd.isna returns ndarray, not DataArray + np.testing.assert_array_equal(result, pd.isna(ds["a"].values)) # xarray namespace should work result = ds.eval("xr.where(a > 1, a, 0)") @@ -7915,7 +7916,7 @@ def test_eval_data_variable_priority() -> None: # np namespace should still provide access to actual functions result = ds.eval("np.abs(other - 10)") - expected = np.abs(ds["other"] - 10) + expected = abs(ds["other"] - 10) assert_equal(result, expected) # np.sum should work even when 'sum' is a data variable From 67f27c2c556583b5214e6f0b8094e460118d2814 Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Thu, 1 Jan 2026 13:28:29 -0800 Subject: [PATCH 3/6] Remove security framing, frame restrictions as pd.eval() compatibility The lambda and dunder restrictions emulate pd.eval() behavior rather than providing security guarantees. Pandas explicitly doesn't claim these as security measures. Co-authored-by: Claude --- xarray/core/dataset.py | 37 +++++++----------------------------- xarray/tests/test_dataset.py | 12 ++++++------ 2 files changed, 13 insertions(+), 36 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 35a9e24d6ab..7e38be737b8 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -9536,7 +9536,7 @@ def argmax(self, dim: Hashable | None = None, **kwargs) -> Self: # Base namespace for eval expressions (modules added lazily in _eval_expression # to avoid circular imports for xarray). - # We add common builtins back since we block __builtins__ for security. + # We add common builtins back since we use an empty __builtins__ dict. # Note: builtins.map is used explicitly because 'map' in class scope refers # to the Dataset.map method defined earlier in this class body. _EVAL_NAMESPACE_BUILTINS: dict[str, Any] = { @@ -9651,41 +9651,21 @@ def visit_Compare(self, node: ast.Compare) -> ast.AST: return ast.fix_missing_locations(result) def _validate_eval_expression(self, tree: ast.AST) -> None: - """Validate that an AST doesn't contain unsafe patterns. + """Validate that an AST doesn't contain patterns we don't support. - This provides basic protection against common attack vectors but is NOT - designed to be a robust security boundary. Eval with untrusted user input - should always be treated with caution. - - Security measures: - - Empty __builtins__ dict blocks __import__, open, exec, etc. - - Blocking private/dunder attributes prevents class hierarchy traversal - attacks (e.g., x.__class__.__bases__[0].__subclasses__()) - - Limited namespace: data variables, coordinates, np/pd/xr modules, and - safe builtins: - - Numeric/aggregation: abs, min, max, round, len, sum, pow, any, all - - Type constructors: int, float, bool, str, list, tuple, dict, set, slice - - Iteration helpers: range, zip, enumerate, map, filter - - Known limitations: - - Format strings (e.g., "{0.__class__}".format(x)) can access dunder - attributes at runtime, bypassing AST-level checks. This allows - information disclosure but not direct code execution. - - We welcome contributions to improve the security model. + These restrictions emulate pd.eval() behavior for consistency. """ for node in ast.walk(tree): - # Block lambda expressions to reduce attack surface + # Block lambda expressions (pd.eval: "Only named functions are supported") if isinstance(node, ast.Lambda): raise ValueError( "Lambda expressions are not allowed in eval(). " "Use direct operations on data variables instead." ) - # Block private/dunder attributes to prevent class hierarchy traversal + # Block private/dunder attributes (consistent with pd.eval restrictions) if isinstance(node, ast.Attribute) and node.attr.startswith("_"): raise ValueError( - f"Access to private attributes is not allowed: '{node.attr}'. " - f"For security, attributes starting with '_' are blocked." + f"Access to private attributes is not allowed: '{node.attr}'" ) def _eval_expression(self, expr: str) -> DataArray: @@ -9743,10 +9723,7 @@ def eval( Warning ------- - This method evaluates Python expressions and should not be used with - untrusted input. While basic security measures are in place (empty - ``__builtins__``, blocked private attributes, limited namespace), they - are not designed to be a robust security sandbox. + Like ``pd.eval()``, this method should not be used with untrusted input. Examples -------- diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 9dda44da068..ab26ab0f984 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -7738,22 +7738,22 @@ def test_eval_chained_comparisons() -> None: assert_identical(expect, actual) -def test_eval_security() -> None: - """Test that eval blocks unsafe operations.""" +def test_eval_restricted_syntax() -> None: + """Test that eval blocks certain syntax to emulate pd.eval() behavior.""" ds = Dataset({"a": ("x", [1, 2, 3])}) - # Dunder/private attribute access should be blocked (sandbox escape vector) + # Private attribute access is not allowed (consistent with pd.eval) with pytest.raises(ValueError, match="Access to private attributes is not allowed"): ds.eval("a.__class__") with pytest.raises(ValueError, match="Access to private attributes is not allowed"): ds.eval("a._private") - # Lambda expressions should be blocked to reduce attack surface + # Lambda expressions are not allowed (pd.eval: "Only named functions are supported") with pytest.raises(ValueError, match="Lambda expressions are not allowed"): ds.eval("(lambda x: x + 1)(a)") - # Dangerous builtins should not be available + # These builtins are not in the namespace with pytest.raises(NameError): ds.eval("__import__('os')") @@ -7792,7 +7792,7 @@ def test_eval_functions() -> None: assert_equal(result, xr.where(ds["a"] > 1, ds["a"], 0)) - # Common builtins should work (we block __builtins__ for security) + # Common builtins should work result = ds.eval("abs(a - 2)") assert_equal(result, abs(ds["a"] - 2)) From 64b2601c29eee4ed813271bb56fca46f28f84307 Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Fri, 2 Jan 2026 10:50:47 -0800 Subject: [PATCH 4/6] Move eval implementation to dedicated module Extract AST-based expression evaluation code to xarray/core/eval.py: - EVAL_BUILTINS dict - LogicalOperatorTransformer class - validate_expression function This addresses the review feedback to keep the Dataset class focused. Co-authored-by: Claude --- xarray/core/dataset.py | 147 +++-------------------------------------- xarray/core/eval.py | 138 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 147 insertions(+), 138 deletions(-) create mode 100644 xarray/core/eval.py diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 7e38be737b8..5c18736a9d5 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -53,6 +53,11 @@ from xarray.core.dataset_utils import _get_virtual_variable, _LocIndexer from xarray.core.dataset_variables import DataVariables from xarray.core.duck_array_ops import datetime_to_numeric +from xarray.core.eval import ( + EVAL_BUILTINS, + LogicalOperatorTransformer, + validate_expression, +) from xarray.core.indexes import ( Index, Indexes, @@ -9534,140 +9539,6 @@ def argmax(self, dim: Hashable | None = None, **kwargs) -> Self: "Dataset.argmin() with a sequence or ... for dim" ) - # Base namespace for eval expressions (modules added lazily in _eval_expression - # to avoid circular imports for xarray). - # We add common builtins back since we use an empty __builtins__ dict. - # Note: builtins.map is used explicitly because 'map' in class scope refers - # to the Dataset.map method defined earlier in this class body. - _EVAL_NAMESPACE_BUILTINS: dict[str, Any] = { - # Numeric/aggregation functions - "abs": abs, - "min": min, - "max": max, - "round": round, - "len": len, - "sum": sum, - "pow": pow, - "any": any, - "all": all, - # Type constructors - "int": int, - "float": float, - "bool": bool, - "str": str, - "list": list, - "tuple": tuple, - "dict": dict, - "set": set, - "slice": slice, - # Iteration helpers - "range": range, - "zip": zip, - "enumerate": enumerate, - "map": builtins.map, - "filter": filter, - } - - # ------------------------------------------------------------------------- - # eval() Implementation Notes (for future maintainers): - # - # This implementation uses native AST-based evaluation instead of pd.eval() - # to support N-dimensional arrays (N > 2). See GitHub issue #11062. - # - # We retain logical operator transformation ('and'/'or'/'not' to '&'/'|'/'~', - # and chained comparisons) for consistency with query(), which still uses - # pd.eval(). We don't migrate query() to this implementation because: - # - query() typically works fine (expressions usually compare 1D coordinates) - # - pd.eval() with numexpr is faster and well-tested for query's use case - # ------------------------------------------------------------------------- - - class _LogicalOperatorTransformer(ast.NodeTransformer): - """Transform operators for consistency with query(). - - query() uses pd.eval() which transforms these operators automatically. - We replicate that behavior here so syntax that works in query() also - works in eval(). - - Transformations: - 1. 'and'/'or'/'not' -> '&'/'|'/'~' - 2. 'a < b < c' -> '(a < b) & (b < c)' - - These constructs fail on arrays in standard Python because they call - __bool__(), which is ambiguous for multi-element arrays. - """ - - def visit_BoolOp(self, node: ast.BoolOp) -> ast.AST: - # Transform: a and b -> a & b, a or b -> a | b - self.generic_visit(node) - op: ast.BitAnd | ast.BitOr - if isinstance(node.op, ast.And): - op = ast.BitAnd() - elif isinstance(node.op, ast.Or): - op = ast.BitOr() - else: - return node - - # BoolOp can have multiple values: a and b and c - # Transform to chained BinOp: (a & b) & c - result = node.values[0] - for value in node.values[1:]: - result = ast.BinOp(left=result, op=op, right=value) - return ast.fix_missing_locations(result) - - def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.AST: - # Transform: not a -> ~a - self.generic_visit(node) - if isinstance(node.op, ast.Not): - return ast.fix_missing_locations( - ast.UnaryOp(op=ast.Invert(), operand=node.operand) - ) - return node - - def visit_Compare(self, node: ast.Compare) -> ast.AST: - # Transform chained comparisons: 1 < x < 5 -> (1 < x) & (x < 5) - # Python's chained comparisons use short-circuit evaluation at runtime, - # which calls __bool__ on intermediate results. This fails for arrays. - # We transform to bitwise AND which works element-wise. - self.generic_visit(node) - - if len(node.ops) == 1: - # Simple comparison, no transformation needed - return node - - # Build individual comparisons and chain with BitAnd - # For: a < b < c < d - # We need: (a < b) & (b < c) & (c < d) - comparisons = [] - left = node.left - for op, comparator in zip(node.ops, node.comparators, strict=True): - comp = ast.Compare(left=left, ops=[op], comparators=[comparator]) - comparisons.append(comp) - left = comparator - - # Chain with BitAnd: (a < b) & (b < c) & ... - result: ast.Compare | ast.BinOp = comparisons[0] - for comp in comparisons[1:]: - result = ast.BinOp(left=result, op=ast.BitAnd(), right=comp) - return ast.fix_missing_locations(result) - - def _validate_eval_expression(self, tree: ast.AST) -> None: - """Validate that an AST doesn't contain patterns we don't support. - - These restrictions emulate pd.eval() behavior for consistency. - """ - for node in ast.walk(tree): - # Block lambda expressions (pd.eval: "Only named functions are supported") - if isinstance(node, ast.Lambda): - raise ValueError( - "Lambda expressions are not allowed in eval(). " - "Use direct operations on data variables instead." - ) - # Block private/dunder attributes (consistent with pd.eval restrictions) - if isinstance(node, ast.Attribute) and node.attr.startswith("_"): - raise ValueError( - f"Access to private attributes is not allowed: '{node.attr}'" - ) - def _eval_expression(self, expr: str) -> DataArray: """Evaluate an expression string using xarray's native operations.""" try: @@ -9676,11 +9547,11 @@ def _eval_expression(self, expr: str) -> DataArray: raise ValueError(f"Invalid expression syntax: {expr}") from e # Transform logical operators for consistency with query(). - # See _LogicalOperatorTransformer docstring for details. - tree = self._LogicalOperatorTransformer().visit(tree) + # See LogicalOperatorTransformer docstring for details. + tree = LogicalOperatorTransformer().visit(tree) ast.fix_missing_locations(tree) - self._validate_eval_expression(tree) + validate_expression(tree) # Build namespace: data variables, coordinates, modules, and safe builtins. # Empty __builtins__ blocks dangerous functions like __import__, exec, open. @@ -9688,7 +9559,7 @@ def _eval_expression(self, expr: str) -> DataArray: # This ensures user data always wins when names collide with builtins. import xarray as xr # Lazy import to avoid circular dependency - namespace: dict[str, Any] = dict(self._EVAL_NAMESPACE_BUILTINS) + namespace: dict[str, Any] = dict(EVAL_BUILTINS) namespace.update({"np": np, "pd": pd, "xr": xr}) namespace.update({str(name): self.coords[name] for name in self.coords}) namespace.update({str(name): self[name] for name in self.data_vars}) diff --git a/xarray/core/eval.py b/xarray/core/eval.py new file mode 100644 index 00000000000..cd851b30649 --- /dev/null +++ b/xarray/core/eval.py @@ -0,0 +1,138 @@ +""" +Expression evaluation for Dataset.eval(). + +This module provides AST-based expression evaluation to support N-dimensional +arrays (N > 2), which pd.eval() doesn't support. See GitHub issue #11062. + +We retain logical operator transformation ('and'/'or'/'not' to '&'/'|'/'~', +and chained comparisons) for consistency with query(), which still uses +pd.eval(). We don't migrate query() to this implementation because: +- query() typically works fine (expressions usually compare 1D coordinates) +- pd.eval() with numexpr is faster and well-tested for query's use case +""" + +from __future__ import annotations + +import ast +import builtins +from typing import Any + +# Base namespace for eval expressions. +# We add common builtins back since we use an empty __builtins__ dict. +EVAL_BUILTINS: dict[str, Any] = { + # Numeric/aggregation functions + "abs": abs, + "min": min, + "max": max, + "round": round, + "len": len, + "sum": sum, + "pow": pow, + "any": any, + "all": all, + # Type constructors + "int": int, + "float": float, + "bool": bool, + "str": str, + "list": list, + "tuple": tuple, + "dict": dict, + "set": set, + "slice": slice, + # Iteration helpers + "range": range, + "zip": zip, + "enumerate": enumerate, + "map": builtins.map, + "filter": filter, +} + + +class LogicalOperatorTransformer(ast.NodeTransformer): + """Transform operators for consistency with query(). + + query() uses pd.eval() which transforms these operators automatically. + We replicate that behavior here so syntax that works in query() also + works in eval(). + + Transformations: + 1. 'and'/'or'/'not' -> '&'/'|'/'~' + 2. 'a < b < c' -> '(a < b) & (b < c)' + + These constructs fail on arrays in standard Python because they call + __bool__(), which is ambiguous for multi-element arrays. + """ + + def visit_BoolOp(self, node: ast.BoolOp) -> ast.AST: + # Transform: a and b -> a & b, a or b -> a | b + self.generic_visit(node) + op: ast.BitAnd | ast.BitOr + if isinstance(node.op, ast.And): + op = ast.BitAnd() + elif isinstance(node.op, ast.Or): + op = ast.BitOr() + else: + return node + + # BoolOp can have multiple values: a and b and c + # Transform to chained BinOp: (a & b) & c + result = node.values[0] + for value in node.values[1:]: + result = ast.BinOp(left=result, op=op, right=value) + return ast.fix_missing_locations(result) + + def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.AST: + # Transform: not a -> ~a + self.generic_visit(node) + if isinstance(node.op, ast.Not): + return ast.fix_missing_locations( + ast.UnaryOp(op=ast.Invert(), operand=node.operand) + ) + return node + + def visit_Compare(self, node: ast.Compare) -> ast.AST: + # Transform chained comparisons: 1 < x < 5 -> (1 < x) & (x < 5) + # Python's chained comparisons use short-circuit evaluation at runtime, + # which calls __bool__ on intermediate results. This fails for arrays. + # We transform to bitwise AND which works element-wise. + self.generic_visit(node) + + if len(node.ops) == 1: + # Simple comparison, no transformation needed + return node + + # Build individual comparisons and chain with BitAnd + # For: a < b < c < d + # We need: (a < b) & (b < c) & (c < d) + comparisons = [] + left = node.left + for op, comparator in zip(node.ops, node.comparators, strict=True): + comp = ast.Compare(left=left, ops=[op], comparators=[comparator]) + comparisons.append(comp) + left = comparator + + # Chain with BitAnd: (a < b) & (b < c) & ... + result: ast.Compare | ast.BinOp = comparisons[0] + for comp in comparisons[1:]: + result = ast.BinOp(left=result, op=ast.BitAnd(), right=comp) + return ast.fix_missing_locations(result) + + +def validate_expression(tree: ast.AST) -> None: + """Validate that an AST doesn't contain patterns we don't support. + + These restrictions emulate pd.eval() behavior for consistency. + """ + for node in ast.walk(tree): + # Block lambda expressions (pd.eval: "Only named functions are supported") + if isinstance(node, ast.Lambda): + raise ValueError( + "Lambda expressions are not allowed in eval(). " + "Use direct operations on data variables instead." + ) + # Block private/dunder attributes (consistent with pd.eval restrictions) + if isinstance(node, ast.Attribute) and node.attr.startswith("_"): + raise ValueError( + f"Access to private attributes is not allowed: '{node.attr}'" + ) From 32f0a29c8548b1f03ee7ac43e21aaf35fffeeac2 Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Fri, 2 Jan 2026 11:31:20 -0800 Subject: [PATCH 5/6] Move eval tests to dedicated test_eval.py module Extract eval tests from test_dataset.py to test_eval.py: - 35 tests covering basic functionality, error messages, edge cases, and dask - Mirrors the implementation structure (core/eval.py <-> tests/test_eval.py) - Reduces test_dataset.py by 574 lines Co-authored-by: Claude --- xarray/tests/test_dataset.py | 574 ---------------------------------- xarray/tests/test_eval.py | 589 +++++++++++++++++++++++++++++++++++ 2 files changed, 589 insertions(+), 574 deletions(-) create mode 100644 xarray/tests/test_eval.py diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index ab26ab0f984..9133471f826 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -7636,580 +7636,6 @@ def test_query(self, backend, engine, parser) -> None: # pytest tests — new tests should go here, rather than in the class. -def test_eval(ds) -> None: - """Test basic eval functionality.""" - actual = ds.eval("z1 + 5") - expect = ds["z1"] + 5 - assert_identical(expect, actual) - - # Use bitwise operators for element-wise operations on arrays - actual = ds.eval("(z1 > 5) & (z2 > 0)") - expect = (ds["z1"] > 5) & (ds["z2"] > 0) - assert_identical(expect, actual) - - -def test_eval_parser_deprecated(ds) -> None: - """Test that passing parser= raises a FutureWarning.""" - with pytest.warns(FutureWarning, match="parser.*deprecated"): - ds.eval("z1 + 5", parser="pandas") - - -def test_eval_logical_operators(ds) -> None: - """Test that 'and'/'or'/'not' are transformed for query() consistency. - - These operators are transformed to '&'/'|'/'~' to match pd.eval() behavior, - which query() uses. This ensures syntax that works in query() also works in - eval(). - """ - # 'and' transformed to '&' - actual = ds.eval("(z1 > 5) and (z2 > 0)") - expect = (ds["z1"] > 5) & (ds["z2"] > 0) - assert_identical(expect, actual) - - # 'or' transformed to '|' - actual = ds.eval("(z1 > 5) or (z2 > 0)") - expect = (ds["z1"] > 5) | (ds["z2"] > 0) - assert_identical(expect, actual) - - # 'not' transformed to '~' - actual = ds.eval("not (z1 > 5)") - expect = ~(ds["z1"] > 5) - assert_identical(expect, actual) - - -def test_eval_ndimensional() -> None: - """Test that eval works with N-dimensional data where N > 2.""" - # Create a 3D dataset - this previously failed with pd.eval - rng = np.random.default_rng(42) - ds = Dataset( - { - "x": (["time", "lat", "lon"], rng.random((3, 4, 5))), - "y": (["time", "lat", "lon"], rng.random((3, 4, 5))), - } - ) - - # Basic arithmetic - actual = ds.eval("x + y") - expect = ds["x"] + ds["y"] - assert_identical(expect, actual) - - # Assignment - actual = ds.eval("z = x + y") - assert "z" in actual.data_vars - assert_equal(ds["x"] + ds["y"], actual["z"]) - - # Complex expression - actual = ds.eval("x * 2 + y ** 2") - expect = ds["x"] * 2 + ds["y"] ** 2 - assert_identical(expect, actual) - - # Comparison - actual = ds.eval("x > y") - expect = ds["x"] > ds["y"] - assert_identical(expect, actual) - - # Use bitwise operators for element-wise boolean operations - actual = ds.eval("(x > 0.5) & (y < 0.5)") - expect = (ds["x"] > 0.5) & (ds["y"] < 0.5) - assert_identical(expect, actual) - - -def test_eval_chained_comparisons() -> None: - """Test that chained comparisons are transformed for query() consistency. - - Chained comparisons like 'a < b < c' are transformed to '(a < b) & (b < c)' - to match pd.eval() behavior, which query() uses. - """ - ds = Dataset({"x": ("dim", np.arange(10))}) - - # Basic chained comparison: 2 < x < 7 - actual = ds.eval("2 < x < 7") - expect = (ds["x"] > 2) & (ds["x"] < 7) - assert_identical(expect, actual) - - # Mixed operators: 0 <= x < 5 - actual = ds.eval("0 <= x < 5") - expect = (ds["x"] >= 0) & (ds["x"] < 5) - assert_identical(expect, actual) - - # Explicit bitwise operators also work - actual = ds.eval("(x > 2) & (x < 7)") - expect = (ds["x"] > 2) & (ds["x"] < 7) - assert_identical(expect, actual) - - -def test_eval_restricted_syntax() -> None: - """Test that eval blocks certain syntax to emulate pd.eval() behavior.""" - ds = Dataset({"a": ("x", [1, 2, 3])}) - - # Private attribute access is not allowed (consistent with pd.eval) - with pytest.raises(ValueError, match="Access to private attributes is not allowed"): - ds.eval("a.__class__") - - with pytest.raises(ValueError, match="Access to private attributes is not allowed"): - ds.eval("a._private") - - # Lambda expressions are not allowed (pd.eval: "Only named functions are supported") - with pytest.raises(ValueError, match="Lambda expressions are not allowed"): - ds.eval("(lambda x: x + 1)(a)") - - # These builtins are not in the namespace - with pytest.raises(NameError): - ds.eval("__import__('os')") - - with pytest.raises(NameError): - ds.eval("open('file.txt')") - - -def test_eval_unsupported_statements() -> None: - """Test that unsupported statement types produce clear errors.""" - ds = Dataset({"a": ("x", [1, 2, 3])}) - - # Augmented assignment is not supported - with pytest.raises(ValueError, match="Unsupported statement type"): - ds.eval("a += 1") - - -def test_eval_functions() -> None: - """Test that numpy and other functions work in eval.""" - ds = Dataset({"a": ("x", [0.0, 1.0, 4.0])}) - - # numpy functions via np namespace should work - result = ds.eval("np.sqrt(a)") - assert_equal(result, np.sqrt(ds["a"])) - - result = ds.eval("np.sin(a) + np.cos(a)") - assert_equal(result, np.sin(ds["a"]) + np.cos(ds["a"])) - - # pandas namespace should work - result = ds.eval("pd.isna(a)") - # pd.isna returns ndarray, not DataArray - np.testing.assert_array_equal(result, pd.isna(ds["a"].values)) - - # xarray namespace should work - result = ds.eval("xr.where(a > 1, a, 0)") - import xarray as xr - - assert_equal(result, xr.where(ds["a"] > 1, ds["a"], 0)) - - # Common builtins should work - result = ds.eval("abs(a - 2)") - assert_equal(result, abs(ds["a"] - 2)) - - result = ds.eval("round(float(a.mean()))") - assert result == round(float(ds["a"].mean())) - - result = ds.eval("len(a)") - assert result == 3 - - result = ds.eval("pow(a, 2)") - assert_equal(result, ds["a"] ** 2) - - # Attribute access on DataArrays should work - result = ds.eval("a.values") - assert isinstance(result, np.ndarray) - - # Method calls on DataArrays should work - result = ds.eval("a.mean()") - assert float(result) == np.mean([0.0, 1.0, 4.0]) - - -def test_eval_extended_builtins() -> None: - """Test extended builtins available in eval namespace. - - These builtins are safe (no I/O, no code execution) and commonly needed - for typical xarray operations like slicing, type conversion, and iteration. - """ - ds = Dataset( - {"a": ("x", [1.0, 2.0, 3.0, 4.0, 5.0])}, - coords={"time": pd.date_range("2019-01-01", periods=5)}, - ) - - # slice - essential for .sel() with ranges - result = ds.eval("a.sel(x=slice(1, 3))") - expected = ds["a"].sel(x=slice(1, 3)) - assert_equal(result, expected) - - # str - type constructor - result = ds.eval("str(int(a.mean()))") - assert result == "3" - - # list, tuple - type constructors - result = ds.eval("list(range(3))") - assert result == [0, 1, 2] - - result = ds.eval("tuple(range(3))") - assert result == (0, 1, 2) - - # dict, set - type constructors - result = ds.eval("dict(x=1, y=2)") - assert result == {"x": 1, "y": 2} - - result = ds.eval("set([1, 2, 2, 3])") - assert result == {1, 2, 3} - - # range - iteration - result = ds.eval("list(range(3))") - assert result == [0, 1, 2] - - # zip, enumerate - iteration helpers - result = ds.eval("list(zip([1, 2], [3, 4]))") - assert result == [(1, 3), (2, 4)] - - result = ds.eval("list(enumerate(['a', 'b']))") - assert result == [(0, "a"), (1, "b")] - - # map, filter - functional helpers - result = ds.eval("list(map(abs, [-1, -2, 3]))") - assert result == [1, 2, 3] - - result = ds.eval("list(filter(bool, [0, 1, 0, 2]))") - assert result == [1, 2] - - # any, all - aggregation - result = ds.eval("any([False, True, False])") - assert result is True - - result = ds.eval("all([True, True, True])") - assert result is True - - result = ds.eval("all([True, False, True])") - assert result is False - - -def test_eval_data_variable_priority() -> None: - """Test that data variables take priority over builtin functions. - - Users may have data variables named 'sum', 'abs', 'min', etc. When they - reference these in eval(), they should get their data, not the Python builtins. - The builtins should still be accessible via the np namespace (np.sum, np.abs). - """ - # Create dataset with data variables that shadow builtins - ds = Dataset( - { - "sum": ("x", [10.0, 20.0, 30.0]), # shadows builtin sum - "abs": ("x", [1.0, 2.0, 3.0]), # shadows builtin abs - "min": ("x", [100.0, 200.0, 300.0]), # shadows builtin min - "other": ("x", [5.0, 10.0, 15.0]), - } - ) - - # Data variables should take priority - user data wins - result = ds.eval("sum + other") - expected = ds["sum"] + ds["other"] - assert_equal(result, expected) - - # Should get the data variable, not builtin sum applied to something - result = ds.eval("sum * 2") - expected = ds["sum"] * 2 - assert_equal(result, expected) - - # abs as data variable should work - result = ds.eval("abs + 1") - expected = ds["abs"] + 1 - assert_equal(result, expected) - - # min as data variable should work - result = ds.eval("min - 50") - expected = ds["min"] - 50 - assert_equal(result, expected) - - # np namespace should still provide access to actual functions - result = ds.eval("np.abs(other - 10)") - expected = abs(ds["other"] - 10) - assert_equal(result, expected) - - # np.sum should work even when 'sum' is a data variable - result = ds.eval("np.sum(other)") - expected = np.sum(ds["other"]) - assert result == expected - - -def test_eval_coordinate_priority() -> None: - """Test that coordinates also take priority over builtins.""" - ds = Dataset( - {"data": ("x", [1.0, 2.0, 3.0])}, - coords={"sum": ("x", [10.0, 20.0, 30.0])}, # coordinate named 'sum' - ) - - # Coordinate should be accessible and take priority over builtin - result = ds.eval("data + sum") - expected = ds["data"] + ds.coords["sum"] - assert_equal(result, expected) - - -class TestEvalErrorMessages: - """Test that eval() produces clear error messages for common mistakes.""" - - def test_undefined_variable(self) -> None: - """Test error message when referencing an undefined variable.""" - ds = Dataset({"a": ("x", [1, 2, 3])}) - with pytest.raises(NameError, match="undefined_var"): - ds.eval("undefined_var + a") - - def test_syntax_error(self) -> None: - """Test error message for malformed expressions.""" - ds = Dataset({"a": ("x", [1, 2, 3])}) - with pytest.raises(ValueError, match="Invalid"): - ds.eval("a +") - - def test_invalid_assignment(self) -> None: - """Test error message when assignment target is invalid.""" - ds = Dataset({"a": ("x", [1, 2, 3])}) - # "1 = a" should fail during parsing - can't assign to a literal - with pytest.raises(ValueError, match="Invalid"): - ds.eval("1 = a") - - def test_dunder_access(self) -> None: - """Test error message when trying to access dunder attributes.""" - ds = Dataset({"a": ("x", [1, 2, 3])}) - with pytest.raises(ValueError, match="private attributes"): - ds.eval("a.__class__") - - def test_missing_method(self) -> None: - """Test error message when calling a nonexistent method.""" - ds = Dataset({"a": ("x", [1, 2, 3])}) - # This should raise AttributeError from the DataArray - with pytest.raises(AttributeError, match="nonexistent_method"): - ds.eval("a.nonexistent_method()") - - def test_type_error_in_expression(self) -> None: - """Test error message when types are incompatible.""" - ds = Dataset({"a": ("x", [1, 2, 3])}) - # Adding string to numeric array should raise TypeError or similar - with pytest.raises((TypeError, np.exceptions.DTypePromotionError)): - ds.eval("a + 'string'") - - -class TestEvalEdgeCases: - """Test edge cases for eval().""" - - def test_empty_expression(self) -> None: - """Test handling of empty expression string.""" - ds = Dataset({"a": ("x", [1, 2, 3])}) - with pytest.raises(ValueError): - ds.eval("") - - def test_whitespace_only_expression(self) -> None: - """Test handling of whitespace-only expression.""" - ds = Dataset({"a": ("x", [1, 2, 3])}) - with pytest.raises(ValueError): - ds.eval(" ") - - def test_just_variable_name(self) -> None: - """Test that just a variable name returns the variable.""" - ds = Dataset({"a": ("x", [1, 2, 3])}) - result = ds.eval("a") - expected = ds["a"] - assert_equal(result, expected) - - def test_unicode_variable_names(self) -> None: - """Test that unicode variable names work in expressions.""" - # Greek letters are valid Python identifiers - ds = Dataset({"α": ("x", [1.0, 2.0, 3.0]), "β": ("x", [4.0, 5.0, 6.0])}) - result = ds.eval("α + β") - expected = ds["α"] + ds["β"] - assert_equal(result, expected) - - def test_long_expression(self) -> None: - """Test that very long expressions work correctly.""" - ds = Dataset({"a": ("x", [1.0, 2.0, 3.0])}) - # Build a long expression: a + a + a + ... (50 times) - long_expr = " + ".join(["a"] * 50) - result = ds.eval(long_expr) - expected = ds["a"] * 50 - assert_equal(result, expected) - - -class TestEvalDask: - """Test Dataset.eval() with dask-backed arrays.""" - - @requires_dask - def test_basic_arithmetic_preserves_dask(self) -> None: - """Test that basic arithmetic with dask arrays returns dask-backed result.""" - from xarray.core.utils import is_duck_dask_array - - ds = Dataset( - {"a": ("x", np.arange(10.0)), "b": ("x", np.linspace(0, 1, 10))} - ).chunk({"x": 5}) - - with raise_if_dask_computes(): - result = ds.eval("a + b") - - assert isinstance(result, DataArray) - assert is_duck_dask_array(result.data) - - # Verify correctness when computed - expected = ds["a"] + ds["b"] - assert_equal(result, expected) - - @requires_dask - def test_assignment_preserves_dask(self) -> None: - """Test that assignments with dask arrays preserve lazy evaluation.""" - from xarray.core.utils import is_duck_dask_array - - ds = Dataset( - {"a": ("x", np.arange(10.0)), "b": ("x", np.linspace(0, 1, 10))} - ).chunk({"x": 5}) - - with raise_if_dask_computes(): - result = ds.eval("z = a + b") - - assert isinstance(result, Dataset) - assert "z" in result.data_vars - assert is_duck_dask_array(result["z"].data) - - # Verify correctness when computed - expected = ds["a"] + ds["b"] - assert_equal(result["z"], expected) - - @requires_dask - def test_method_chaining_with_compute(self) -> None: - """Test that method chaining works with dask arrays.""" - ds = Dataset({"a": (("x", "y"), np.arange(20.0).reshape(4, 5))}).chunk( - {"x": 2, "y": 5} - ) - - # Calling .mean() should still be lazy - result = ds.eval("a.mean(dim='x')") - # Calling .compute() should return numpy-backed result - computed = result.compute() - - expected = ds["a"].mean(dim="x").compute() - assert_equal(computed, expected) - - @requires_dask - def test_xr_where_preserves_dask(self) -> None: - """Test that xr.where() with dask arrays preserves lazy evaluation.""" - from xarray.core.utils import is_duck_dask_array - - ds = Dataset({"a": ("x", np.arange(-5, 5, dtype=float))}).chunk({"x": 5}) - - with raise_if_dask_computes(): - result = ds.eval("xr.where(a > 0, a, 0)") - - assert isinstance(result, DataArray) - assert is_duck_dask_array(result.data) - - # Verify correctness when computed - expected = xr.where(ds["a"] > 0, ds["a"], 0) - assert_equal(result, expected) - - @requires_dask - def test_complex_expression_preserves_dask(self) -> None: - """Test that complex expressions preserve dask backing.""" - from xarray.core.utils import is_duck_dask_array - - rng = np.random.default_rng(42) - ds = Dataset( - { - "x": (["time", "lat", "lon"], rng.random((3, 4, 5))), - "y": (["time", "lat", "lon"], rng.random((3, 4, 5))), - } - ).chunk({"time": 1, "lat": 2, "lon": 5}) - - with raise_if_dask_computes(): - result = ds.eval("x * 2 + y ** 2") - - assert is_duck_dask_array(result.data) - - # Verify correctness when computed - expected = ds["x"] * 2 + ds["y"] ** 2 - assert_equal(result, expected) - - @requires_dask - def test_mixed_dask_and_numpy(self) -> None: - """Test expressions with mixed dask and numpy arrays.""" - from xarray.core.utils import is_duck_dask_array - - ds = Dataset( - { - "dask_var": ("x", np.arange(10.0)), - "numpy_var": ("x", np.linspace(0, 1, 10)), - } - ) - # Only chunk one variable - ds["dask_var"] = ds["dask_var"].chunk({"x": 5}) - - with raise_if_dask_computes(): - result = ds.eval("dask_var + numpy_var") - - # Result should be dask-backed when any input is dask - assert is_duck_dask_array(result.data) - - # Verify correctness - expected = ds["dask_var"] + ds["numpy_var"] - assert_equal(result, expected) - - @requires_dask - def test_np_functions_preserve_dask(self) -> None: - """Test that numpy functions via np namespace preserve dask.""" - from xarray.core.utils import is_duck_dask_array - - ds = Dataset({"a": ("x", np.arange(1.0, 11.0))}).chunk({"x": 5}) - - with raise_if_dask_computes(): - result = ds.eval("np.sqrt(a)") - - assert is_duck_dask_array(result.data) - - # Verify correctness - expected = np.sqrt(ds["a"]) - assert_equal(result, expected) - - @requires_dask - def test_comparison_preserves_dask(self) -> None: - """Test that comparison operations preserve dask backing.""" - from xarray.core.utils import is_duck_dask_array - - ds = Dataset( - {"a": ("x", np.arange(10.0)), "b": ("x", np.arange(10.0)[::-1])} - ).chunk({"x": 5}) - - with raise_if_dask_computes(): - result = ds.eval("a > b") - - assert is_duck_dask_array(result.data) - - # Verify correctness - expected = ds["a"] > ds["b"] - assert_equal(result, expected) - - @requires_dask - def test_boolean_operators_preserve_dask(self) -> None: - """Test that bitwise boolean operators preserve dask.""" - from xarray.core.utils import is_duck_dask_array - - ds = Dataset( - {"a": ("x", np.arange(10.0)), "b": ("x", np.arange(10.0)[::-1])} - ).chunk({"x": 5}) - - with raise_if_dask_computes(): - result = ds.eval("(a > 3) & (b < 7)") - - assert is_duck_dask_array(result.data) - - # Verify correctness - expected = (ds["a"] > 3) & (ds["b"] < 7) - assert_equal(result, expected) - - @requires_dask - def test_chained_comparisons_preserve_dask(self) -> None: - """Test that chained comparisons preserve dask backing.""" - from xarray.core.utils import is_duck_dask_array - - ds = Dataset({"x": ("dim", np.arange(10.0))}).chunk({"dim": 5}) - - with raise_if_dask_computes(): - result = ds.eval("2 < x < 7") - - assert is_duck_dask_array(result.data) - - # Verify correctness - expected = (ds["x"] > 2) & (ds["x"] < 7) - assert_equal(result, expected) - - @pytest.mark.parametrize("test_elements", ([1, 2], np.array([1, 2]), DataArray([1, 2]))) def test_isin(test_elements, backend) -> None: expected = Dataset( diff --git a/xarray/tests/test_eval.py b/xarray/tests/test_eval.py new file mode 100644 index 00000000000..7a0bac6000b --- /dev/null +++ b/xarray/tests/test_eval.py @@ -0,0 +1,589 @@ +"""Tests for Dataset.eval() functionality.""" + +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest + +import xarray as xr +from xarray import DataArray, Dataset +from xarray.tests import ( + assert_equal, + assert_identical, + raise_if_dask_computes, + requires_dask, +) + + +def test_eval(ds) -> None: + """Test basic eval functionality.""" + actual = ds.eval("z1 + 5") + expect = ds["z1"] + 5 + assert_identical(expect, actual) + + # Use bitwise operators for element-wise operations on arrays + actual = ds.eval("(z1 > 5) & (z2 > 0)") + expect = (ds["z1"] > 5) & (ds["z2"] > 0) + assert_identical(expect, actual) + + +def test_eval_parser_deprecated(ds) -> None: + """Test that passing parser= raises a FutureWarning.""" + with pytest.warns(FutureWarning, match="parser.*deprecated"): + ds.eval("z1 + 5", parser="pandas") + + +def test_eval_logical_operators(ds) -> None: + """Test that 'and'/'or'/'not' are transformed for query() consistency. + + These operators are transformed to '&'/'|'/'~' to match pd.eval() behavior, + which query() uses. This ensures syntax that works in query() also works in + eval(). + """ + # 'and' transformed to '&' + actual = ds.eval("(z1 > 5) and (z2 > 0)") + expect = (ds["z1"] > 5) & (ds["z2"] > 0) + assert_identical(expect, actual) + + # 'or' transformed to '|' + actual = ds.eval("(z1 > 5) or (z2 > 0)") + expect = (ds["z1"] > 5) | (ds["z2"] > 0) + assert_identical(expect, actual) + + # 'not' transformed to '~' + actual = ds.eval("not (z1 > 5)") + expect = ~(ds["z1"] > 5) + assert_identical(expect, actual) + + +def test_eval_ndimensional() -> None: + """Test that eval works with N-dimensional data where N > 2.""" + # Create a 3D dataset - this previously failed with pd.eval + rng = np.random.default_rng(42) + ds = Dataset( + { + "x": (["time", "lat", "lon"], rng.random((3, 4, 5))), + "y": (["time", "lat", "lon"], rng.random((3, 4, 5))), + } + ) + + # Basic arithmetic + actual = ds.eval("x + y") + expect = ds["x"] + ds["y"] + assert_identical(expect, actual) + + # Assignment + actual = ds.eval("z = x + y") + assert "z" in actual.data_vars + assert_equal(ds["x"] + ds["y"], actual["z"]) + + # Complex expression + actual = ds.eval("x * 2 + y ** 2") + expect = ds["x"] * 2 + ds["y"] ** 2 + assert_identical(expect, actual) + + # Comparison + actual = ds.eval("x > y") + expect = ds["x"] > ds["y"] + assert_identical(expect, actual) + + # Use bitwise operators for element-wise boolean operations + actual = ds.eval("(x > 0.5) & (y < 0.5)") + expect = (ds["x"] > 0.5) & (ds["y"] < 0.5) + assert_identical(expect, actual) + + +def test_eval_chained_comparisons() -> None: + """Test that chained comparisons are transformed for query() consistency. + + Chained comparisons like 'a < b < c' are transformed to '(a < b) & (b < c)' + to match pd.eval() behavior, which query() uses. + """ + ds = Dataset({"x": ("dim", np.arange(10))}) + + # Basic chained comparison: 2 < x < 7 + actual = ds.eval("2 < x < 7") + expect = (ds["x"] > 2) & (ds["x"] < 7) + assert_identical(expect, actual) + + # Mixed operators: 0 <= x < 5 + actual = ds.eval("0 <= x < 5") + expect = (ds["x"] >= 0) & (ds["x"] < 5) + assert_identical(expect, actual) + + # Explicit bitwise operators also work + actual = ds.eval("(x > 2) & (x < 7)") + expect = (ds["x"] > 2) & (ds["x"] < 7) + assert_identical(expect, actual) + + +def test_eval_restricted_syntax() -> None: + """Test that eval blocks certain syntax to emulate pd.eval() behavior.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + + # Private attribute access is not allowed (consistent with pd.eval) + with pytest.raises(ValueError, match="Access to private attributes is not allowed"): + ds.eval("a.__class__") + + with pytest.raises(ValueError, match="Access to private attributes is not allowed"): + ds.eval("a._private") + + # Lambda expressions are not allowed (pd.eval: "Only named functions are supported") + with pytest.raises(ValueError, match="Lambda expressions are not allowed"): + ds.eval("(lambda x: x + 1)(a)") + + # These builtins are not in the namespace + with pytest.raises(NameError): + ds.eval("__import__('os')") + + with pytest.raises(NameError): + ds.eval("open('file.txt')") + + +def test_eval_unsupported_statements() -> None: + """Test that unsupported statement types produce clear errors.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + + # Augmented assignment is not supported + with pytest.raises(ValueError, match="Unsupported statement type"): + ds.eval("a += 1") + + +def test_eval_functions() -> None: + """Test that numpy and other functions work in eval.""" + ds = Dataset({"a": ("x", [0.0, 1.0, 4.0])}) + + # numpy functions via np namespace should work + result = ds.eval("np.sqrt(a)") + assert_equal(result, np.sqrt(ds["a"])) + + result = ds.eval("np.sin(a) + np.cos(a)") + assert_equal(result, np.sin(ds["a"]) + np.cos(ds["a"])) + + # pandas namespace should work + result = ds.eval("pd.isna(a)") + # pd.isna returns ndarray, not DataArray + np.testing.assert_array_equal(result, pd.isna(ds["a"].values)) + + # xarray namespace should work + result = ds.eval("xr.where(a > 1, a, 0)") + + assert_equal(result, xr.where(ds["a"] > 1, ds["a"], 0)) + + # Common builtins should work + result = ds.eval("abs(a - 2)") + assert_equal(result, abs(ds["a"] - 2)) + + result = ds.eval("round(float(a.mean()))") + assert result == round(float(ds["a"].mean())) + + result = ds.eval("len(a)") + assert result == 3 + + result = ds.eval("pow(a, 2)") + assert_equal(result, ds["a"] ** 2) + + # Attribute access on DataArrays should work + result = ds.eval("a.values") + assert isinstance(result, np.ndarray) + + # Method calls on DataArrays should work + result = ds.eval("a.mean()") + assert float(result) == np.mean([0.0, 1.0, 4.0]) + + +def test_eval_extended_builtins() -> None: + """Test extended builtins available in eval namespace. + + These builtins are safe (no I/O, no code execution) and commonly needed + for typical xarray operations like slicing, type conversion, and iteration. + """ + ds = Dataset( + {"a": ("x", [1.0, 2.0, 3.0, 4.0, 5.0])}, + coords={"time": pd.date_range("2019-01-01", periods=5)}, + ) + + # slice - essential for .sel() with ranges + result = ds.eval("a.sel(x=slice(1, 3))") + expected = ds["a"].sel(x=slice(1, 3)) + assert_equal(result, expected) + + # str - type constructor + result = ds.eval("str(int(a.mean()))") + assert result == "3" + + # list, tuple - type constructors + result = ds.eval("list(range(3))") + assert result == [0, 1, 2] + + result = ds.eval("tuple(range(3))") + assert result == (0, 1, 2) + + # dict, set - type constructors + result = ds.eval("dict(x=1, y=2)") + assert result == {"x": 1, "y": 2} + + result = ds.eval("set([1, 2, 2, 3])") + assert result == {1, 2, 3} + + # range - iteration + result = ds.eval("list(range(3))") + assert result == [0, 1, 2] + + # zip, enumerate - iteration helpers + result = ds.eval("list(zip([1, 2], [3, 4]))") + assert result == [(1, 3), (2, 4)] + + result = ds.eval("list(enumerate(['a', 'b']))") + assert result == [(0, "a"), (1, "b")] + + # map, filter - functional helpers + result = ds.eval("list(map(abs, [-1, -2, 3]))") + assert result == [1, 2, 3] + + result = ds.eval("list(filter(bool, [0, 1, 0, 2]))") + assert result == [1, 2] + + # any, all - aggregation + result = ds.eval("any([False, True, False])") + assert result is True + + result = ds.eval("all([True, True, True])") + assert result is True + + result = ds.eval("all([True, False, True])") + assert result is False + + +def test_eval_data_variable_priority() -> None: + """Test that data variables take priority over builtin functions. + + Users may have data variables named 'sum', 'abs', 'min', etc. When they + reference these in eval(), they should get their data, not the Python builtins. + The builtins should still be accessible via the np namespace (np.sum, np.abs). + """ + # Create dataset with data variables that shadow builtins + ds = Dataset( + { + "sum": ("x", [10.0, 20.0, 30.0]), # shadows builtin sum + "abs": ("x", [1.0, 2.0, 3.0]), # shadows builtin abs + "min": ("x", [100.0, 200.0, 300.0]), # shadows builtin min + "other": ("x", [5.0, 10.0, 15.0]), + } + ) + + # Data variables should take priority - user data wins + result = ds.eval("sum + other") + expected = ds["sum"] + ds["other"] + assert_equal(result, expected) + + # Should get the data variable, not builtin sum applied to something + result = ds.eval("sum * 2") + expected = ds["sum"] * 2 + assert_equal(result, expected) + + # abs as data variable should work + result = ds.eval("abs + 1") + expected = ds["abs"] + 1 + assert_equal(result, expected) + + # min as data variable should work + result = ds.eval("min - 50") + expected = ds["min"] - 50 + assert_equal(result, expected) + + # np namespace should still provide access to actual functions + result = ds.eval("np.abs(other - 10)") + expected = abs(ds["other"] - 10) + assert_equal(result, expected) + + # np.sum should work even when 'sum' is a data variable + result = ds.eval("np.sum(other)") + expected = np.sum(ds["other"]) + assert result == expected + + +def test_eval_coordinate_priority() -> None: + """Test that coordinates also take priority over builtins.""" + ds = Dataset( + {"data": ("x", [1.0, 2.0, 3.0])}, + coords={"sum": ("x", [10.0, 20.0, 30.0])}, # coordinate named 'sum' + ) + + # Coordinate should be accessible and take priority over builtin + result = ds.eval("data + sum") + expected = ds["data"] + ds.coords["sum"] + assert_equal(result, expected) + + +class TestEvalErrorMessages: + """Test that eval() produces clear error messages for common mistakes.""" + + def test_undefined_variable(self) -> None: + """Test error message when referencing an undefined variable.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + with pytest.raises(NameError, match="undefined_var"): + ds.eval("undefined_var + a") + + def test_syntax_error(self) -> None: + """Test error message for malformed expressions.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + with pytest.raises(ValueError, match="Invalid"): + ds.eval("a +") + + def test_invalid_assignment(self) -> None: + """Test error message when assignment target is invalid.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + # "1 = a" should fail during parsing - can't assign to a literal + with pytest.raises(ValueError, match="Invalid"): + ds.eval("1 = a") + + def test_dunder_access(self) -> None: + """Test error message when trying to access dunder attributes.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + with pytest.raises(ValueError, match="private attributes"): + ds.eval("a.__class__") + + def test_missing_method(self) -> None: + """Test error message when calling a nonexistent method.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + # This should raise AttributeError from the DataArray + with pytest.raises(AttributeError, match="nonexistent_method"): + ds.eval("a.nonexistent_method()") + + def test_type_error_in_expression(self) -> None: + """Test error message when types are incompatible.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + # Adding string to numeric array should raise TypeError or similar + with pytest.raises((TypeError, np.exceptions.DTypePromotionError)): + ds.eval("a + 'string'") + + +class TestEvalEdgeCases: + """Test edge cases for eval().""" + + def test_empty_expression(self) -> None: + """Test handling of empty expression string.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + with pytest.raises(ValueError): + ds.eval("") + + def test_whitespace_only_expression(self) -> None: + """Test handling of whitespace-only expression.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + with pytest.raises(ValueError): + ds.eval(" ") + + def test_just_variable_name(self) -> None: + """Test that just a variable name returns the variable.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + result = ds.eval("a") + expected = ds["a"] + assert_equal(result, expected) + + def test_unicode_variable_names(self) -> None: + """Test that unicode variable names work in expressions.""" + # Greek letters are valid Python identifiers + ds = Dataset({"α": ("x", [1.0, 2.0, 3.0]), "β": ("x", [4.0, 5.0, 6.0])}) + result = ds.eval("α + β") + expected = ds["α"] + ds["β"] + assert_equal(result, expected) + + def test_long_expression(self) -> None: + """Test that very long expressions work correctly.""" + ds = Dataset({"a": ("x", [1.0, 2.0, 3.0])}) + # Build a long expression: a + a + a + ... (50 times) + long_expr = " + ".join(["a"] * 50) + result = ds.eval(long_expr) + expected = ds["a"] * 50 + assert_equal(result, expected) + + +class TestEvalDask: + """Test Dataset.eval() with dask-backed arrays.""" + + @requires_dask + def test_basic_arithmetic_preserves_dask(self) -> None: + """Test that basic arithmetic with dask arrays returns dask-backed result.""" + from xarray.core.utils import is_duck_dask_array + + ds = Dataset( + {"a": ("x", np.arange(10.0)), "b": ("x", np.linspace(0, 1, 10))} + ).chunk({"x": 5}) + + with raise_if_dask_computes(): + result = ds.eval("a + b") + + assert isinstance(result, DataArray) + assert is_duck_dask_array(result.data) + + # Verify correctness when computed + expected = ds["a"] + ds["b"] + assert_equal(result, expected) + + @requires_dask + def test_assignment_preserves_dask(self) -> None: + """Test that assignments with dask arrays preserve lazy evaluation.""" + from xarray.core.utils import is_duck_dask_array + + ds = Dataset( + {"a": ("x", np.arange(10.0)), "b": ("x", np.linspace(0, 1, 10))} + ).chunk({"x": 5}) + + with raise_if_dask_computes(): + result = ds.eval("z = a + b") + + assert isinstance(result, Dataset) + assert "z" in result.data_vars + assert is_duck_dask_array(result["z"].data) + + # Verify correctness when computed + expected = ds["a"] + ds["b"] + assert_equal(result["z"], expected) + + @requires_dask + def test_method_chaining_with_compute(self) -> None: + """Test that method chaining works with dask arrays.""" + ds = Dataset({"a": (("x", "y"), np.arange(20.0).reshape(4, 5))}).chunk( + {"x": 2, "y": 5} + ) + + # Calling .mean() should still be lazy + result = ds.eval("a.mean(dim='x')") + # Calling .compute() should return numpy-backed result + computed = result.compute() + + expected = ds["a"].mean(dim="x").compute() + assert_equal(computed, expected) + + @requires_dask + def test_xr_where_preserves_dask(self) -> None: + """Test that xr.where() with dask arrays preserves lazy evaluation.""" + from xarray.core.utils import is_duck_dask_array + + ds = Dataset({"a": ("x", np.arange(-5, 5, dtype=float))}).chunk({"x": 5}) + + with raise_if_dask_computes(): + result = ds.eval("xr.where(a > 0, a, 0)") + + assert isinstance(result, DataArray) + assert is_duck_dask_array(result.data) + + # Verify correctness when computed + expected = xr.where(ds["a"] > 0, ds["a"], 0) + assert_equal(result, expected) + + @requires_dask + def test_complex_expression_preserves_dask(self) -> None: + """Test that complex expressions preserve dask backing.""" + from xarray.core.utils import is_duck_dask_array + + rng = np.random.default_rng(42) + ds = Dataset( + { + "x": (["time", "lat", "lon"], rng.random((3, 4, 5))), + "y": (["time", "lat", "lon"], rng.random((3, 4, 5))), + } + ).chunk({"time": 1, "lat": 2, "lon": 5}) + + with raise_if_dask_computes(): + result = ds.eval("x * 2 + y ** 2") + + assert is_duck_dask_array(result.data) + + # Verify correctness when computed + expected = ds["x"] * 2 + ds["y"] ** 2 + assert_equal(result, expected) + + @requires_dask + def test_mixed_dask_and_numpy(self) -> None: + """Test expressions with mixed dask and numpy arrays.""" + from xarray.core.utils import is_duck_dask_array + + ds = Dataset( + { + "dask_var": ("x", np.arange(10.0)), + "numpy_var": ("x", np.linspace(0, 1, 10)), + } + ) + # Only chunk one variable + ds["dask_var"] = ds["dask_var"].chunk({"x": 5}) + + with raise_if_dask_computes(): + result = ds.eval("dask_var + numpy_var") + + # Result should be dask-backed when any input is dask + assert is_duck_dask_array(result.data) + + # Verify correctness + expected = ds["dask_var"] + ds["numpy_var"] + assert_equal(result, expected) + + @requires_dask + def test_np_functions_preserve_dask(self) -> None: + """Test that numpy functions via np namespace preserve dask.""" + from xarray.core.utils import is_duck_dask_array + + ds = Dataset({"a": ("x", np.arange(1.0, 11.0))}).chunk({"x": 5}) + + with raise_if_dask_computes(): + result = ds.eval("np.sqrt(a)") + + assert is_duck_dask_array(result.data) + + # Verify correctness + expected = np.sqrt(ds["a"]) + assert_equal(result, expected) + + @requires_dask + def test_comparison_preserves_dask(self) -> None: + """Test that comparison operations preserve dask backing.""" + from xarray.core.utils import is_duck_dask_array + + ds = Dataset( + {"a": ("x", np.arange(10.0)), "b": ("x", np.arange(10.0)[::-1])} + ).chunk({"x": 5}) + + with raise_if_dask_computes(): + result = ds.eval("a > b") + + assert is_duck_dask_array(result.data) + + # Verify correctness + expected = ds["a"] > ds["b"] + assert_equal(result, expected) + + @requires_dask + def test_boolean_operators_preserve_dask(self) -> None: + """Test that bitwise boolean operators preserve dask.""" + from xarray.core.utils import is_duck_dask_array + + ds = Dataset( + {"a": ("x", np.arange(10.0)), "b": ("x", np.arange(10.0)[::-1])} + ).chunk({"x": 5}) + + with raise_if_dask_computes(): + result = ds.eval("(a > 3) & (b < 7)") + + assert is_duck_dask_array(result.data) + + # Verify correctness + expected = (ds["a"] > 3) & (ds["b"] < 7) + assert_equal(result, expected) + + @requires_dask + def test_chained_comparisons_preserve_dask(self) -> None: + """Test that chained comparisons preserve dask backing.""" + from xarray.core.utils import is_duck_dask_array + + ds = Dataset({"x": ("dim", np.arange(10.0))}).chunk({"dim": 5}) + + with raise_if_dask_computes(): + result = ds.eval("2 < x < 7") + + assert is_duck_dask_array(result.data) + + # Verify correctness + expected = (ds["x"] > 2) & (ds["x"] < 7) + assert_equal(result, expected) From 46bf9ef46fa0a9d12429dc039214c5a543a99875 Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Sun, 4 Jan 2026 19:20:36 -0800 Subject: [PATCH 6/6] Refactor eval tests: convert classes to standalone functions Address review feedback: - Convert TestEvalErrorMessages class to test_eval_error_* functions - Convert TestEvalEdgeCases class to test_eval_* functions - Convert TestEvalDask class to test_eval_dask_* functions This follows xarray's preference for standalone test functions over classes. Co-authored-by: Claude --- xarray/tests/test_eval.py | 500 ++++++++++++++++++++------------------ 1 file changed, 259 insertions(+), 241 deletions(-) diff --git a/xarray/tests/test_eval.py b/xarray/tests/test_eval.py index 7a0bac6000b..cfb046362a3 100644 --- a/xarray/tests/test_eval.py +++ b/xarray/tests/test_eval.py @@ -317,273 +317,291 @@ def test_eval_coordinate_priority() -> None: assert_equal(result, expected) -class TestEvalErrorMessages: - """Test that eval() produces clear error messages for common mistakes.""" - - def test_undefined_variable(self) -> None: - """Test error message when referencing an undefined variable.""" - ds = Dataset({"a": ("x", [1, 2, 3])}) - with pytest.raises(NameError, match="undefined_var"): - ds.eval("undefined_var + a") - - def test_syntax_error(self) -> None: - """Test error message for malformed expressions.""" - ds = Dataset({"a": ("x", [1, 2, 3])}) - with pytest.raises(ValueError, match="Invalid"): - ds.eval("a +") - - def test_invalid_assignment(self) -> None: - """Test error message when assignment target is invalid.""" - ds = Dataset({"a": ("x", [1, 2, 3])}) - # "1 = a" should fail during parsing - can't assign to a literal - with pytest.raises(ValueError, match="Invalid"): - ds.eval("1 = a") - - def test_dunder_access(self) -> None: - """Test error message when trying to access dunder attributes.""" - ds = Dataset({"a": ("x", [1, 2, 3])}) - with pytest.raises(ValueError, match="private attributes"): - ds.eval("a.__class__") - - def test_missing_method(self) -> None: - """Test error message when calling a nonexistent method.""" - ds = Dataset({"a": ("x", [1, 2, 3])}) - # This should raise AttributeError from the DataArray - with pytest.raises(AttributeError, match="nonexistent_method"): - ds.eval("a.nonexistent_method()") - - def test_type_error_in_expression(self) -> None: - """Test error message when types are incompatible.""" - ds = Dataset({"a": ("x", [1, 2, 3])}) - # Adding string to numeric array should raise TypeError or similar - with pytest.raises((TypeError, np.exceptions.DTypePromotionError)): - ds.eval("a + 'string'") - - -class TestEvalEdgeCases: - """Test edge cases for eval().""" - - def test_empty_expression(self) -> None: - """Test handling of empty expression string.""" - ds = Dataset({"a": ("x", [1, 2, 3])}) - with pytest.raises(ValueError): - ds.eval("") - - def test_whitespace_only_expression(self) -> None: - """Test handling of whitespace-only expression.""" - ds = Dataset({"a": ("x", [1, 2, 3])}) - with pytest.raises(ValueError): - ds.eval(" ") - - def test_just_variable_name(self) -> None: - """Test that just a variable name returns the variable.""" - ds = Dataset({"a": ("x", [1, 2, 3])}) - result = ds.eval("a") - expected = ds["a"] - assert_equal(result, expected) - - def test_unicode_variable_names(self) -> None: - """Test that unicode variable names work in expressions.""" - # Greek letters are valid Python identifiers - ds = Dataset({"α": ("x", [1.0, 2.0, 3.0]), "β": ("x", [4.0, 5.0, 6.0])}) - result = ds.eval("α + β") - expected = ds["α"] + ds["β"] - assert_equal(result, expected) - - def test_long_expression(self) -> None: - """Test that very long expressions work correctly.""" - ds = Dataset({"a": ("x", [1.0, 2.0, 3.0])}) - # Build a long expression: a + a + a + ... (50 times) - long_expr = " + ".join(["a"] * 50) - result = ds.eval(long_expr) - expected = ds["a"] * 50 - assert_equal(result, expected) - - -class TestEvalDask: - """Test Dataset.eval() with dask-backed arrays.""" - - @requires_dask - def test_basic_arithmetic_preserves_dask(self) -> None: - """Test that basic arithmetic with dask arrays returns dask-backed result.""" - from xarray.core.utils import is_duck_dask_array - - ds = Dataset( - {"a": ("x", np.arange(10.0)), "b": ("x", np.linspace(0, 1, 10))} - ).chunk({"x": 5}) - - with raise_if_dask_computes(): - result = ds.eval("a + b") - - assert isinstance(result, DataArray) - assert is_duck_dask_array(result.data) - - # Verify correctness when computed - expected = ds["a"] + ds["b"] - assert_equal(result, expected) - - @requires_dask - def test_assignment_preserves_dask(self) -> None: - """Test that assignments with dask arrays preserve lazy evaluation.""" - from xarray.core.utils import is_duck_dask_array - - ds = Dataset( - {"a": ("x", np.arange(10.0)), "b": ("x", np.linspace(0, 1, 10))} - ).chunk({"x": 5}) - - with raise_if_dask_computes(): - result = ds.eval("z = a + b") - - assert isinstance(result, Dataset) - assert "z" in result.data_vars - assert is_duck_dask_array(result["z"].data) - - # Verify correctness when computed - expected = ds["a"] + ds["b"] - assert_equal(result["z"], expected) - - @requires_dask - def test_method_chaining_with_compute(self) -> None: - """Test that method chaining works with dask arrays.""" - ds = Dataset({"a": (("x", "y"), np.arange(20.0).reshape(4, 5))}).chunk( - {"x": 2, "y": 5} - ) - - # Calling .mean() should still be lazy - result = ds.eval("a.mean(dim='x')") - # Calling .compute() should return numpy-backed result - computed = result.compute() - - expected = ds["a"].mean(dim="x").compute() - assert_equal(computed, expected) - - @requires_dask - def test_xr_where_preserves_dask(self) -> None: - """Test that xr.where() with dask arrays preserves lazy evaluation.""" - from xarray.core.utils import is_duck_dask_array - - ds = Dataset({"a": ("x", np.arange(-5, 5, dtype=float))}).chunk({"x": 5}) - - with raise_if_dask_computes(): - result = ds.eval("xr.where(a > 0, a, 0)") - - assert isinstance(result, DataArray) - assert is_duck_dask_array(result.data) - - # Verify correctness when computed - expected = xr.where(ds["a"] > 0, ds["a"], 0) - assert_equal(result, expected) - - @requires_dask - def test_complex_expression_preserves_dask(self) -> None: - """Test that complex expressions preserve dask backing.""" - from xarray.core.utils import is_duck_dask_array +# Error message tests - rng = np.random.default_rng(42) - ds = Dataset( - { - "x": (["time", "lat", "lon"], rng.random((3, 4, 5))), - "y": (["time", "lat", "lon"], rng.random((3, 4, 5))), - } - ).chunk({"time": 1, "lat": 2, "lon": 5}) - with raise_if_dask_computes(): - result = ds.eval("x * 2 + y ** 2") +def test_eval_error_undefined_variable() -> None: + """Test error message when referencing an undefined variable.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + with pytest.raises(NameError, match="undefined_var"): + ds.eval("undefined_var + a") + + +def test_eval_error_syntax() -> None: + """Test error message for malformed expressions.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + with pytest.raises(ValueError, match="Invalid"): + ds.eval("a +") + + +def test_eval_error_invalid_assignment() -> None: + """Test error message when assignment target is invalid.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + # "1 = a" should fail during parsing - can't assign to a literal + with pytest.raises(ValueError, match="Invalid"): + ds.eval("1 = a") + + +def test_eval_error_dunder_access() -> None: + """Test error message when trying to access dunder attributes.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + with pytest.raises(ValueError, match="private attributes"): + ds.eval("a.__class__") + + +def test_eval_error_missing_method() -> None: + """Test error message when calling a nonexistent method.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + # This should raise AttributeError from the DataArray + with pytest.raises(AttributeError, match="nonexistent_method"): + ds.eval("a.nonexistent_method()") + + +def test_eval_error_type_mismatch() -> None: + """Test error message when types are incompatible.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + # Adding string to numeric array should raise TypeError or similar + with pytest.raises((TypeError, np.exceptions.DTypePromotionError)): + ds.eval("a + 'string'") + + +# Edge case tests + + +def test_eval_empty_expression() -> None: + """Test handling of empty expression string.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + with pytest.raises(ValueError): + ds.eval("") + + +def test_eval_whitespace_only_expression() -> None: + """Test handling of whitespace-only expression.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + with pytest.raises(ValueError): + ds.eval(" ") + + +def test_eval_just_variable_name() -> None: + """Test that just a variable name returns the variable.""" + ds = Dataset({"a": ("x", [1, 2, 3])}) + result = ds.eval("a") + expected = ds["a"] + assert_equal(result, expected) - assert is_duck_dask_array(result.data) - # Verify correctness when computed - expected = ds["x"] * 2 + ds["y"] ** 2 - assert_equal(result, expected) +def test_eval_unicode_variable_names() -> None: + """Test that unicode variable names work in expressions.""" + # Greek letters are valid Python identifiers + ds = Dataset({"α": ("x", [1.0, 2.0, 3.0]), "β": ("x", [4.0, 5.0, 6.0])}) + result = ds.eval("α + β") + expected = ds["α"] + ds["β"] + assert_equal(result, expected) - @requires_dask - def test_mixed_dask_and_numpy(self) -> None: - """Test expressions with mixed dask and numpy arrays.""" - from xarray.core.utils import is_duck_dask_array - ds = Dataset( - { - "dask_var": ("x", np.arange(10.0)), - "numpy_var": ("x", np.linspace(0, 1, 10)), - } - ) - # Only chunk one variable - ds["dask_var"] = ds["dask_var"].chunk({"x": 5}) +def test_eval_long_expression() -> None: + """Test that very long expressions work correctly.""" + ds = Dataset({"a": ("x", [1.0, 2.0, 3.0])}) + # Build a long expression: a + a + a + ... (50 times) + long_expr = " + ".join(["a"] * 50) + result = ds.eval(long_expr) + expected = ds["a"] * 50 + assert_equal(result, expected) - with raise_if_dask_computes(): - result = ds.eval("dask_var + numpy_var") - # Result should be dask-backed when any input is dask - assert is_duck_dask_array(result.data) +# Dask tests - # Verify correctness - expected = ds["dask_var"] + ds["numpy_var"] - assert_equal(result, expected) - @requires_dask - def test_np_functions_preserve_dask(self) -> None: - """Test that numpy functions via np namespace preserve dask.""" - from xarray.core.utils import is_duck_dask_array +@requires_dask +def test_eval_dask_basic_arithmetic() -> None: + """Test that basic arithmetic with dask arrays returns dask-backed result.""" + from xarray.core.utils import is_duck_dask_array - ds = Dataset({"a": ("x", np.arange(1.0, 11.0))}).chunk({"x": 5}) + ds = Dataset( + {"a": ("x", np.arange(10.0)), "b": ("x", np.linspace(0, 1, 10))} + ).chunk({"x": 5}) - with raise_if_dask_computes(): - result = ds.eval("np.sqrt(a)") + with raise_if_dask_computes(): + result = ds.eval("a + b") - assert is_duck_dask_array(result.data) + assert isinstance(result, DataArray) + assert is_duck_dask_array(result.data) - # Verify correctness - expected = np.sqrt(ds["a"]) - assert_equal(result, expected) + # Verify correctness when computed + expected = ds["a"] + ds["b"] + assert_equal(result, expected) - @requires_dask - def test_comparison_preserves_dask(self) -> None: - """Test that comparison operations preserve dask backing.""" - from xarray.core.utils import is_duck_dask_array - ds = Dataset( - {"a": ("x", np.arange(10.0)), "b": ("x", np.arange(10.0)[::-1])} - ).chunk({"x": 5}) +@requires_dask +def test_eval_dask_assignment() -> None: + """Test that assignments with dask arrays preserve lazy evaluation.""" + from xarray.core.utils import is_duck_dask_array - with raise_if_dask_computes(): - result = ds.eval("a > b") + ds = Dataset( + {"a": ("x", np.arange(10.0)), "b": ("x", np.linspace(0, 1, 10))} + ).chunk({"x": 5}) - assert is_duck_dask_array(result.data) + with raise_if_dask_computes(): + result = ds.eval("z = a + b") - # Verify correctness - expected = ds["a"] > ds["b"] - assert_equal(result, expected) + assert isinstance(result, Dataset) + assert "z" in result.data_vars + assert is_duck_dask_array(result["z"].data) - @requires_dask - def test_boolean_operators_preserve_dask(self) -> None: - """Test that bitwise boolean operators preserve dask.""" - from xarray.core.utils import is_duck_dask_array + # Verify correctness when computed + expected = ds["a"] + ds["b"] + assert_equal(result["z"], expected) - ds = Dataset( - {"a": ("x", np.arange(10.0)), "b": ("x", np.arange(10.0)[::-1])} - ).chunk({"x": 5}) - with raise_if_dask_computes(): - result = ds.eval("(a > 3) & (b < 7)") +@requires_dask +def test_eval_dask_method_chaining() -> None: + """Test that method chaining works with dask arrays.""" + ds = Dataset({"a": (("x", "y"), np.arange(20.0).reshape(4, 5))}).chunk( + {"x": 2, "y": 5} + ) - assert is_duck_dask_array(result.data) + # Calling .mean() should still be lazy + result = ds.eval("a.mean(dim='x')") + # Calling .compute() should return numpy-backed result + computed = result.compute() - # Verify correctness - expected = (ds["a"] > 3) & (ds["b"] < 7) - assert_equal(result, expected) + expected = ds["a"].mean(dim="x").compute() + assert_equal(computed, expected) - @requires_dask - def test_chained_comparisons_preserve_dask(self) -> None: - """Test that chained comparisons preserve dask backing.""" - from xarray.core.utils import is_duck_dask_array - ds = Dataset({"x": ("dim", np.arange(10.0))}).chunk({"dim": 5}) +@requires_dask +def test_eval_dask_xr_where() -> None: + """Test that xr.where() with dask arrays preserves lazy evaluation.""" + from xarray.core.utils import is_duck_dask_array - with raise_if_dask_computes(): - result = ds.eval("2 < x < 7") + ds = Dataset({"a": ("x", np.arange(-5, 5, dtype=float))}).chunk({"x": 5}) - assert is_duck_dask_array(result.data) + with raise_if_dask_computes(): + result = ds.eval("xr.where(a > 0, a, 0)") - # Verify correctness - expected = (ds["x"] > 2) & (ds["x"] < 7) - assert_equal(result, expected) + assert isinstance(result, DataArray) + assert is_duck_dask_array(result.data) + + # Verify correctness when computed + expected = xr.where(ds["a"] > 0, ds["a"], 0) + assert_equal(result, expected) + + +@requires_dask +def test_eval_dask_complex_expression() -> None: + """Test that complex expressions preserve dask backing.""" + from xarray.core.utils import is_duck_dask_array + + rng = np.random.default_rng(42) + ds = Dataset( + { + "x": (["time", "lat", "lon"], rng.random((3, 4, 5))), + "y": (["time", "lat", "lon"], rng.random((3, 4, 5))), + } + ).chunk({"time": 1, "lat": 2, "lon": 5}) + + with raise_if_dask_computes(): + result = ds.eval("x * 2 + y ** 2") + + assert is_duck_dask_array(result.data) + + # Verify correctness when computed + expected = ds["x"] * 2 + ds["y"] ** 2 + assert_equal(result, expected) + + +@requires_dask +def test_eval_dask_mixed_backends() -> None: + """Test expressions with mixed dask and numpy arrays.""" + from xarray.core.utils import is_duck_dask_array + + ds = Dataset( + { + "dask_var": ("x", np.arange(10.0)), + "numpy_var": ("x", np.linspace(0, 1, 10)), + } + ) + # Only chunk one variable + ds["dask_var"] = ds["dask_var"].chunk({"x": 5}) + + with raise_if_dask_computes(): + result = ds.eval("dask_var + numpy_var") + + # Result should be dask-backed when any input is dask + assert is_duck_dask_array(result.data) + + # Verify correctness + expected = ds["dask_var"] + ds["numpy_var"] + assert_equal(result, expected) + + +@requires_dask +def test_eval_dask_np_functions() -> None: + """Test that numpy functions via np namespace preserve dask.""" + from xarray.core.utils import is_duck_dask_array + + ds = Dataset({"a": ("x", np.arange(1.0, 11.0))}).chunk({"x": 5}) + + with raise_if_dask_computes(): + result = ds.eval("np.sqrt(a)") + + assert is_duck_dask_array(result.data) + + # Verify correctness + expected = np.sqrt(ds["a"]) + assert_equal(result, expected) + + +@requires_dask +def test_eval_dask_comparison() -> None: + """Test that comparison operations preserve dask backing.""" + from xarray.core.utils import is_duck_dask_array + + ds = Dataset( + {"a": ("x", np.arange(10.0)), "b": ("x", np.arange(10.0)[::-1])} + ).chunk({"x": 5}) + + with raise_if_dask_computes(): + result = ds.eval("a > b") + + assert is_duck_dask_array(result.data) + + # Verify correctness + expected = ds["a"] > ds["b"] + assert_equal(result, expected) + + +@requires_dask +def test_eval_dask_boolean_operators() -> None: + """Test that bitwise boolean operators preserve dask.""" + from xarray.core.utils import is_duck_dask_array + + ds = Dataset( + {"a": ("x", np.arange(10.0)), "b": ("x", np.arange(10.0)[::-1])} + ).chunk({"x": 5}) + + with raise_if_dask_computes(): + result = ds.eval("(a > 3) & (b < 7)") + + assert is_duck_dask_array(result.data) + + # Verify correctness + expected = (ds["a"] > 3) & (ds["b"] < 7) + assert_equal(result, expected) + + +@requires_dask +def test_eval_dask_chained_comparisons() -> None: + """Test that chained comparisons preserve dask backing.""" + from xarray.core.utils import is_duck_dask_array + + ds = Dataset({"x": ("dim", np.arange(10.0))}).chunk({"dim": 5}) + + with raise_if_dask_computes(): + result = ds.eval("2 < x < 7") + + assert is_duck_dask_array(result.data) + + # Verify correctness + expected = (ds["x"] > 2) & (ds["x"] < 7) + assert_equal(result, expected)