66import random
77import uuid
88from dataclasses import dataclass
9+ from typing import TYPE_CHECKING
910
1011import pytest
1112
12- from duron import durable , get_context , task
13+ from duron import durable
1314from duron .log .storage import MemoryLogStorage
1415
16+ if TYPE_CHECKING :
17+ from duron .context import Context
18+
1519
1620@pytest .mark .asyncio
1721async def test_task ():
@@ -21,8 +25,7 @@ async def u() -> str:
2125 return str (uuid .uuid4 ())
2226
2327 @durable ()
24- async def activity (i : str ) -> str :
25- ctx = get_context ()
28+ async def activity (ctx : Context , i : str ) -> str :
2629 x = await asyncio .gather (
2730 ctx .run (u ),
2831 ctx .run (u ),
@@ -44,18 +47,18 @@ async def activity(i: str) -> str:
4447 }
4548
4649 log = MemoryLogStorage ()
47- async with task ( activity , log ) as t :
50+ async with activity ( log ) as t :
4851 await t .start ("test" )
4952 a = await t .wait ()
5053 assert set (e ["id" ] for e in await log .entries ()) == IDS
5154
52- async with task ( activity , log ) as t :
55+ async with activity ( log ) as t :
5356 await t .start ("test" )
5457 b = await t .wait ()
5558 assert a == b
5659
5760 log2 = MemoryLogStorage ((await log .entries ())[:- 2 ])
58- async with task ( activity , log2 ) as t :
61+ async with activity ( log2 ) as t :
5962 await t .start ("test" )
6063 c = await t .wait ()
6164 assert a == c
@@ -65,8 +68,7 @@ async def activity(i: str) -> str:
6568@pytest .mark .asyncio
6669async def test_task_error ():
6770 @durable ()
68- async def activity ():
69- ctx = get_context ()
71+ async def activity (ctx : Context ):
7072 _ = await ctx .run (lambda : asyncio .sleep (0.1 ))
7173
7274 async def error ():
@@ -76,11 +78,11 @@ async def error():
7678
7779 log = MemoryLogStorage ()
7880 with pytest .raises (check = lambda v : "test error" in str (v )):
79- async with task ( activity , log ) as t :
81+ async with activity ( log ) as t :
8082 await t .start ()
8183 await t .wait ()
8284 with pytest .raises (check = lambda v : "test error" in str (v )):
83- async with task ( activity , log ) as t :
85+ async with activity ( log ) as t :
8486 await t .start ()
8587 await t .wait ()
8688
@@ -90,18 +92,17 @@ async def test_resume():
9092 sleep = 9999
9193
9294 @durable ()
93- async def activity (s : str ) -> str :
94- ctx = get_context ()
95+ async def activity (ctx : Context , s : str ) -> str :
9596 _ = await ctx .run (lambda : asyncio .sleep (sleep ))
9697 return s
9798
9899 log = MemoryLogStorage ()
99- async with task ( activity , log ) as t :
100+ async with activity ( log ) as t :
100101 await t .start ("hello" )
101102 with pytest .raises (asyncio .TimeoutError ):
102103 _ = await asyncio .wait_for (t .wait (), 0.1 )
103104
104- async with task ( activity , log ) as t :
105+ async with activity ( log ) as t :
105106 sleep = 0
106107 await t .resume ()
107108 x = await t .wait ()
@@ -127,13 +128,12 @@ def decode_json(self, encoded: object) -> object:
127128@pytest .mark .asyncio
128129async def test_serialize ():
129130 @durable (codec = PickleCodec ())
130- async def activity () -> CustomPoint :
131- ctx = get_context ()
131+ async def activity (ctx : Context ) -> CustomPoint :
132132 pt = await ctx .run (lambda : CustomPoint (x = 1 , y = 2 ))
133133 return CustomPoint (x = pt .x + 5 , y = pt .y + 10 )
134134
135135 log = MemoryLogStorage ()
136- async with task ( activity , log ) as t :
136+ async with activity ( log ) as t :
137137 await t .start ()
138138 a = await t .wait ()
139139 assert type (a ) is CustomPoint
0 commit comments