-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Dataset.eval works with >2 dims
#11064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
max-sixty
wants to merge
7
commits into
pydata:main
Choose a base branch
from
max-sixty:eval
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+838
−32
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
79fc511
Replace pandas.eval with native implementation
max-sixty ca87f1b
Merge branch 'main' into eval
max-sixty cac8e5a
Fix mypy errors in eval tests
max-sixty 67f27c2
Remove security framing, frame restrictions as pd.eval() compatibility
max-sixty 64b2601
Move eval implementation to dedicated module
max-sixty 32f0a29
Move eval tests to dedicated test_eval.py module
max-sixty 46bf9ef
Refactor eval tests: convert classes to standalone functions
max-sixty File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| """ | ||
| Expression evaluation for Dataset.eval(). | ||
|
|
||
| This module provides AST-based expression evaluation to support N-dimensional | ||
| arrays (N > 2), which pd.eval() doesn't support. See GitHub issue #11062. | ||
|
|
||
| We retain logical operator transformation ('and'/'or'/'not' to '&'/'|'/'~', | ||
| and chained comparisons) for consistency with query(), which still uses | ||
| pd.eval(). We don't migrate query() to this implementation because: | ||
| - query() typically works fine (expressions usually compare 1D coordinates) | ||
| - pd.eval() with numexpr is faster and well-tested for query's use case | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import ast | ||
| import builtins | ||
| from typing import Any | ||
|
|
||
| # Base namespace for eval expressions. | ||
| # We add common builtins back since we use an empty __builtins__ dict. | ||
| EVAL_BUILTINS: dict[str, Any] = { | ||
| # Numeric/aggregation functions | ||
| "abs": abs, | ||
| "min": min, | ||
| "max": max, | ||
| "round": round, | ||
| "len": len, | ||
| "sum": sum, | ||
| "pow": pow, | ||
| "any": any, | ||
| "all": all, | ||
| # Type constructors | ||
| "int": int, | ||
| "float": float, | ||
| "bool": bool, | ||
| "str": str, | ||
| "list": list, | ||
| "tuple": tuple, | ||
| "dict": dict, | ||
| "set": set, | ||
| "slice": slice, | ||
| # Iteration helpers | ||
| "range": range, | ||
| "zip": zip, | ||
| "enumerate": enumerate, | ||
| "map": builtins.map, | ||
| "filter": filter, | ||
| } | ||
|
|
||
|
|
||
| class LogicalOperatorTransformer(ast.NodeTransformer): | ||
| """Transform operators for consistency with query(). | ||
|
|
||
| query() uses pd.eval() which transforms these operators automatically. | ||
| We replicate that behavior here so syntax that works in query() also | ||
| works in eval(). | ||
|
|
||
| Transformations: | ||
| 1. 'and'/'or'/'not' -> '&'/'|'/'~' | ||
| 2. 'a < b < c' -> '(a < b) & (b < c)' | ||
|
|
||
| These constructs fail on arrays in standard Python because they call | ||
| __bool__(), which is ambiguous for multi-element arrays. | ||
| """ | ||
|
|
||
| def visit_BoolOp(self, node: ast.BoolOp) -> ast.AST: | ||
| # Transform: a and b -> a & b, a or b -> a | b | ||
| self.generic_visit(node) | ||
| op: ast.BitAnd | ast.BitOr | ||
| if isinstance(node.op, ast.And): | ||
| op = ast.BitAnd() | ||
| elif isinstance(node.op, ast.Or): | ||
| op = ast.BitOr() | ||
| else: | ||
| return node | ||
|
|
||
| # BoolOp can have multiple values: a and b and c | ||
| # Transform to chained BinOp: (a & b) & c | ||
| result = node.values[0] | ||
| for value in node.values[1:]: | ||
| result = ast.BinOp(left=result, op=op, right=value) | ||
| return ast.fix_missing_locations(result) | ||
|
|
||
| def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.AST: | ||
| # Transform: not a -> ~a | ||
| self.generic_visit(node) | ||
| if isinstance(node.op, ast.Not): | ||
| return ast.fix_missing_locations( | ||
| ast.UnaryOp(op=ast.Invert(), operand=node.operand) | ||
| ) | ||
| return node | ||
|
|
||
| def visit_Compare(self, node: ast.Compare) -> ast.AST: | ||
| # Transform chained comparisons: 1 < x < 5 -> (1 < x) & (x < 5) | ||
| # Python's chained comparisons use short-circuit evaluation at runtime, | ||
| # which calls __bool__ on intermediate results. This fails for arrays. | ||
| # We transform to bitwise AND which works element-wise. | ||
| self.generic_visit(node) | ||
|
|
||
| if len(node.ops) == 1: | ||
| # Simple comparison, no transformation needed | ||
| return node | ||
|
|
||
| # Build individual comparisons and chain with BitAnd | ||
| # For: a < b < c < d | ||
| # We need: (a < b) & (b < c) & (c < d) | ||
| comparisons = [] | ||
| left = node.left | ||
| for op, comparator in zip(node.ops, node.comparators, strict=True): | ||
| comp = ast.Compare(left=left, ops=[op], comparators=[comparator]) | ||
| comparisons.append(comp) | ||
| left = comparator | ||
|
|
||
| # Chain with BitAnd: (a < b) & (b < c) & ... | ||
| result: ast.Compare | ast.BinOp = comparisons[0] | ||
| for comp in comparisons[1:]: | ||
| result = ast.BinOp(left=result, op=ast.BitAnd(), right=comp) | ||
| return ast.fix_missing_locations(result) | ||
|
|
||
|
|
||
| def validate_expression(tree: ast.AST) -> None: | ||
| """Validate that an AST doesn't contain patterns we don't support. | ||
|
|
||
| These restrictions emulate pd.eval() behavior for consistency. | ||
| """ | ||
| for node in ast.walk(tree): | ||
| # Block lambda expressions (pd.eval: "Only named functions are supported") | ||
| if isinstance(node, ast.Lambda): | ||
| raise ValueError( | ||
| "Lambda expressions are not allowed in eval(). " | ||
| "Use direct operations on data variables instead." | ||
| ) | ||
| # Block private/dunder attributes (consistent with pd.eval restrictions) | ||
| if isinstance(node, ast.Attribute) and node.attr.startswith("_"): | ||
| raise ValueError( | ||
| f"Access to private attributes is not allowed: '{node.attr}'" | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not snake_case? I'm surprised we don't have a linter that catches this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a required convention from Python's
ast.NodeTransformer. The visitor methods must be namedvisit_<NodeType>where<NodeType>matches the AST node class name exactly (e.g.,BoolOp,UnaryOp,Compare). Using snake_case likevisit_bool_opwould break the visitor pattern - Python's ast module wouldn't find the methods.[This is Claude Code on behalf of max-sixty]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's the implementation, the only documentation I managed to understand:
https://github.com/python/cpython/blob/d0e9f4445a0d9039e1a2367ecee376b4b3ba7593/Lib/ast.py#L502-L506
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right — I think it needs to be
visit_Foo, we can't change that. which the link seems to support?(this is Max himself!)