Skip to content

Commit 6ea9218

Browse files
authored
Merge pull request #27 from microsoft/fixups
Fixups
2 parents a1235cd + 26b0218 commit 6ea9218

14 files changed

Lines changed: 236 additions & 85 deletions

File tree

flowquery-py/src/graph/node_reference.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,28 @@
22

33
from typing import Any, Optional
44

5+
from ..parsing.ast_node import ASTNode
56
from .node import Node
67

78

89
class NodeReference(Node):
910
"""Represents a reference to an existing node variable."""
1011

11-
def __init__(self, base: Node, reference: Node) -> None:
12+
def __init__(self, base: Node, reference: ASTNode) -> None:
1213
super().__init__(base.identifier, base.label)
13-
self._reference: Node = reference
14+
self._reference: ASTNode = reference
1415
# Copy properties from base
1516
self._properties = base._properties
1617
self._outgoing = base.outgoing
1718
self._incoming = base.incoming
1819

1920
@property
20-
def reference(self) -> Node:
21+
def reference(self) -> ASTNode:
2122
return self._reference
2223

2324
# Keep referred as alias for backward compatibility
2425
@property
25-
def referred(self) -> Node:
26+
def referred(self) -> ASTNode:
2627
return self._reference
2728

2829
def value(self) -> Optional[Any]:

flowquery-py/src/graph/relationship.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@ def set_data(self, data: Optional['RelationshipData']) -> None:
123123
def get_data(self) -> Optional['RelationshipData']:
124124
return self._data
125125

126-
def set_value(self, relationship: 'Relationship') -> None:
126+
def set_value(self, relationship: 'Relationship', traversal_id: str = "") -> None:
127127
"""Set value by pushing match to collector."""
128-
self._matches.push(relationship)
128+
self._matches.push(relationship, traversal_id)
129129
self._value = self._matches.value()
130130

131131
def value(self) -> Optional[Union[RelationshipMatchRecord, List[RelationshipMatchRecord]]]:
@@ -139,11 +139,13 @@ def set_end_node(self, node: 'Node') -> None:
139139
"""Set the end node for the current match."""
140140
self._matches.end_node = node
141141

142+
def _left_id_or_right_id(self) -> str:
143+
return "left_id" if self._direction == "left" else "right_id"
144+
142145
async def find(self, left_id: str, hop: int = 0) -> None:
143146
"""Find relationships starting from the given node ID."""
144147
# Save original source node
145148
original = self._source
146-
is_left = self._direction == "left"
147149
if hop > 0:
148150
# For hops greater than 0, the source becomes the target of the previous hop
149151
self._source = self._target
@@ -158,30 +160,26 @@ async def find(self, left_id: str, hop: int = 0) -> None:
158160
# No relationship match is pushed since no edge is traversed
159161
await self._target.find(left_id, hop)
160162

161-
def find_match(id_: str, h: int) -> bool:
162-
if self._data is None:
163-
return False
164-
if is_left:
165-
return self._data.find_reverse(id_, h)
166-
return self._data.find(id_, h)
167-
follow_id = 'left_id' if is_left else 'right_id'
168-
while self._data and find_match(left_id, hop):
163+
while self._data and self._data.find(left_id, hop, self._direction):
169164
data = self._data.current(hop)
170-
if data and self._hops and hop + 1 >= self._hops.min:
171-
self.set_value(self)
165+
if data is None:
166+
continue
167+
id = data[self._left_id_or_right_id()]
168+
if hop + 1 >= self._hops.min:
169+
self.set_value(self, left_id)
172170
if not self._matches_properties(hop):
173171
continue
174-
if self._target and follow_id in data:
175-
await self._target.find(data[follow_id], hop)
176-
if self._matches.is_circular():
177-
raise ValueError("Circular relationship detected")
178-
if self._hops and hop + 1 < self._hops.max:
179-
await self.find(data[follow_id], hop + 1)
172+
if self._target:
173+
await self._target.find(id, hop)
174+
if hop + 1 < self._hops.max:
175+
if self._matches.is_circular(id):
176+
self._matches.pop()
177+
continue
178+
await self.find(id, hop + 1)
180179
self._matches.pop()
181-
elif data and self._hops:
180+
else:
182181
# Below minimum hops: traverse the edge without yielding a match
183-
if follow_id in data:
184-
await self.find(data[follow_id], hop + 1)
182+
await self.find(id, hop + 1)
185183

186184
# Restore original source node
187185
self._source = original

flowquery-py/src/graph/relationship_data.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,10 @@ def __init__(self, records: Optional[List[Dict[str, Any]]] = None):
1919
self._build_index("left_id")
2020
self._build_index("right_id")
2121

22-
def find(self, left_id: str, hop: int = 0) -> bool:
23-
"""Find a relationship by start node ID."""
24-
return self._find(left_id, hop, "left_id")
25-
26-
def find_reverse(self, right_id: str, hop: int = 0) -> bool:
27-
"""Find a relationship by end node ID (reverse direction)."""
28-
return self._find(right_id, hop, "right_id")
22+
def find(self, id: str, hop: int = 0, direction: str = "right") -> bool:
23+
"""Find a relationship by node ID and direction."""
24+
key = "right_id" if direction == "left" else "left_id"
25+
return self._find(id, hop, key)
2926

3027
def properties(self) -> Optional[Dict[str, Any]]:
3128
"""Get properties of current relationship, excluding left_id and right_id."""

flowquery-py/src/graph/relationship_match_collector.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self) -> None:
2424
self._matches: List[RelationshipMatchRecord] = []
2525
self._node_ids: List[str] = []
2626

27-
def push(self, relationship: 'Relationship') -> RelationshipMatchRecord:
27+
def push(self, relationship: 'Relationship', traversal_id: str = "") -> RelationshipMatchRecord:
2828
"""Push a new match onto the collector."""
2929
start_node_value = relationship.source.value() if relationship.source else None
3030
rel_data = relationship.get_data()
@@ -36,8 +36,7 @@ def push(self, relationship: 'Relationship') -> RelationshipMatchRecord:
3636
"properties": rel_props,
3737
}
3838
self._matches.append(match)
39-
if isinstance(start_node_value, dict):
40-
self._node_ids.append(start_node_value.get("id", ""))
39+
self._node_ids.append(traversal_id)
4140
return match
4241

4342
@property
@@ -76,7 +75,6 @@ def matches(self) -> List[RelationshipMatchRecord]:
7675
"""Get all matches."""
7776
return self._matches
7877

79-
def is_circular(self) -> bool:
80-
"""Check if the collected relationships form a circular pattern."""
81-
seen = set(self._node_ids)
82-
return len(seen) < len(self._node_ids)
78+
def is_circular(self, next_id: str = "") -> bool:
79+
"""Check if traversing to the given node id would form a cycle."""
80+
return next_id in self._node_ids

flowquery-py/src/parsing/operations/group_by.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""GroupBy implementation for aggregate operations."""
22

3+
import json
34
from typing import Any, Dict, Generator, List, Optional
45

56
from ..ast_node import ASTNode
@@ -8,6 +9,15 @@
89
from .projection import Projection
910

1011

12+
def _make_hashable(value: Any) -> Any:
13+
"""Convert a value to a hashable form for use as a dict key."""
14+
if isinstance(value, dict):
15+
return json.dumps(value, sort_keys=True, default=str)
16+
if isinstance(value, list):
17+
return json.dumps(value, sort_keys=True, default=str)
18+
return value
19+
20+
1121
class GroupByNode:
1222
"""Represents a node in the group-by tree."""
1323

@@ -60,10 +70,11 @@ def _map(self) -> None:
6070
node = self._current
6171
for mapper in self.mappers:
6272
value = mapper.value()
63-
child = node.children.get(value)
73+
key = _make_hashable(value)
74+
child = node.children.get(key)
6475
if child is None:
6576
child = GroupByNode(value)
66-
node.children[value] = child
77+
node.children[key] = child
6778
node = child
6879
self._current = node
6980

flowquery-py/src/parsing/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def _parse_node(self) -> Optional[Node]:
499499
inner = ref_child.referred
500500
if isinstance(inner, Node):
501501
reference = inner
502-
if reference is None or not isinstance(reference, Node):
502+
if reference is None or (not isinstance(reference, Node) and not isinstance(reference, Unwind)):
503503
raise ValueError(f"Undefined node reference: {identifier}")
504504
node = NodeReference(node, reference)
505505
elif identifier is not None:

flowquery-py/tests/compute/test_runner.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,8 +1279,8 @@ async def test_circular_graph_pattern(self):
12791279
assert len(results) == 2
12801280

12811281
@pytest.mark.asyncio
1282-
async def test_circular_graph_pattern_with_variable_length_should_throw_error(self):
1283-
"""Test circular graph pattern with variable length should throw error."""
1282+
async def test_circular_graph_pattern_with_variable_length_should_not_revisit_nodes(self):
1283+
"""Test circular graph pattern with variable length should not revisit nodes."""
12841284
await Runner(
12851285
"""
12861286
CREATE VIRTUAL (:CircularVarPerson) AS {
@@ -1309,8 +1309,10 @@ async def test_circular_graph_pattern_with_variable_length_should_throw_error(se
13091309
RETURN p AS pattern
13101310
"""
13111311
)
1312-
with pytest.raises(ValueError, match="Circular relationship detected"):
1313-
await match.run()
1312+
await match.run()
1313+
results = match.results
1314+
# Circular graph 1↔2: cycles are skipped, only acyclic paths are returned
1315+
assert len(results) == 6
13141316

13151317
@pytest.mark.asyncio
13161318
async def test_multi_hop_match_with_min_hops_constraint_1(self):
@@ -2212,4 +2214,43 @@ async def test_where_with_contains_combined_with_and(self):
22122214
await runner.run()
22132215
results = runner.results
22142216
assert len(results) == 1
2215-
assert results[0]["fruit"] == "pineapple"
2217+
assert results[0]["fruit"] == "pineapple"
2218+
2219+
@pytest.mark.asyncio
2220+
async def test_collected_nodes_and_re_matching(self):
2221+
"""Test that collected nodes can be unwound and used as node references in subsequent MATCH."""
2222+
await Runner("""
2223+
CREATE VIRTUAL (:Person) AS {
2224+
unwind [
2225+
{id: 1, name: 'Person 1'},
2226+
{id: 2, name: 'Person 2'},
2227+
{id: 3, name: 'Person 3'},
2228+
{id: 4, name: 'Person 4'}
2229+
] as record
2230+
RETURN record.id as id, record.name as name
2231+
}
2232+
""").run()
2233+
await Runner("""
2234+
CREATE VIRTUAL (:Person)-[:KNOWS]-(:Person) AS {
2235+
unwind [
2236+
{left_id: 1, right_id: 2},
2237+
{left_id: 2, right_id: 3},
2238+
{left_id: 3, right_id: 4}
2239+
] as record
2240+
RETURN record.left_id as left_id, record.right_id as right_id
2241+
}
2242+
""").run()
2243+
runner = Runner("""
2244+
MATCH (a:Person)-[:KNOWS*0..3]->(b:Person)
2245+
WITH collect(a) AS persons, b
2246+
UNWIND persons AS p
2247+
match (p)-[:KNOWS]->(:Person)
2248+
return p.name AS name
2249+
""")
2250+
await runner.run()
2251+
results = runner.results
2252+
assert len(results) == 9
2253+
names = [r["name"] for r in results]
2254+
assert "Person 1" in names
2255+
assert "Person 2" in names
2256+
assert "Person 3" in names

src/graph/node_reference.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import ASTNode from "../parsing/ast_node";
12
import Node from "./node";
23

34
class NodeReference extends Node {
4-
private _reference: Node | null = null;
5-
constructor(base: Node, reference: Node) {
5+
private _reference: ASTNode | null = null;
6+
constructor(base: Node, reference: ASTNode) {
67
super();
78
this._identifier = base.identifier;
89
this._label = base.label;
@@ -11,7 +12,7 @@ class NodeReference extends Node {
1112
this._incoming = base.incoming;
1213
this._reference = reference;
1314
}
14-
public get reference(): Node | null {
15+
public get reference(): ASTNode | null {
1516
return this._reference;
1617
}
1718
public async next(): Promise<void> {

src/graph/relationship.ts

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ class Relationship extends ASTNode {
7171
public get hops(): Hops | null {
7272
return this._hops;
7373
}
74-
public setValue(relationship: Relationship): void {
75-
const match: RelationshipMatchRecord = this._matches.push(relationship);
74+
public setValue(relationship: Relationship, traversalId: string = ""): void {
75+
const match: RelationshipMatchRecord = this._matches.push(relationship, traversalId);
7676
this._value = this._matches.value();
7777
}
7878
public set source(node: Node | null) {
@@ -108,10 +108,12 @@ class Relationship extends ASTNode {
108108
public setEndNode(node: Node): void {
109109
this._matches.endNode = node;
110110
}
111+
public _left_id_or_right_id(): string {
112+
return this._direction === "left" ? "left_id" : "right_id";
113+
}
111114
public async find(left_id: string, hop: number = 0): Promise<void> {
112115
// Save original source node
113116
const original = this._source;
114-
const isLeft = this._direction === "left";
115117
if (hop > 0) {
116118
// For hops greater than 0, the source becomes the target of the previous hop
117119
this._source = this._target;
@@ -127,28 +129,26 @@ class Relationship extends ASTNode {
127129
await this._target.find(left_id, hop);
128130
}
129131
}
130-
const findMatch = isLeft
131-
? (id: string, h: number) => this._data!.findReverse(id, h)
132-
: (id: string, h: number) => this._data!.find(id, h);
133-
const followId = isLeft ? "left_id" : "right_id";
134-
while (findMatch(left_id, hop)) {
132+
while (this._data!.find(left_id, hop, this._direction)) {
135133
const data: RelationshipRecord = this._data?.current(hop) as RelationshipRecord;
134+
const id = data[this._left_id_or_right_id()];
136135
if (hop + 1 >= this.hops!.min) {
137-
this.setValue(this);
136+
this.setValue(this, left_id);
138137
if (!this._matchesProperties(hop)) {
139138
continue;
140139
}
141-
await this._target?.find(data[followId], hop);
142-
if (this._matches.isCircular()) {
143-
throw new Error("Circular relationship detected");
144-
}
140+
await this._target?.find(id, hop);
145141
if (hop + 1 < this.hops!.max) {
146-
await this.find(data[followId], hop + 1);
142+
if (this._matches.isCircular(id)) {
143+
this._matches.pop();
144+
continue;
145+
}
146+
await this.find(id, hop + 1);
147147
}
148148
this._matches.pop();
149149
} else {
150150
// Below minimum hops: traverse the edge without yielding a match
151-
await this.find(data[followId], hop + 1);
151+
await this.find(id, hop + 1);
152152
}
153153
}
154154
// Restore original source node

src/graph/relationship_data.ts

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,8 @@ class RelationshipData extends Data {
88
super._buildIndex("left_id");
99
super._buildIndex("right_id");
1010
}
11-
public find(left_id: string, hop: number = 0): boolean {
12-
return super._find(left_id, hop, "left_id");
13-
}
14-
public findReverse(right_id: string, hop: number = 0): boolean {
15-
return super._find(right_id, hop, "right_id");
11+
public find(id: string, hop: number = 0, direction: "left" | "right" = "right"): boolean {
12+
return super._find(id, hop, direction === "left" ? "right_id" : "left_id");
1613
}
1714
/*
1815
** Get the properties of the current relationship record

0 commit comments

Comments
 (0)