Skip to content

Commit 7e3d9b2

Browse files
authored
Merge pull request #63 from microsoft/graph_pattern_traversal_improvements
Graph pattern traversal improvements
2 parents 432a65e + 7e3e785 commit 7e3d9b2

17 files changed

Lines changed: 232 additions & 181 deletions

File tree

flowquery-py/src/graph/node.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Union
5+
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, Optional
66

77
from ..parsing.ast_node import ASTNode
88
from ..parsing.expressions.expression import Expression
@@ -29,7 +29,6 @@ def __init__(
2929
self._incoming: Optional['Relationship'] = None
3030
self._outgoing: Optional['Relationship'] = None
3131
self._data: Optional['NodeData'] = None
32-
self._todo_next: Optional[Callable[[], Union[None, Awaitable[None]]]] = None
3332

3433
@property
3534
def identifier(self) -> Optional[str]:
@@ -112,7 +111,7 @@ def incoming(self, relationship: Optional['Relationship']) -> None:
112111
def set_data(self, data: Optional['NodeData']) -> None:
113112
self._data = data
114113

115-
async def next(self) -> None:
114+
async def next(self) -> AsyncIterator[None]:
116115
if self._data:
117116
self._data.reset()
118117
while self._data.next():
@@ -122,10 +121,12 @@ async def next(self) -> None:
122121
if not self._matches_properties():
123122
continue
124123
if self._outgoing and self._value:
125-
await self._outgoing.find(self._value['id'])
126-
await self.run_todo_next()
124+
async for _ in self._outgoing.find(self._value['id']):
125+
yield
126+
else:
127+
yield
127128

128-
async def find(self, id_: str, hop: int = 0) -> None:
129+
async def find(self, id_: str, hop: int = 0) -> AsyncIterator[None]:
129130
if self._data:
130131
self._data.reset()
131132
while self._data.find(id_, hop):
@@ -137,19 +138,7 @@ async def find(self, id_: str, hop: int = 0) -> None:
137138
if self._incoming:
138139
self._incoming.set_end_node(self)
139140
if self._outgoing and self._value:
140-
await self._outgoing.find(self._value['id'], hop)
141-
await self.run_todo_next()
142-
143-
@property
144-
def todo_next(self) -> Optional[Callable[[], Union[None, Awaitable[None]]]]:
145-
return self._todo_next
146-
147-
@todo_next.setter
148-
def todo_next(self, func: Optional[Callable[[], Union[None, Awaitable[None]]]]) -> None:
149-
self._todo_next = func
150-
151-
async def run_todo_next(self) -> None:
152-
if self._todo_next:
153-
result = self._todo_next()
154-
if result is not None:
155-
await result
141+
async for _ in self._outgoing.find(self._value['id'], hop):
142+
yield
143+
else:
144+
yield

flowquery-py/src/graph/node_reference.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any, Optional
3+
from typing import Any, AsyncIterator, Optional
44

55
from ..parsing.ast_node import ASTNode
66
from .node import Node
@@ -30,22 +30,26 @@ def referred(self) -> ASTNode:
3030
def value(self) -> Optional[Any]:
3131
return self._reference.value() if self._reference else None
3232

33-
async def next(self) -> None:
33+
async def next(self) -> AsyncIterator[None]:
3434
"""Process next using the referenced node's value."""
3535
ref_value = self._reference.value()
3636
if ref_value is None:
3737
return
3838
self.set_value(dict(ref_value))
3939
if self._outgoing and self._value:
40-
await self._outgoing.find(self._value['id'])
41-
await self.run_todo_next()
40+
async for _ in self._outgoing.find(self._value['id']):
41+
yield
42+
else:
43+
yield
4244

43-
async def find(self, id_: str, hop: int = 0) -> None:
45+
async def find(self, id_: str, hop: int = 0) -> AsyncIterator[None]:
4446
"""Find by ID, only matching if it equals the referenced node's ID."""
4547
referenced = self._reference.value()
4648
if referenced is None or id_ != referenced.get('id'):
4749
return
4850
self.set_value(dict(referenced))
4951
if self._outgoing and self._value:
50-
await self._outgoing.find(self._value['id'], hop)
51-
await self.run_todo_next()
52+
async for _ in self._outgoing.find(self._value['id'], hop):
53+
yield
54+
else:
55+
yield

flowquery-py/src/graph/pattern.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,5 @@ async def initialize(self) -> None:
112112
async def traverse(self) -> None:
113113
first = self.first_node()
114114
if first and isinstance(first, Node):
115-
await first.next()
115+
async for _ in first.next():
116+
pass

flowquery-py/src/graph/pattern_expression.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,9 @@ async def evaluate(self) -> None:
5151
Sets _evaluation to True if the pattern is matched, False otherwise.
5252
"""
5353
self._evaluation = False
54-
55-
async def set_evaluation_true() -> None:
54+
async for _ in self.start_node.next():
5655
self._evaluation = True
5756

58-
self.end_node.todo_next = set_evaluation_true
59-
await self.start_node.next()
60-
6157
def value(self) -> Any:
6258
"""Returns the result of the pattern evaluation."""
6359
return self._evaluation

flowquery-py/src/graph/patterns.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Collection of graph patterns for FlowQuery."""
22

3-
from typing import Awaitable, Callable, List, Optional
3+
from typing import AsyncIterator, List, Optional
44

55
from .pattern import Pattern
66

@@ -10,33 +10,25 @@ class Patterns:
1010

1111
def __init__(self, patterns: Optional[List[Pattern]] = None) -> None:
1212
self._patterns = patterns or []
13-
self._to_do_next: Optional[Callable[[], Awaitable[None]]] = None
1413

1514
@property
1615
def patterns(self) -> List[Pattern]:
1716
return self._patterns
1817

19-
@property
20-
def to_do_next(self) -> Optional[Callable[[], Awaitable[None]]]:
21-
return self._to_do_next
22-
23-
@to_do_next.setter
24-
def to_do_next(self, func: Optional[Callable[[], Awaitable[None]]]) -> None:
25-
self._to_do_next = func
26-
if self._patterns:
27-
self._patterns[-1].end_node.todo_next = func
28-
2918
async def initialize(self) -> None:
30-
previous: Optional[Pattern] = None
3119
for pattern in self._patterns:
32-
await pattern.fetch_data() # Ensure data is loaded
33-
if previous is not None:
34-
# Chain the patterns together
35-
async def next_pattern_start(p: Pattern = pattern) -> None:
36-
await p.start_node.next()
37-
previous.end_node.todo_next = next_pattern_start
38-
previous = pattern
39-
40-
async def traverse(self) -> None:
41-
if self._patterns:
42-
await self._patterns[0].start_node.next()
20+
await pattern.fetch_data()
21+
22+
async def traverse(self) -> AsyncIterator[None]:
23+
if not self._patterns:
24+
return
25+
async for _ in self._chain_patterns(0):
26+
yield
27+
28+
async def _chain_patterns(self, index: int) -> AsyncIterator[None]:
29+
async for _ in self._patterns[index].start_node.next():
30+
if index + 1 < len(self._patterns):
31+
async for _ in self._chain_patterns(index + 1):
32+
yield
33+
else:
34+
yield

flowquery-py/src/graph/relationship.py

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
5+
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Union
66

77
from ..parsing.ast_node import ASTNode
88
from .hops import Hops
@@ -151,44 +151,49 @@ def set_end_node(self, node: 'Node') -> None:
151151
def _left_id_or_right_id(self) -> str:
152152
return "left_id" if self._direction == "left" else "right_id"
153153

154-
async def find(self, left_id: str, hop: int = 0) -> None:
154+
async def find(self, left_id: str, hop: int = 0) -> AsyncIterator[None]:
155155
"""Find relationships starting from the given node ID."""
156156
# Save original source node
157157
original = self._source
158158
if hop > 0:
159159
# For hops greater than 0, the source becomes the target of the previous hop
160160
self._source = self._target
161-
if hop == 0:
162-
if self._data:
163-
self._data.reset()
164-
165-
# Handle zero-hop case: when min is 0 on a variable-length relationship,
166-
# match source node as target (no traversal)
167-
if self._hops and self._hops.multi() and self._hops.min == 0 and self._target:
168-
# For zero-hop, target finds the same node as source (left_id)
169-
# No relationship match is pushed since no edge is traversed
170-
await self._target.find(left_id, hop)
171-
172-
while self._data and self._data.find(left_id, hop, self._direction):
173-
data = self._data.current(hop)
174-
if data is None:
175-
continue
176-
id = data[self._left_id_or_right_id()]
177-
if hop + 1 >= self._hops.min:
178-
self.set_value(self, left_id)
179-
if not self._matches_properties(hop):
161+
try:
162+
if hop == 0:
163+
if self._data:
164+
self._data.reset()
165+
166+
# Handle zero-hop case: when min is 0 on a variable-length relationship,
167+
# match source node as target (no traversal)
168+
if self._hops and self._hops.multi() and self._hops.min == 0 and self._target:
169+
# For zero-hop, target finds the same node as source (left_id)
170+
# No relationship match is pushed since no edge is traversed
171+
async for _ in self._target.find(left_id, hop):
172+
yield
173+
174+
while self._data and self._data.find(left_id, hop, self._direction):
175+
data = self._data.current(hop)
176+
if data is None:
180177
continue
181-
if self._target:
182-
await self._target.find(id, hop)
183-
if hop + 1 < self._hops.max:
184-
if self._matches.is_circular(id):
185-
self._matches.pop()
178+
id = data[self._left_id_or_right_id()]
179+
if hop + 1 >= self._hops.min:
180+
self.set_value(self, left_id)
181+
if not self._matches_properties(hop):
186182
continue
187-
await self.find(id, hop + 1)
188-
self._matches.pop()
189-
else:
190-
# Below minimum hops: traverse the edge without yielding a match
191-
await self.find(id, hop + 1)
192-
193-
# Restore original source node
194-
self._source = original
183+
if self._target:
184+
async for _ in self._target.find(id, hop):
185+
yield
186+
if hop + 1 < self._hops.max:
187+
if self._matches.is_circular(id):
188+
self._matches.pop()
189+
continue
190+
async for _ in self.find(id, hop + 1):
191+
yield
192+
self._matches.pop()
193+
else:
194+
# Below minimum hops: traverse the edge without yielding a match
195+
async for _ in self.find(id, hop + 1):
196+
yield
197+
finally:
198+
# Restore original source node
199+
self._source = original

flowquery-py/src/parsing/data_structures/lookup.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,13 @@ def value(self) -> Any:
4444
# Try dict-like access first, then fall back to attribute access for objects
4545
try:
4646
return obj[key]
47-
except (TypeError, KeyError):
47+
except KeyError:
4848
# For objects with attributes (like dataclasses), use getattr
49+
if hasattr(obj, key):
50+
return getattr(obj, key)
51+
# Return None for missing keys, matching JavaScript obj[key] behavior
52+
return None
53+
except TypeError:
4954
if hasattr(obj, key):
5055
return getattr(obj, key)
5156
raise

flowquery-py/src/parsing/operations/match.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,11 @@ async def run(self) -> None:
4141
await self._patterns.initialize()
4242
matched = False
4343

44-
async def to_do_next() -> None:
45-
nonlocal matched
44+
async for _ in self._patterns.traverse():
4645
matched = True
4746
if self.next:
4847
await self.next.run()
4948

50-
self._patterns.to_do_next = to_do_next
51-
await self._patterns.traverse()
52-
5349
# For OPTIONAL MATCH: if nothing matched, continue with None values
5450
if not matched and self._optional:
5551
for pattern in self._patterns.patterns:

flowquery-py/tests/compute/test_runner.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,42 @@ async def test_return_with_expression_alias_which_starts_with_keyword(self):
11131113
assert len(results) == 1
11141114
assert results[0] == {"return1": 1, "notes": ["hello", "world"]}
11151115

1116+
@pytest.mark.asyncio
1117+
async def test_lookup_missing_property_returns_null(self):
1118+
"""Test that accessing a missing property returns null instead of raising KeyError."""
1119+
runner = Runner('RETURN {a: 1}.b as result')
1120+
await runner.run()
1121+
results = runner.results
1122+
assert len(results) == 1
1123+
assert results[0] == {"result": None}
1124+
1125+
@pytest.mark.asyncio
1126+
async def test_lookup_missing_property_bracket_notation_returns_null(self):
1127+
"""Test that bracket notation on a missing property returns null."""
1128+
runner = Runner('RETURN {a: 1}["b"] as result')
1129+
await runner.run()
1130+
results = runner.results
1131+
assert len(results) == 1
1132+
assert results[0] == {"result": None}
1133+
1134+
@pytest.mark.asyncio
1135+
async def test_lookup_missing_property_with_coalesce(self):
1136+
"""Test coalesce with a missing property lookup."""
1137+
runner = Runner('RETURN coalesce({a: 1}.b, "default") as result')
1138+
await runner.run()
1139+
results = runner.results
1140+
assert len(results) == 1
1141+
assert results[0] == {"result": "default"}
1142+
1143+
@pytest.mark.asyncio
1144+
async def test_lookup_on_null_returns_null(self):
1145+
"""Test that lookup on null returns null."""
1146+
runner = Runner('WITH null as obj RETURN obj.x as result')
1147+
await runner.run()
1148+
results = runner.results
1149+
assert len(results) == 1
1150+
assert results[0] == {"result": None}
1151+
11161152
@pytest.mark.asyncio
11171153
async def test_return_with_where_clause(self):
11181154
"""Test return with where clause."""

0 commit comments

Comments
 (0)