Skip to content

Commit 6799de8

Browse files
authored
Merge pull request #35 from microsoft/bug_fixes_and_improvements
Bug fixes and improvements
2 parents 2a757f8 + 49f8861 commit 6799de8

17 files changed

Lines changed: 674 additions & 21 deletions

File tree

flowquery-py/src/parsing/expressions/operator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,23 +163,23 @@ def value(self) -> int:
163163

164164
class Is(Operator):
165165
def __init__(self) -> None:
166-
super().__init__(-1, True)
166+
super().__init__(0, True)
167167

168168
def value(self) -> int:
169169
return 1 if self.lhs.value() == self.rhs.value() else 0
170170

171171

172172
class IsNot(Operator):
173173
def __init__(self) -> None:
174-
super().__init__(-1, True)
174+
super().__init__(0, True)
175175

176176
def value(self) -> int:
177177
return 1 if self.lhs.value() != self.rhs.value() else 0
178178

179179

180180
class In(Operator):
181181
def __init__(self) -> None:
182-
super().__init__(-1, True)
182+
super().__init__(0, True)
183183

184184
def value(self) -> int:
185185
lst = self.rhs.value()
@@ -190,7 +190,7 @@ def value(self) -> int:
190190

191191
class NotIn(Operator):
192192
def __init__(self) -> None:
193-
super().__init__(-1, True)
193+
super().__init__(0, True)
194194

195195
def value(self) -> int:
196196
lst = self.rhs.value()

flowquery-py/src/parsing/operations/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .load import Load
1111
from .match import Match
1212
from .operation import Operation
13+
from .order_by import OrderBy, SortField
1314
from .projection import Projection
1415
from .return_op import Return
1516
from .union import Union
@@ -36,4 +37,6 @@
3637
"CreateRelationship",
3738
"Union",
3839
"UnionAll",
40+
"OrderBy",
41+
"SortField",
3942
]

flowquery-py/src/parsing/operations/aggregated_return.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,7 @@ async def run(self) -> None:
1919
def results(self) -> List[Dict[str, Any]]:
2020
if self._where is not None:
2121
self._group_by.where = self._where
22-
return list(self._group_by.generate_results())
22+
results = list(self._group_by.generate_results())
23+
if self._order_by is not None:
24+
results = self._order_by.sort(results)
25+
return results

flowquery-py/src/parsing/operations/limit.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ def __init__(self, limit: int):
1515
def is_limit_reached(self) -> bool:
1616
return self._count >= self._limit
1717

18+
@property
19+
def limit_value(self) -> int:
20+
return self._limit
21+
1822
def increment(self) -> None:
1923
self._count += 1
2024

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""Represents an ORDER BY operation that sorts results."""
2+
3+
from typing import Any, Dict, List
4+
5+
from .operation import Operation
6+
7+
8+
class SortField:
9+
"""A single sort specification: field name and direction."""
10+
11+
def __init__(self, field: str, direction: str = "asc"):
12+
self.field = field
13+
self.direction = direction
14+
15+
16+
class OrderBy(Operation):
17+
"""Represents an ORDER BY operation that sorts results.
18+
19+
Can be attached to a RETURN operation (sorting its results),
20+
or used as a standalone accumulating operation after a non-aggregate WITH.
21+
22+
Example:
23+
RETURN x ORDER BY x DESC
24+
"""
25+
26+
def __init__(self, fields: List[SortField]):
27+
super().__init__()
28+
self._fields = fields
29+
self._results: List[Dict[str, Any]] = []
30+
31+
@property
32+
def fields(self) -> List[SortField]:
33+
return self._fields
34+
35+
def sort(self, records: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
36+
"""Sorts an array of records according to the sort fields."""
37+
import functools
38+
39+
def compare(a: Dict[str, Any], b: Dict[str, Any]) -> int:
40+
for sf in self._fields:
41+
a_val = a.get(sf.field)
42+
b_val = b.get(sf.field)
43+
cmp = 0
44+
if a_val is None and b_val is None:
45+
cmp = 0
46+
elif a_val is None:
47+
cmp = -1
48+
elif b_val is None:
49+
cmp = 1
50+
elif a_val < b_val:
51+
cmp = -1
52+
elif a_val > b_val:
53+
cmp = 1
54+
if cmp != 0:
55+
return -cmp if sf.direction == "desc" else cmp
56+
return 0
57+
58+
return sorted(records, key=functools.cmp_to_key(compare))
59+
60+
async def run(self) -> None:
61+
"""When used as a standalone operation, passes through to next."""
62+
if self.next:
63+
await self.next.run()
64+
65+
async def initialize(self) -> None:
66+
self._results = []
67+
if self.next:
68+
await self.next.initialize()
69+
70+
@property
71+
def results(self) -> List[Dict[str, Any]]:
72+
return self._results

flowquery-py/src/parsing/operations/return_op.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from ..ast_node import ASTNode
77
from .limit import Limit
8+
from .order_by import OrderBy
89
from .projection import Projection
910

1011
if TYPE_CHECKING:
@@ -26,6 +27,7 @@ def __init__(self, expressions: List[ASTNode]) -> None:
2627
self._where: Optional['Where'] = None
2728
self._results: List[Dict[str, Any]] = []
2829
self._limit: Optional[Limit] = None
30+
self._order_by: Optional[OrderBy] = None
2931

3032
@property
3133
def where(self) -> Any:
@@ -45,10 +47,20 @@ def limit(self) -> Optional[Limit]:
4547
def limit(self, limit: Limit) -> None:
4648
self._limit = limit
4749

50+
@property
51+
def order_by(self) -> Optional[OrderBy]:
52+
return self._order_by
53+
54+
@order_by.setter
55+
def order_by(self, order_by: OrderBy) -> None:
56+
self._order_by = order_by
57+
4858
async def run(self) -> None:
4959
if not self.where:
5060
return
51-
if self._limit is not None and self._limit.is_limit_reached:
61+
# When ORDER BY is present, skip limit during accumulation;
62+
# limit will be applied after sorting in results property
63+
if self._order_by is None and self._limit is not None and self._limit.is_limit_reached:
5264
return
5365
record: Dict[str, Any] = {}
5466
for expression, alias in self.expressions():
@@ -57,12 +69,17 @@ async def run(self) -> None:
5769
value = copy.deepcopy(raw) if isinstance(raw, (dict, list)) else raw
5870
record[alias] = value
5971
self._results.append(record)
60-
if self._limit is not None:
72+
if self._order_by is None and self._limit is not None:
6173
self._limit.increment()
6274

6375
async def initialize(self) -> None:
6476
self._results = []
6577

6678
@property
6779
def results(self) -> List[Dict[str, Any]]:
68-
return self._results
80+
result = self._results
81+
if self._order_by is not None:
82+
result = self._order_by.sort(result)
83+
if self._order_by is not None and self._limit is not None:
84+
result = result[:self._limit.limit_value]
85+
return result

flowquery-py/src/parsing/parser.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from .operations.load import Load
6262
from .operations.match import Match
6363
from .operations.operation import Operation
64+
from .operations.order_by import OrderBy, SortField
6465
from .operations.return_op import Return
6566
from .operations.union import Union
6667
from .operations.union_all import UnionAll
@@ -146,6 +147,14 @@ def _parse_tokenized(self, is_sub_query: bool = False) -> ASTNode:
146147
operation.add_sibling(where)
147148
operation = where
148149

150+
order_by = self._parse_order_by()
151+
if order_by is not None:
152+
if isinstance(operation, Return):
153+
operation.order_by = order_by
154+
else:
155+
operation.add_sibling(order_by)
156+
operation = order_by
157+
149158
limit = self._parse_limit()
150159
if limit is not None:
151160
if isinstance(operation, Return):
@@ -694,6 +703,41 @@ def _parse_limit(self) -> Optional[Limit]:
694703
self.set_next_token()
695704
return limit
696705

706+
def _parse_order_by(self) -> Optional[OrderBy]:
707+
self._skip_whitespace_and_comments()
708+
if not self.token.is_order():
709+
return None
710+
self._expect_previous_token_to_be_whitespace_or_comment()
711+
self.set_next_token()
712+
self._expect_and_skip_whitespace_and_comments()
713+
if not self.token.is_by():
714+
raise ValueError("Expected BY after ORDER")
715+
self.set_next_token()
716+
self._expect_and_skip_whitespace_and_comments()
717+
fields: list[SortField] = []
718+
while True:
719+
if not self.token.is_identifier_or_keyword():
720+
raise ValueError("Expected field name in ORDER BY")
721+
field = self.token.value
722+
self.set_next_token()
723+
self._skip_whitespace_and_comments()
724+
direction = "asc"
725+
if self.token.is_asc():
726+
direction = "asc"
727+
self.set_next_token()
728+
self._skip_whitespace_and_comments()
729+
elif self.token.is_desc():
730+
direction = "desc"
731+
self.set_next_token()
732+
self._skip_whitespace_and_comments()
733+
fields.append(SortField(field, direction))
734+
if self.token.is_comma():
735+
self.set_next_token()
736+
self._skip_whitespace_and_comments()
737+
else:
738+
break
739+
return OrderBy(fields)
740+
697741
def _parse_expressions(
698742
self, alias_option: AliasOption = AliasOption.NOT_ALLOWED
699743
) -> Iterator[Expression]:

flowquery-py/src/tokenization/token.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,34 @@ def ALL() -> Token:
630630
def is_all(self) -> bool:
631631
return self._type == TokenType.KEYWORD and self._value == Keyword.ALL.value
632632

633+
@staticmethod
634+
def ORDER() -> Token:
635+
return Token(TokenType.KEYWORD, Keyword.ORDER.value)
636+
637+
def is_order(self) -> bool:
638+
return self._type == TokenType.KEYWORD and self._value == Keyword.ORDER.value
639+
640+
@staticmethod
641+
def BY() -> Token:
642+
return Token(TokenType.KEYWORD, Keyword.BY.value)
643+
644+
def is_by(self) -> bool:
645+
return self._type == TokenType.KEYWORD and self._value == Keyword.BY.value
646+
647+
@staticmethod
648+
def ASC() -> Token:
649+
return Token(TokenType.KEYWORD, Keyword.ASC.value)
650+
651+
def is_asc(self) -> bool:
652+
return self._type == TokenType.KEYWORD and self._value == Keyword.ASC.value
653+
654+
@staticmethod
655+
def DESC() -> Token:
656+
return Token(TokenType.KEYWORD, Keyword.DESC.value)
657+
658+
def is_desc(self) -> bool:
659+
return self._type == TokenType.KEYWORD and self._value == Keyword.DESC.value
660+
633661
# End of file token
634662

635663
@staticmethod

0 commit comments

Comments
 (0)