Skip to content

Commit 6a94d20

Browse files
Fix tests
1 parent 2db977e commit 6a94d20

4 files changed

Lines changed: 45 additions & 11 deletions

File tree

aws_lambda_powertools/event_handler/openapi/merge.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,17 @@ def _file_has_resolver(file_path: Path, resolver_name: str) -> bool:
6767
return False
6868

6969
for node in ast.walk(tree):
70-
targets = []
70+
targets: list[ast.expr] = []
71+
value: ast.expr | None = None
7172
if isinstance(node, ast.Assign):
7273
targets = node.targets
74+
value = node.value
7375
elif isinstance(node, ast.AnnAssign):
7476
targets = [node.target]
77+
value = node.value
7578
for target in targets:
7679
if isinstance(target, ast.Name) and target.id == resolver_name:
77-
if _is_resolver_call(node.value):
80+
if value is not None and _is_resolver_call(value):
7881
return True
7982
return False
8083

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from __future__ import annotations
2+
3+
from pydantic import BaseModel
4+
5+
from aws_lambda_powertools.event_handler import APIGatewayRestResolver
6+
7+
app: APIGatewayRestResolver = APIGatewayRestResolver(enable_validation=True)
8+
9+
10+
class Product(BaseModel):
11+
id: int
12+
name: str
13+
price: float
14+
15+
16+
@app.get("/products")
17+
def get_products() -> list[Product]:
18+
return [
19+
Product(id=1, name="Widget", price=9.99),
20+
]
21+
22+
23+
def handler(event, context):
24+
return app.resolve(event, context)

tests/functional/event_handler/_pydantic/test_openapi_merge.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,3 +367,19 @@ def test_openapi_merge_schema_is_cached():
367367

368368
# AND paths should not be duplicated
369369
assert len([p for p in schema1["paths"] if p == "/users"]) == 1
370+
371+
372+
def test_openapi_merge_discover_type_annotated_resolver():
373+
# GIVEN an OpenAPIMerge instance
374+
merge = OpenAPIMerge(title="Typed API", version="1.0.0")
375+
376+
# WHEN discovering a handler with a type-annotated resolver (app: Resolver = Resolver())
377+
merge.discover(
378+
path=MERGE_HANDLERS_PATH,
379+
pattern="**/typed_handler.py",
380+
resolver_name="app",
381+
)
382+
383+
# THEN it should find the resolver and include its routes in the schema
384+
schema = merge.get_openapi_schema()
385+
assert "/products" in schema["paths"]

tests/unit/event_handler/openapi/test_openapi_merge.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,6 @@ def test_file_has_resolver_found(tmp_path: Path):
5757
assert _file_has_resolver(handler_file, "app") is True
5858

5959

60-
def test_file_has_resolver_found_with_type_annotation(tmp_path: Path):
61-
handler_file = tmp_path / "handler.py"
62-
handler_file.write_text("""
63-
from aws_lambda_powertools.event_handler import APIGatewayRestResolver
64-
app: APIGatewayRestResolver = APIGatewayRestResolver()
65-
""")
66-
assert _file_has_resolver(handler_file, "app") is True
67-
68-
6960
def test_is_excluded_with_directory_pattern():
7061
root = Path("/project")
7162
assert _is_excluded(Path("/project/tests/handler.py"), root, ["**/tests/**"]) is True

0 commit comments

Comments
 (0)