Skip to content

Commit fc030bf

Browse files
committed
feat: update interface
1 parent 53b6513 commit fc030bf

5 files changed

Lines changed: 48 additions & 42 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,4 @@ venvPath = ".venv"
6161
reportAny = false
6262
reportExplicitAny = false
6363
reportUnreachable = false
64+
reportImportCycles = false

src/duron/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from duron.context import get_context
1+
from duron.context import Context, get_context
22
from duron.fn import durable
3-
from duron.task import task
43

5-
__all__ = ["durable", "task", "get_context"]
4+
__all__ = ["durable", "get_context", "Context"]

src/duron/fn.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,26 @@
33
from dataclasses import dataclass
44
from typing import (
55
TYPE_CHECKING,
6+
Any,
7+
Concatenate,
68
Generic,
79
ParamSpec,
810
TypeVar,
11+
cast,
912
)
1013

1114
from duron.codec import DefaultCodec
15+
from duron.task import Task, TaskGuard
1216

1317
if TYPE_CHECKING:
14-
from collections.abc import Callable
18+
from collections.abc import Callable, Coroutine
1519

1620
from duron.codec import Codec
21+
from duron.context import Context
22+
from duron.log import LogStorage
1723

24+
_TOffset = TypeVar("_TOffset")
25+
_TLease = TypeVar("_TLease")
1826

1927
_T_co = TypeVar("_T_co", covariant=True)
2028
_P = ParamSpec("_P")
@@ -23,21 +31,29 @@
2331
@dataclass(slots=True)
2432
class DurableFn(Generic[_P, _T_co]):
2533
codec: Codec
26-
fn: Callable[_P, _T_co]
34+
fn: Callable[Concatenate[Context, _P], Coroutine[Any, Any, _T_co]]
2735

28-
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T_co:
29-
return self.fn(*args, **kwargs)
36+
def __call__(
37+
self,
38+
log: LogStorage[_TOffset, _TLease],
39+
) -> TaskGuard[_P, _T_co]:
40+
return TaskGuard(Task(self, cast("LogStorage[object, object]", log)))
3041

3142

3243
def durable(
3344
*, codec: Codec | None = None
34-
) -> Callable[[Callable[_P, _T_co]], DurableFn[_P, _T_co]]:
45+
) -> Callable[
46+
[Callable[Concatenate[Context, _P], Coroutine[Any, Any, _T_co]]],
47+
DurableFn[_P, _T_co],
48+
]:
3549
"""
3650
Mark a function as durable, meaning its execution can be recorded and
3751
replayed.
3852
"""
3953

40-
def decorate(fn: Callable[_P, _T_co]) -> DurableFn[_P, _T_co]:
54+
def decorate(
55+
fn: Callable[Concatenate[Context, _P], Coroutine[Any, Any, _T_co]],
56+
) -> DurableFn[_P, _T_co]:
4157
return DurableFn(codec=codec or DefaultCodec(), fn=fn)
4258

4359
return decorate

src/duron/task.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,22 +38,13 @@
3838
)
3939
from duron.ops import Op
4040

41-
_TOffset = TypeVar("_TOffset")
42-
_TLease = TypeVar("_TLease")
4341

4442
_T = TypeVar("_T")
4543
_P = ParamSpec("_P")
4644

4745
_CURRENT_VERSION = 0
4846

4947

50-
def task(
51-
task_co: DurableFn[_P, Coroutine[Any, Any, _T]],
52-
log: LogStorage[_TOffset, _TLease],
53-
) -> TaskGuard[_P, _T]:
54-
return TaskGuard(Task(task_co, log))
55-
56-
5748
@final
5849
class TaskGuard(Generic[_P, _T]):
5950
def __init__(self, task: Task[_P, _T]) -> None:
@@ -74,9 +65,8 @@ async def __aexit__(
7465
class Task(Generic[_P, _T]):
7566
def __init__(
7667
self,
77-
task_fn: DurableFn[_P, Coroutine[Any, Any, _T]],
78-
log: LogStorage[_TOffset, _TLease],
79-
codec: Codec | None = None,
68+
task_fn: DurableFn[_P, _T],
69+
log: LogStorage[object, object],
8070
) -> None:
8171
self._task_fn = task_fn
8272
self._log = log
@@ -91,7 +81,7 @@ async def start(self, *args: _P.args, **kwargs: _P.kwargs) -> None:
9181
}
9282
self._run = _TaskRun(
9383
_task_prelude(self._task_fn, lambda: init),
94-
cast("LogStorage[object, object]", self._log),
84+
self._log,
9585
codec,
9686
)
9787
await self._run.resume()
@@ -103,7 +93,7 @@ def cb() -> TaskInitParams:
10393
task = _task_prelude(self._task_fn, cb)
10494
self._run = _TaskRun(
10595
task,
106-
cast("LogStorage[object, object]", self._log),
96+
self._log,
10797
self._task_fn.codec,
10898
)
10999
await self._run.resume()
@@ -121,7 +111,7 @@ class TaskInitParams(TypedDict):
121111

122112

123113
async def _task_prelude(
124-
task_fn: DurableFn[..., Coroutine[Any, Any, object]],
114+
task_fn: DurableFn[..., object],
125115
init: Callable[[], TaskInitParams],
126116
) -> object:
127117
ctx = get_context()
@@ -131,7 +121,7 @@ async def _task_prelude(
131121
codec = task_fn.codec
132122
args = (codec.decode_json(arg) for arg in init_params["args"])
133123
kwargs = {k: codec.decode_json(v) for k, v in init_params["kwargs"].items()}
134-
return await task_fn.fn(*args, **kwargs)
124+
return await task_fn.fn(get_context(), *args, **kwargs)
135125

136126

137127
@final

tests/test_task.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66
import random
77
import uuid
88
from dataclasses import dataclass
9+
from typing import TYPE_CHECKING
910

1011
import pytest
1112

12-
from duron import durable, get_context, task
13+
from duron import durable
1314
from duron.log.storage import MemoryLogStorage
1415

16+
if TYPE_CHECKING:
17+
from duron.context import Context
18+
1519

1620
@pytest.mark.asyncio
1721
async def test_task():
@@ -21,8 +25,7 @@ async def u() -> str:
2125
return str(uuid.uuid4())
2226

2327
@durable()
24-
async def activity(i: str) -> str:
25-
ctx = get_context()
28+
async def activity(ctx: Context, i: str) -> str:
2629
x = await asyncio.gather(
2730
ctx.run(u),
2831
ctx.run(u),
@@ -44,18 +47,18 @@ async def activity(i: str) -> str:
4447
}
4548

4649
log = MemoryLogStorage()
47-
async with task(activity, log) as t:
50+
async with activity(log) as t:
4851
await t.start("test")
4952
a = await t.wait()
5053
assert set(e["id"] for e in await log.entries()) == IDS
5154

52-
async with task(activity, log) as t:
55+
async with activity(log) as t:
5356
await t.start("test")
5457
b = await t.wait()
5558
assert a == b
5659

5760
log2 = MemoryLogStorage((await log.entries())[:-2])
58-
async with task(activity, log2) as t:
61+
async with activity(log2) as t:
5962
await t.start("test")
6063
c = await t.wait()
6164
assert a == c
@@ -65,8 +68,7 @@ async def activity(i: str) -> str:
6568
@pytest.mark.asyncio
6669
async def test_task_error():
6770
@durable()
68-
async def activity():
69-
ctx = get_context()
71+
async def activity(ctx: Context):
7072
_ = await ctx.run(lambda: asyncio.sleep(0.1))
7173

7274
async def error():
@@ -76,11 +78,11 @@ async def error():
7678

7779
log = MemoryLogStorage()
7880
with pytest.raises(check=lambda v: "test error" in str(v)):
79-
async with task(activity, log) as t:
81+
async with activity(log) as t:
8082
await t.start()
8183
await t.wait()
8284
with pytest.raises(check=lambda v: "test error" in str(v)):
83-
async with task(activity, log) as t:
85+
async with activity(log) as t:
8486
await t.start()
8587
await t.wait()
8688

@@ -90,18 +92,17 @@ async def test_resume():
9092
sleep = 9999
9193

9294
@durable()
93-
async def activity(s: str) -> str:
94-
ctx = get_context()
95+
async def activity(ctx: Context, s: str) -> str:
9596
_ = await ctx.run(lambda: asyncio.sleep(sleep))
9697
return s
9798

9899
log = MemoryLogStorage()
99-
async with task(activity, log) as t:
100+
async with activity(log) as t:
100101
await t.start("hello")
101102
with pytest.raises(asyncio.TimeoutError):
102103
_ = await asyncio.wait_for(t.wait(), 0.1)
103104

104-
async with task(activity, log) as t:
105+
async with activity(log) as t:
105106
sleep = 0
106107
await t.resume()
107108
x = await t.wait()
@@ -127,13 +128,12 @@ def decode_json(self, encoded: object) -> object:
127128
@pytest.mark.asyncio
128129
async def test_serialize():
129130
@durable(codec=PickleCodec())
130-
async def activity() -> CustomPoint:
131-
ctx = get_context()
131+
async def activity(ctx: Context) -> CustomPoint:
132132
pt = await ctx.run(lambda: CustomPoint(x=1, y=2))
133133
return CustomPoint(x=pt.x + 5, y=pt.y + 10)
134134

135135
log = MemoryLogStorage()
136-
async with task(activity, log) as t:
136+
async with activity(log) as t:
137137
await t.start()
138138
a = await t.wait()
139139
assert type(a) is CustomPoint

0 commit comments

Comments
 (0)