Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 93 additions & 15 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import ast
import asyncio
import builtins
import copy
import datetime
import io
Expand Down Expand Up @@ -51,6 +53,11 @@
from xarray.core.dataset_utils import _get_virtual_variable, _LocIndexer
from xarray.core.dataset_variables import DataVariables
from xarray.core.duck_array_ops import datetime_to_numeric
from xarray.core.eval import (
EVAL_BUILTINS,
LogicalOperatorTransformer,
validate_expression,
)
from xarray.core.indexes import (
Index,
Indexes,
Expand All @@ -72,7 +79,6 @@
Self,
T_ChunkDim,
T_ChunksFreq,
T_DataArray,
T_DataArrayOrSet,
ZarrWriteModes,
)
Expand Down Expand Up @@ -9533,19 +9539,48 @@ def argmax(self, dim: Hashable | None = None, **kwargs) -> Self:
"Dataset.argmin() with a sequence or ... for dim"
)

def _eval_expression(self, expr: str) -> DataArray:
"""Evaluate an expression string using xarray's native operations."""
try:
tree = ast.parse(expr, mode="eval")
except SyntaxError as e:
raise ValueError(f"Invalid expression syntax: {expr}") from e

# Transform logical operators for consistency with query().
# See LogicalOperatorTransformer docstring for details.
tree = LogicalOperatorTransformer().visit(tree)
ast.fix_missing_locations(tree)

validate_expression(tree)

# Build namespace: data variables, coordinates, modules, and safe builtins.
# Empty __builtins__ blocks dangerous functions like __import__, exec, open.
# Priority order (highest to lowest): data variables > coordinates > modules > builtins
# This ensures user data always wins when names collide with builtins.
import xarray as xr # Lazy import to avoid circular dependency

namespace: dict[str, Any] = dict(EVAL_BUILTINS)
namespace.update({"np": np, "pd": pd, "xr": xr})
namespace.update({str(name): self.coords[name] for name in self.coords})
namespace.update({str(name): self[name] for name in self.data_vars})

code = compile(tree, "<xarray.eval>", "eval")
return builtins.eval(code, {"__builtins__": {}}, namespace)

def eval(
self,
statement: str,
*,
parser: QueryParserOptions = "pandas",
) -> Self | T_DataArray:
parser: QueryParserOptions | Default = _default,
) -> Self | DataArray:
"""
Calculate an expression supplied as a string in the context of the dataset.

This is currently experimental; the API may change particularly around
assignments, which currently return a ``Dataset`` with the additional variable.
Currently only the ``python`` engine is supported, which has the same
performance as executing in python.

Logical operators (``and``, ``or``, ``not``) are automatically transformed
to bitwise operators (``&``, ``|``, ``~``) which work element-wise on arrays.

Parameters
----------
Expand All @@ -9555,7 +9590,11 @@ def eval(
Returns
-------
result : Dataset or DataArray, depending on whether ``statement`` contains an
assignment.
assignment.

Warning
-------
Like ``pd.eval()``, this method should not be used with untrusted input.

Examples
--------
Expand Down Expand Up @@ -9584,16 +9623,55 @@ def eval(
b (x) float64 40B 0.0 0.25 0.5 0.75 1.0
c (x) float64 40B 0.0 1.25 2.5 3.75 5.0
"""
if parser is not _default:
emit_user_level_warning(
"The 'parser' argument to Dataset.eval() is deprecated and will be "
"removed in a future version. Logical operators (and/or/not) are now "
"always transformed to bitwise operators (&/|/~) for array compatibility.",
FutureWarning,
)

return pd.eval( # type: ignore[return-value]
statement,
resolvers=[self],
target=self,
parser=parser,
# Because numexpr returns a numpy array, using that engine results in
# different behavior. We'd be very open to a contribution handling this.
engine="python",
)
statement = statement.strip()

# Check for assignment: "target = expr"
# Must handle compound operators like ==, !=, <=, >=
# Use ast to detect assignment properly
try:
tree = ast.parse(statement, mode="exec")
except SyntaxError as e:
raise ValueError(f"Invalid statement syntax: {statement}") from e

if len(tree.body) != 1:
raise ValueError("Only single statements are supported")

stmt = tree.body[0]

if isinstance(stmt, ast.Assign):
# Assignment: "c = a + b"
if len(stmt.targets) != 1:
raise ValueError("Only single assignment targets are supported")
target = stmt.targets[0]
if not isinstance(target, ast.Name):
raise ValueError(
f"Assignment target must be a simple name, got {type(target).__name__}"
)
target_name = target.id

# Get the expression source
expr_source = ast.unparse(stmt.value)
result: DataArray = self._eval_expression(expr_source)
return self.assign({target_name: result})

elif isinstance(stmt, ast.Expr):
# Expression: "a + b"
expr_source = ast.unparse(stmt.value)
return self._eval_expression(expr_source)

else:
raise ValueError(
f"Unsupported statement type: {type(stmt).__name__}. "
f"Only expressions and assignments are supported."
)

def query(
self,
Expand Down
138 changes: 138 additions & 0 deletions xarray/core/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""
Expression evaluation for Dataset.eval().

This module provides AST-based expression evaluation to support N-dimensional
arrays (N > 2), which pd.eval() doesn't support. See GitHub issue #11062.

We retain logical operator transformation ('and'/'or'/'not' to '&'/'|'/'~',
and chained comparisons) for consistency with query(), which still uses
pd.eval(). We don't migrate query() to this implementation because:
- query() typically works fine (expressions usually compare 1D coordinates)
- pd.eval() with numexpr is faster and well-tested for query's use case
"""

from __future__ import annotations

import ast
import builtins
from typing import Any

# Base namespace for eval expressions.
# We add common builtins back since we use an empty __builtins__ dict.
EVAL_BUILTINS: dict[str, Any] = {
# Numeric/aggregation functions
"abs": abs,
"min": min,
"max": max,
"round": round,
"len": len,
"sum": sum,
"pow": pow,
"any": any,
"all": all,
# Type constructors
"int": int,
"float": float,
"bool": bool,
"str": str,
"list": list,
"tuple": tuple,
"dict": dict,
"set": set,
"slice": slice,
# Iteration helpers
"range": range,
"zip": zip,
"enumerate": enumerate,
"map": builtins.map,
"filter": filter,
}


class LogicalOperatorTransformer(ast.NodeTransformer):
"""Transform operators for consistency with query().

query() uses pd.eval() which transforms these operators automatically.
We replicate that behavior here so syntax that works in query() also
works in eval().

Transformations:
1. 'and'/'or'/'not' -> '&'/'|'/'~'
2. 'a < b < c' -> '(a < b) & (b < c)'

These constructs fail on arrays in standard Python because they call
__bool__(), which is ambiguous for multi-element arrays.
"""

def visit_BoolOp(self, node: ast.BoolOp) -> ast.AST:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not snake_case? I'm surprised we don't have a linter that catches this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a required convention from Python's ast.NodeTransformer. The visitor methods must be named visit_<NodeType> where <NodeType> matches the AST node class name exactly (e.g., BoolOp, UnaryOp, Compare). Using snake_case like visit_bool_op would break the visitor pattern - Python's ast module wouldn't find the methods.

[This is Claude Code on behalf of max-sixty]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the implementation, the only documentation I managed to understand:
https://github.com/python/cpython/blob/d0e9f4445a0d9039e1a2367ecee376b4b3ba7593/Lib/ast.py#L502-L506

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right — I think it needs to be visit_Foo, we can't change that. which the link seems to support?

        method = 'visit_' + node.__class__.__name__

(this is Max himself!)

# Transform: a and b -> a & b, a or b -> a | b
self.generic_visit(node)
op: ast.BitAnd | ast.BitOr
if isinstance(node.op, ast.And):
op = ast.BitAnd()
elif isinstance(node.op, ast.Or):
op = ast.BitOr()
else:
return node

# BoolOp can have multiple values: a and b and c
# Transform to chained BinOp: (a & b) & c
result = node.values[0]
for value in node.values[1:]:
result = ast.BinOp(left=result, op=op, right=value)
return ast.fix_missing_locations(result)

def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.AST:
# Transform: not a -> ~a
self.generic_visit(node)
if isinstance(node.op, ast.Not):
return ast.fix_missing_locations(
ast.UnaryOp(op=ast.Invert(), operand=node.operand)
)
return node

def visit_Compare(self, node: ast.Compare) -> ast.AST:
# Transform chained comparisons: 1 < x < 5 -> (1 < x) & (x < 5)
# Python's chained comparisons use short-circuit evaluation at runtime,
# which calls __bool__ on intermediate results. This fails for arrays.
# We transform to bitwise AND which works element-wise.
self.generic_visit(node)

if len(node.ops) == 1:
# Simple comparison, no transformation needed
return node

# Build individual comparisons and chain with BitAnd
# For: a < b < c < d
# We need: (a < b) & (b < c) & (c < d)
comparisons = []
left = node.left
for op, comparator in zip(node.ops, node.comparators, strict=True):
comp = ast.Compare(left=left, ops=[op], comparators=[comparator])
comparisons.append(comp)
left = comparator

# Chain with BitAnd: (a < b) & (b < c) & ...
result: ast.Compare | ast.BinOp = comparisons[0]
for comp in comparisons[1:]:
result = ast.BinOp(left=result, op=ast.BitAnd(), right=comp)
return ast.fix_missing_locations(result)


def validate_expression(tree: ast.AST) -> None:
"""Validate that an AST doesn't contain patterns we don't support.

These restrictions emulate pd.eval() behavior for consistency.
"""
for node in ast.walk(tree):
# Block lambda expressions (pd.eval: "Only named functions are supported")
if isinstance(node, ast.Lambda):
raise ValueError(
"Lambda expressions are not allowed in eval(). "
"Use direct operations on data variables instead."
)
# Block private/dunder attributes (consistent with pd.eval restrictions)
if isinstance(node, ast.Attribute) and node.attr.startswith("_"):
raise ValueError(
f"Access to private attributes is not allowed: '{node.attr}'"
)
17 changes: 0 additions & 17 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7636,23 +7636,6 @@ def test_query(self, backend, engine, parser) -> None:
# pytest tests — new tests should go here, rather than in the class.


@pytest.mark.parametrize("parser", ["pandas", "python"])
def test_eval(ds, parser) -> None:
"""Currently much more minimal testing that `query` above, and much of the setup
isn't used. But the risks are fairly low — `query` shares much of the code, and
the method is currently experimental."""

actual = ds.eval("z1 + 5", parser=parser)
expect = ds["z1"] + 5
assert_identical(expect, actual)

# check pandas query syntax is supported
if parser == "pandas":
actual = ds.eval("(z1 > 5) and (z2 > 0)", parser=parser)
expect = (ds["z1"] > 5) & (ds["z2"] > 0)
assert_identical(expect, actual)


@pytest.mark.parametrize("test_elements", ([1, 2], np.array([1, 2]), DataArray([1, 2])))
def test_isin(test_elements, backend) -> None:
expected = Dataset(
Expand Down
Loading
Loading