Skip to content

Commit 4fefc46

Browse files
committed
feat: use type annotations for stateful
1 parent 3b50b13 commit 4fefc46

8 files changed

Lines changed: 101 additions & 163 deletions

File tree

examples/hello_world.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ async def generate_lucky_number() -> int:
3030
return random.randint(1, 100)
3131

3232

33-
@duron.effect(stateful=True, initial=lambda: 0, reducer=int.__add__)
33+
@duron.effect
3434
async def count_up(count: int, target: int) -> AsyncGenerator[int, int]:
3535
await asyncio.sleep(0.5)
3636
while count < target:
37-
count = yield 10
37+
count = yield (count + 10)
3838
logger.info("⚡ Current count: %s", count)
3939
await asyncio.sleep(0.05)
4040

@@ -44,7 +44,7 @@ async def greeting_flow(ctx: duron.Context, name: str) -> str:
4444
message, lucky_number = await asyncio.gather(
4545
ctx.run(work, name), ctx.run(generate_lucky_number)
4646
)
47-
_ = await ctx.run(count_up, lucky_number)
47+
_ = await ctx.run(count_up, 0, lucky_number)
4848
return f"{message} Your lucky number is {lucky_number}."
4949

5050

src/duron/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@
88
from duron._core.stream import StreamClosed as StreamClosed
99
from duron._core.stream import StreamWriter as StreamWriter
1010
from duron._decorator.durable import durable as durable
11+
from duron._decorator.effect import Reducer as Reducer
1112
from duron._decorator.effect import effect as effect
1213
from duron.typing import Provided as Provided

src/duron/_core/context.py

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import functools
66
import inspect
77
from random import Random
8-
from typing import TYPE_CHECKING, cast
9-
from typing_extensions import Any, ParamSpec, TypeVar, final, overload
8+
from typing import TYPE_CHECKING, Concatenate, cast, get_args, get_origin
9+
from typing_extensions import Any, AsyncGenerator, ParamSpec, TypeVar, final, overload
1010

1111
from duron._core.ops import (
1212
Barrier,
@@ -18,8 +18,9 @@
1818
)
1919
from duron._core.signal import create_signal
2020
from duron._core.stream import create_stream, run_stateful
21-
from duron._decorator.effect import EffectFn, StatefulFn
21+
from duron._decorator.effect import Reducer
2222
from duron.typing import inspect_function
23+
from duron.typing._hint import UnspecifiedType
2324

2425
if TYPE_CHECKING:
2526
from collections.abc import Awaitable, Callable, Coroutine
@@ -46,28 +47,25 @@ def __init__(self, loop: EventLoop, seed: str) -> None:
4647
@overload
4748
async def run(
4849
self,
49-
fn: Callable[_P, Coroutine[Any, Any, _T]] | EffectFn[_P, _T],
50+
fn: Callable[_P, Coroutine[Any, Any, _T]],
5051
/,
5152
*args: _P.args,
5253
**kwargs: _P.kwargs,
5354
) -> _T: ...
5455
@overload
5556
async def run(
5657
self,
57-
fn: Callable[_P, _T] | StatefulFn[_P, _T, Any],
58+
fn: Callable[Concatenate[_T, _P], AsyncGenerator[_S, _T]],
5859
/,
60+
state: _S,
5961
*args: _P.args,
6062
**kwargs: _P.kwargs,
6163
) -> _T: ...
64+
@overload
6265
async def run(
63-
self,
64-
fn: Callable[_P, Coroutine[Any, Any, _T] | _T]
65-
| EffectFn[_P, _T]
66-
| StatefulFn[_P, _T, Any],
67-
/,
68-
*args: _P.args,
69-
**kwargs: _P.kwargs,
70-
) -> _T:
66+
self, fn: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs
67+
) -> _T: ...
68+
async def run(self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any) -> Any:
7169
"""
7270
Run a function within the context.
7371
@@ -81,18 +79,13 @@ async def run(
8179
msg = "Context time can only be used in the context loop"
8280
raise RuntimeError(msg)
8381

84-
if isinstance(fn, StatefulFn):
85-
async with self.stream(
86-
cast("StatefulFn[_P, _T, Any]", fn), *args, **kwargs
87-
) as (stream, result):
82+
if inspect.isasyncgenfunction(fn):
83+
async with self.stream(fn, *args, **kwargs) as (stream, result):
8884
await stream.discard()
8985
return await result
9086

9187
callable_: Callable[[], Coroutine[Any, Any, object]]
92-
if isinstance(fn, EffectFn):
93-
hint = fn.type_hint
94-
callable_ = functools.partial(fn.fn, *args, **kwargs)
95-
elif inspect.iscoroutinefunction(fn):
88+
if inspect.iscoroutinefunction(fn):
9689
hint = inspect_function(fn)
9790
callable_ = functools.partial(fn, *args, **kwargs)
9891
else:
@@ -103,7 +96,7 @@ async def wrapper() -> object: # noqa: RUF029
10396
hint = inspect_function(fn)
10497
callable_ = wrapper
10598

106-
op: asyncio.Future[_T] = create_op(
99+
op: asyncio.Future[object] = create_op(
107100
self._loop,
108101
FnCall(
109102
callable=callable_,
@@ -115,7 +108,12 @@ async def wrapper() -> object: # noqa: RUF029
115108
return await op
116109

117110
def stream(
118-
self, fn: StatefulFn[_P, _T, _S], /, *args: _P.args, **kwargs: _P.kwargs
111+
self,
112+
fn: Callable[Concatenate[_T, _P], AsyncGenerator[_S, _T]],
113+
/,
114+
initial: _T,
115+
*args: _P.args,
116+
**kwargs: _P.kwargs,
119117
) -> AbstractAsyncContextManager[tuple[Stream[_S], Awaitable[_T]]]:
120118
"""Stream stateful function partial results.
121119
@@ -133,8 +131,17 @@ def stream(
133131
if asyncio.get_running_loop() is not self._loop:
134132
msg = "Context time can only be used in the context loop"
135133
raise RuntimeError(msg)
134+
135+
type_hint = inspect_function(fn)
136+
action_type: TypeHint[_S] = UnspecifiedType
137+
if get_origin(ret := type_hint.return_type) is AsyncGenerator:
138+
action_type, _ = get_args(ret)
139+
140+
state_name = type_hint.parameters[0]
141+
annotations = type_hint.parameter_annotations.get(state_name, ())
142+
reducer = _find_reducer(tuple(annotations))
136143
return run_stateful(
137-
self._loop, fn.action_type, fn.initial(), fn.reducer, fn.fn, *args, **kwargs
144+
self._loop, action_type, reducer, fn, initial, *args, **kwargs
138145
)
139146

140147
async def create_stream(
@@ -288,3 +295,20 @@ async def complete_future(
288295
self._loop,
289296
FutureComplete(future_id=future_id, value=result, exception=exception),
290297
)
298+
299+
300+
def _find_reducer(annotations: tuple[Any, ...]) -> Callable[[_S, _T], _S]:
301+
for annotation in annotations:
302+
if not isinstance(annotation, Reducer):
303+
continue
304+
hint = inspect_function(annotation.reducer)
305+
if len(hint.parameters) != 2:
306+
msg = "Reducer function must have exactly two parameters"
307+
raise TypeError(msg)
308+
return cast("Callable[[_S, _T], _S]", annotation.reducer)
309+
310+
return cast("Callable[[_S, _T], _S]", _default_reducer)
311+
312+
313+
def _default_reducer(_old: object, new: object) -> object:
314+
return new

src/duron/_core/session.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from duron.tracing._tracer import current_tracer, span
4141
from duron.typing import JSONValue
4242
from duron.typing._hint import UnspecifiedType
43+
from duron.typing._inspect import inspect_function
4344

4445
if TYPE_CHECKING:
4546
from collections.abc import Callable, Coroutine
@@ -770,7 +771,7 @@ async def _prelude_fn(
770771
ctx = Context(loop, init_params["nonce"])
771772

772773
codec = fn.codec
773-
type_info = fn.type_hints
774+
type_info = inspect_function(fn.fn)
774775
args = tuple(
775776
codec.decode_json(
776777
arg,

src/duron/_core/stream.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -354,18 +354,18 @@ async def __aexit__(
354354
async def run_stateful(
355355
loop: EventLoop,
356356
dtype: TypeHint[Any],
357-
initial: _T,
358357
reducer: Callable[[_T, _U], _T],
359358
fn: Callable[Concatenate[_T, _P], AsyncGenerator[_U, _T]],
360359
/,
360+
initial: _T,
361361
*args: _P.args,
362362
**kwargs: _P.kwargs,
363363
) -> AsyncGenerator[tuple[Stream[_U], Awaitable[_T]], None]:
364364
assert asyncio.get_running_loop() is loop
365365

366366
name = cast("str", getattr(fn, "__name__", repr(fn)))
367367
stream: _StatefulStream[_U, _T] = _StatefulStream(
368-
initial, reducer, fn, *args, **kwargs
368+
reducer, fn, initial, *args, **kwargs
369369
)
370370
sink: StreamWriter[_U] = OpWriter(
371371
await create_op(
@@ -410,10 +410,10 @@ class _StatefulStream(_BufferStream[_U], Generic[_U, _T]):
410410

411411
def __init__(
412412
self,
413-
initial: _T,
414413
reducer: Callable[[_T, _U], _T],
415414
fn: Callable[Concatenate[_T, _P], AsyncGenerator[_U, _T]],
416415
/,
416+
initial: _T,
417417
*args: _P.args,
418418
**kwargs: _P.kwargs,
419419
) -> None:

src/duron/_decorator/durable.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def __init__(
4040
self.codec = codec
4141
self.fn = fn
4242
self.inject = sorted(inject)
43-
self.type_hints = inspect_function(fn)
4443

4544

4645
@overload

src/duron/_decorator/effect.py

Lines changed: 27 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,137 +1,58 @@
11
from __future__ import annotations
22

3-
import functools
4-
from collections.abc import AsyncGenerator
5-
from typing import TYPE_CHECKING, Concatenate, Generic, Literal, get_args, get_origin
6-
from typing_extensions import Any, ParamSpec, TypeVar, final, overload
7-
8-
from duron.typing import UnspecifiedType, inspect_function
3+
from typing import TYPE_CHECKING
4+
from typing_extensions import NamedTuple, ParamSpec, TypeVar, overload
95

106
if TYPE_CHECKING:
11-
from collections.abc import Callable, Coroutine
12-
13-
from duron.typing import TypeHint
7+
from collections.abc import Callable
148

159

16-
_T = TypeVar("_T")
17-
_S = TypeVar("_S")
1810
_T_co = TypeVar("_T_co", covariant=True)
1911
_P = ParamSpec("_P")
2012

2113

22-
@final
23-
class EffectFn(Generic[_P, _T_co]):
24-
def __init__(self, fn: Callable[_P, Coroutine[Any, Any, _T_co]]) -> None:
25-
self.fn = fn
26-
self.type_hint = inspect_function(fn)
27-
functools.update_wrapper(self, fn)
28-
29-
def __call__(
30-
self, *args: _P.args, **kwargs: _P.kwargs
31-
) -> Coroutine[Any, Any, _T_co]:
32-
return self.fn(*args, **kwargs)
33-
34-
35-
@final
36-
class StatefulFn(Generic[_P, _S, _T]):
37-
def __init__(
38-
self,
39-
fn: Callable[Concatenate[_S, _P], AsyncGenerator[_T, _S]],
40-
initial: Callable[[], _S],
41-
reducer: Callable[[_S, _T], _S],
42-
) -> None:
43-
self.fn = fn
44-
self.type_hint = inspect_function(fn)
45-
self.initial = initial
46-
self.reducer = reducer
47-
48-
action_type: TypeHint[_T] = UnspecifiedType
49-
if get_origin(ret := self.type_hint.return_type) is AsyncGenerator:
50-
action_type, _ = get_args(ret)
51-
self.action_type = action_type
52-
functools.update_wrapper(self, fn)
53-
54-
def __call__(
55-
self, state: _S, *args: _P.args, **kwargs: _P.kwargs
56-
) -> AsyncGenerator[_T, _S]:
57-
return self.fn(state, *args, **kwargs)
14+
class Reducer(NamedTuple):
15+
"""Annotation to mark a parameter as a reducer."""
16+
17+
reducer: Callable[[object, object], object]
5818

5919

6020
@overload
61-
def effect(fn: Callable[_P, Coroutine[Any, Any, _T_co]], /) -> EffectFn[_P, _T_co]: ...
62-
@overload
63-
def effect(
64-
*, stateful: Literal[False] = ...
65-
) -> Callable[[Callable[_P, Coroutine[Any, Any, _T_co]]], EffectFn[_P, _T_co]]: ...
21+
def effect(fn: Callable[_P, _T_co], /) -> Callable[_P, _T_co]: ...
6622
@overload
23+
def effect() -> Callable[[Callable[_P, _T_co]], Callable[_P, _T_co]]: ...
6724
def effect(
68-
*,
69-
stateful: Literal[True],
70-
initial: Callable[[], _S],
71-
reducer: Callable[[_S, _T], _S],
72-
) -> Callable[
73-
[Callable[Concatenate[_S, _P], AsyncGenerator[_T, _S]]], StatefulFn[_P, _S, _T]
74-
]: ...
75-
def effect(
76-
fn: Callable[_P, Coroutine[Any, Any, _T_co]] | None = None,
77-
/,
78-
*,
79-
# stateful parameters
80-
stateful: bool = False,
81-
initial: Callable[[], _S] | None = None,
82-
reducer: Callable[[_S, _T], _S] | None = None,
83-
) -> (
84-
EffectFn[_P, _T_co]
85-
| Callable[[Callable[_P, Coroutine[Any, Any, _T_co]]], EffectFn[_P, _T_co]]
86-
| Callable[
87-
[Callable[Concatenate[_S, _P], AsyncGenerator[_T, _S]]], StatefulFn[_P, _S, _T]
88-
]
89-
):
25+
fn: Callable[_P, _T_co] | None = None, /
26+
) -> Callable[_P, _T_co] | Callable[[Callable[_P, _T_co]], Callable[_P, _T_co]]:
9027
"""Decorator to mark async functions as effects.
9128
9229
Effects are operations that interact with the outside world.
9330
94-
Args:
95-
stateful: Whether the effect is stateful.
96-
initial: Factory function for initial state (required with `stateful=True`)
97-
reducer: Function to reduce actions into state (required with `stateful=True`)
98-
9931
Example:
100-
Basic example:
10132
```python
10233
@duron.effect
103-
async def fetch_data(url: str) -> dict:
104-
return await http_client.get(url)
105-
```
34+
async def send_email(to: str, subject: str, body: str) -> None:
35+
# Send an email
36+
...
10637
107-
Stateful example:
108-
```python
109-
@duron.effect(stateful=True, initial=lambda: 0, reducer=int.__add__)
110-
async def count_items(state: int, items: list) -> AsyncGenerator[int, int]:
111-
# restore based on `state`
112-
for item in items:
113-
yield 1
38+
39+
@duron.effect
40+
async def counter(
41+
state: Annotated[int, duron.Reducer(lambda s, a: s + a)], increment: int
42+
) -> AsyncGenerator[int, int]:
43+
state += increment
44+
yield state
11445
```
11546
47+
11648
Returns:
11749
Function wrapper that can be invoked with [ctx.run()][duron.Context.run]
118-
119-
Raises:
120-
ValueError: If stateful is True but initial or reducer is not provided.
12150
"""
122-
if fn is not None:
123-
return EffectFn(fn)
12451

125-
if stateful:
126-
if not initial or not reducer:
127-
msg = "initial and reducer must be provided for stateful ops"
128-
raise ValueError(msg)
129-
130-
def decorate_stateful(
131-
fn: Callable[Concatenate[_S, _P], AsyncGenerator[_T, _S]],
132-
) -> StatefulFn[_P, _S, _T]:
133-
return StatefulFn(fn, initial, reducer)
52+
if fn is not None:
53+
return fn
13454

135-
return decorate_stateful
55+
def decorate(fn: Callable[_P, _T_co]) -> Callable[_P, _T_co]:
56+
return fn
13657

137-
return EffectFn
58+
return decorate

0 commit comments

Comments
 (0)