Skip to content

Commit 0438825

Browse files
authored
Merge pull request #44 from microsoft/list_comprehension
List comprehension
2 parents 9269be2 + 7c1fc6f commit 0438825

File tree

9 files changed

+591
-0
lines changed

9 files changed

+591
-0
lines changed

flowquery-py/src/parsing/data_structures/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
from .associative_array import AssociativeArray
44
from .json_array import JSONArray
55
from .key_value_pair import KeyValuePair
6+
from .list_comprehension import ListComprehension
67
from .lookup import Lookup
78
from .range_lookup import RangeLookup
89

910
__all__ = [
1011
"AssociativeArray",
1112
"JSONArray",
1213
"KeyValuePair",
14+
"ListComprehension",
1315
"Lookup",
1416
"RangeLookup",
1517
]
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""Represents a Cypher-style list comprehension in the AST.
2+
3+
List comprehensions allow mapping and filtering arrays inline using the syntax:
4+
[variable IN list | expression]
5+
[variable IN list WHERE condition | expression]
6+
[variable IN list WHERE condition]
7+
[variable IN list]
8+
9+
Example:
10+
[n IN [1, 2, 3] WHERE n > 1 | n * 2] => [4, 6]
11+
"""
12+
13+
from typing import Any, List, Optional
14+
15+
from ..ast_node import ASTNode
16+
from ..expressions.expression import Expression
17+
from ..functions.value_holder import ValueHolder
18+
from ..operations.where import Where
19+
20+
21+
class ListComprehension(ASTNode):
22+
"""Represents a list comprehension expression.
23+
24+
Children layout:
25+
- Child 0: Reference (iteration variable)
26+
- Child 1: Expression (source array)
27+
- Child 2 (optional): Where (filter condition) or Expression (mapping)
28+
- Child 3 (optional): Expression (mapping, when Where is child 2)
29+
"""
30+
31+
def __init__(self) -> None:
32+
super().__init__()
33+
self._value_holder = ValueHolder()
34+
35+
@property
36+
def reference(self) -> ASTNode:
37+
"""The iteration variable reference."""
38+
return self.first_child()
39+
40+
@property
41+
def array(self) -> ASTNode:
42+
"""The source array expression (unwrapped from its Expression wrapper)."""
43+
return self.get_children()[1].first_child()
44+
45+
@property
46+
def _return(self) -> Optional[Expression]:
47+
"""The mapping expression, or None if not specified."""
48+
children = self.get_children()
49+
if len(children) <= 2:
50+
return None
51+
last = children[-1]
52+
if isinstance(last, Where):
53+
return None
54+
return last if isinstance(last, Expression) else None
55+
56+
@property
57+
def where(self) -> Optional[Where]:
58+
"""The optional WHERE filter condition."""
59+
for child in self.get_children():
60+
if isinstance(child, Where):
61+
return child
62+
return None
63+
64+
def value(self) -> List[Any]:
65+
"""Evaluate the list comprehension.
66+
67+
Iterates over the source array, applies the optional filter,
68+
and maps each element through the return expression.
69+
70+
Returns:
71+
The resulting filtered/mapped array.
72+
"""
73+
ref = self.reference
74+
if hasattr(ref, "referred"):
75+
ref.referred = self._value_holder
76+
array = self.array.value()
77+
if array is None or not isinstance(array, list):
78+
raise ValueError("Expected array for list comprehension")
79+
result: List[Any] = []
80+
for item in array:
81+
self._value_holder.holder = item
82+
if self.where is None or self.where.value():
83+
if self._return is not None:
84+
result.append(self._return.value())
85+
else:
86+
result.append(item)
87+
return result
88+
89+
def __str__(self) -> str:
90+
return "ListComprehension"

flowquery-py/src/parsing/parser.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .data_structures.associative_array import AssociativeArray
2424
from .data_structures.json_array import JSONArray
2525
from .data_structures.key_value_pair import KeyValuePair
26+
from .data_structures.list_comprehension import ListComprehension
2627
from .data_structures.lookup import Lookup
2728
from .data_structures.range_lookup import RangeLookup
2829
from .expressions.expression import Expression
@@ -877,6 +878,13 @@ def _parse_operand(self, expression: Expression) -> bool:
877878
lookup = self._parse_lookup(sub)
878879
expression.add_node(lookup)
879880
return True
881+
elif self.token.is_opening_bracket() and self._looks_like_list_comprehension():
882+
list_comp = self._parse_list_comprehension()
883+
if list_comp is None:
884+
raise ValueError("Expected list comprehension")
885+
lookup = self._parse_lookup(list_comp)
886+
expression.add_node(lookup)
887+
return True
880888
elif self.token.is_opening_brace() or self.token.is_opening_bracket():
881889
json = self._parse_json()
882890
if json is None:
@@ -1290,6 +1298,84 @@ def _parse_function_parameters(self) -> Iterator[ASTNode]:
12901298
break
12911299
self.set_next_token()
12921300

1301+
def _looks_like_list_comprehension(self) -> bool:
1302+
"""Peek ahead from an opening bracket to determine whether the
1303+
upcoming tokens form a list comprehension (e.g. ``[n IN list | n.name]``)
1304+
rather than a plain JSON array literal (e.g. ``[1, 2, 3]``).
1305+
1306+
The heuristic is: ``[`` identifier ``IN`` -> list comprehension.
1307+
"""
1308+
saved_index = self._token_index
1309+
self.set_next_token() # skip '['
1310+
self._skip_whitespace_and_comments()
1311+
1312+
if not self.token.is_identifier_or_keyword():
1313+
self._token_index = saved_index
1314+
return False
1315+
1316+
self.set_next_token() # skip identifier
1317+
self._skip_whitespace_and_comments()
1318+
result = self.token.is_in()
1319+
self._token_index = saved_index
1320+
return result
1321+
1322+
def _parse_list_comprehension(self) -> Optional[ListComprehension]:
1323+
"""Parse a list comprehension expression.
1324+
1325+
Syntax: ``[variable IN list [WHERE condition] [| expression]]``
1326+
"""
1327+
if not self.token.is_opening_bracket():
1328+
return None
1329+
1330+
list_comp = ListComprehension()
1331+
self.set_next_token() # skip '['
1332+
self._skip_whitespace_and_comments()
1333+
1334+
# Parse iteration variable
1335+
if not self.token.is_identifier_or_keyword():
1336+
raise ValueError("Expected identifier")
1337+
reference = Reference(self.token.value or "")
1338+
self._state.variables[reference.identifier] = reference
1339+
list_comp.add_child(reference)
1340+
self.set_next_token()
1341+
self._expect_and_skip_whitespace_and_comments()
1342+
1343+
# Parse IN keyword
1344+
if not self.token.is_in():
1345+
raise ValueError("Expected IN")
1346+
self.set_next_token()
1347+
self._expect_and_skip_whitespace_and_comments()
1348+
1349+
# Parse source array expression
1350+
array_expr = self._parse_expression()
1351+
if array_expr is None:
1352+
raise ValueError("Expected expression")
1353+
list_comp.add_child(array_expr)
1354+
1355+
# Optional WHERE clause
1356+
self._skip_whitespace_and_comments()
1357+
where = self._parse_where()
1358+
if where is not None:
1359+
list_comp.add_child(where)
1360+
1361+
# Optional | mapping expression
1362+
self._skip_whitespace_and_comments()
1363+
if self.token.is_pipe():
1364+
self.set_next_token()
1365+
self._skip_whitespace_and_comments()
1366+
return_expr = self._parse_expression()
1367+
if return_expr is None:
1368+
raise ValueError("Expected expression after |")
1369+
list_comp.add_child(return_expr)
1370+
1371+
self._skip_whitespace_and_comments()
1372+
if not self.token.is_closing_bracket():
1373+
raise ValueError("Expected closing bracket")
1374+
self.set_next_token()
1375+
1376+
del self._state.variables[reference.identifier]
1377+
return list_comp
1378+
12931379
def _parse_json(self) -> Optional[ASTNode]:
12941380
if self.token.is_opening_brace():
12951381
return self._parse_associative_array()

flowquery-py/tests/compute/test_runner.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,83 @@ async def test_range_function(self):
406406
assert len(results) == 1
407407
assert results[0] == {"range": [1, 2, 3]}
408408

409+
@pytest.mark.asyncio
410+
async def test_list_comprehension_with_mapping(self):
411+
"""Test list comprehension with mapping expression."""
412+
runner = Runner("RETURN [n IN [1, 2, 3] | n * 2] AS doubled")
413+
await runner.run()
414+
results = runner.results
415+
assert len(results) == 1
416+
assert results[0] == {"doubled": [2, 4, 6]}
417+
418+
@pytest.mark.asyncio
419+
async def test_list_comprehension_with_where_filter(self):
420+
"""Test list comprehension with WHERE filter."""
421+
runner = Runner("RETURN [n IN [1, 2, 3, 4, 5] WHERE n > 2] AS filtered")
422+
await runner.run()
423+
results = runner.results
424+
assert len(results) == 1
425+
assert results[0] == {"filtered": [3, 4, 5]}
426+
427+
@pytest.mark.asyncio
428+
async def test_list_comprehension_with_where_and_mapping(self):
429+
"""Test list comprehension with WHERE and mapping."""
430+
runner = Runner("RETURN [n IN [1, 2, 3, 4] WHERE n > 1 | n ^ 2] AS result")
431+
await runner.run()
432+
results = runner.results
433+
assert len(results) == 1
434+
assert results[0] == {"result": [4, 9, 16]}
435+
436+
@pytest.mark.asyncio
437+
async def test_list_comprehension_identity(self):
438+
"""Test list comprehension identity (no WHERE, no mapping)."""
439+
runner = Runner("RETURN [n IN [10, 20, 30]] AS result")
440+
await runner.run()
441+
results = runner.results
442+
assert len(results) == 1
443+
assert results[0] == {"result": [10, 20, 30]}
444+
445+
@pytest.mark.asyncio
446+
async def test_list_comprehension_with_variable_reference(self):
447+
"""Test list comprehension with variable reference."""
448+
runner = Runner("WITH [1, 2, 3] AS nums RETURN [n IN nums | n + 10] AS result")
449+
await runner.run()
450+
results = runner.results
451+
assert len(results) == 1
452+
assert results[0] == {"result": [11, 12, 13]}
453+
454+
@pytest.mark.asyncio
455+
async def test_list_comprehension_with_property_access(self):
456+
"""Test list comprehension with property access."""
457+
runner = Runner(
458+
'WITH [{name: "Alice", age: 30}, {name: "Bob", age: 25}] AS people '
459+
'RETURN [p IN people | p.name] AS names'
460+
)
461+
await runner.run()
462+
results = runner.results
463+
assert len(results) == 1
464+
assert results[0] == {"names": ["Alice", "Bob"]}
465+
466+
@pytest.mark.asyncio
467+
async def test_list_comprehension_with_function_source(self):
468+
"""Test list comprehension with function as source."""
469+
runner = Runner("RETURN [n IN range(1, 5) WHERE n > 3 | n * 10] AS result")
470+
await runner.run()
471+
results = runner.results
472+
assert len(results) == 1
473+
assert results[0] == {"result": [40, 50]}
474+
475+
@pytest.mark.asyncio
476+
async def test_list_comprehension_with_size(self):
477+
"""Test list comprehension composed with size."""
478+
runner = Runner(
479+
"RETURN size([n IN [1, 2, 3, 4, 5] WHERE n > 2]) AS count"
480+
)
481+
await runner.run()
482+
results = runner.results
483+
assert len(results) == 1
484+
assert results[0] == {"count": 3}
485+
409486
@pytest.mark.asyncio
410487
async def test_range_function_with_unwind_and_case(self):
411488
"""Test range function with unwind and case."""

flowquery-py/tests/parsing/test_parser.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,3 +1235,39 @@ def test_order_by_expression_with_limit(self):
12351235
"return x order by toLower(x) asc limit 2"
12361236
)
12371237
assert ast is not None
1238+
1239+
def test_list_comprehension_with_mapping(self):
1240+
"""Test list comprehension with mapping parses correctly."""
1241+
parser = Parser()
1242+
ast = parser.parse("RETURN [n IN [1, 2, 3] | n * 2] AS doubled")
1243+
assert "ListComprehension" in ast.print()
1244+
1245+
def test_list_comprehension_with_where_and_mapping(self):
1246+
"""Test list comprehension with WHERE and mapping."""
1247+
parser = Parser()
1248+
ast = parser.parse("RETURN [n IN [1, 2, 3] WHERE n > 1 | n * 2] AS result")
1249+
output = ast.print()
1250+
assert "ListComprehension" in output
1251+
assert "Where" in output
1252+
1253+
def test_list_comprehension_with_where_only(self):
1254+
"""Test list comprehension with WHERE only."""
1255+
parser = Parser()
1256+
ast = parser.parse("RETURN [n IN [1, 2, 3, 4] WHERE n > 2] AS filtered")
1257+
output = ast.print()
1258+
assert "ListComprehension" in output
1259+
assert "Where" in output
1260+
1261+
def test_list_comprehension_identity(self):
1262+
"""Test list comprehension identity."""
1263+
parser = Parser()
1264+
ast = parser.parse("RETURN [n IN [1, 2, 3]] AS result")
1265+
assert "ListComprehension" in ast.print()
1266+
1267+
def test_regular_json_array_still_parses(self):
1268+
"""Regular JSON array still parses correctly alongside list comprehension."""
1269+
parser = Parser()
1270+
ast = parser.parse("RETURN [1, 2, 3] AS arr")
1271+
output = ast.print()
1272+
assert "JSONArray" in output
1273+
assert "ListComprehension" not in output

0 commit comments

Comments
 (0)