|
5 | 5 | import concurrent.futures |
6 | 6 | import contextvars |
7 | 7 | import functools |
8 | | -import inspect |
9 | 8 | import threading |
10 | 9 | import time |
11 | 10 | import warnings |
|
15 | 14 | from types import TracebackType |
16 | 15 | from typing import ( |
17 | 16 | Any, |
18 | | - AsyncIterator, |
19 | 17 | Callable, |
20 | 18 | Coroutine, |
21 | 19 | Deque, |
22 | | - Generic, |
23 | | - Iterator, |
24 | | - List, |
25 | | - MutableSet, |
26 | 20 | Optional, |
27 | 21 | Set, |
28 | 22 | Type, |
|
31 | 25 | cast, |
32 | 26 | ) |
33 | 27 |
|
34 | | -from .concurrent import is_threaded_callable |
35 | | -from .utils.inspect import ensure_coroutine |
36 | | - |
37 | 28 | _T = TypeVar("_T") |
38 | 29 |
|
39 | | -# TODO try to use new ParamSpec feature in Python 3.10 |
40 | | - |
41 | | -_TResult = TypeVar("_TResult") |
42 | | -_TCallable = TypeVar("_TCallable", bound=Callable[..., Any]) |
43 | | - |
44 | | - |
45 | | -class AsyncEventResultIteratorBase(Generic[_TCallable, _TResult]): |
46 | | - def __init__(self) -> None: |
47 | | - self._lock = threading.RLock() |
48 | | - |
49 | | - self._listeners: MutableSet[weakref.ref[Any]] = set() |
50 | | - self._loop = asyncio.get_event_loop() |
51 | | - |
52 | | - def add(self, callback: _TCallable) -> None: |
53 | | - def remove_listener(ref: Any) -> None: |
54 | | - with self._lock: |
55 | | - self._listeners.remove(ref) |
56 | | - |
57 | | - with self._lock: |
58 | | - if inspect.ismethod(callback): |
59 | | - self._listeners.add(weakref.WeakMethod(callback, remove_listener)) |
60 | | - else: |
61 | | - self._listeners.add(weakref.ref(callback, remove_listener)) |
62 | | - |
63 | | - def remove(self, callback: _TCallable) -> None: |
64 | | - with self._lock: |
65 | | - try: |
66 | | - if inspect.ismethod(callback): |
67 | | - self._listeners.remove(weakref.WeakMethod(callback)) |
68 | | - else: |
69 | | - self._listeners.remove(weakref.ref(callback)) |
70 | | - except KeyError: |
71 | | - pass |
72 | | - |
73 | | - def __contains__(self, obj: Any) -> bool: |
74 | | - if inspect.ismethod(obj): |
75 | | - return weakref.WeakMethod(obj) in self._listeners |
76 | | - |
77 | | - return weakref.ref(obj) in self._listeners |
78 | | - |
79 | | - def __len__(self) -> int: |
80 | | - return len(self._listeners) |
81 | | - |
82 | | - def __iter__(self) -> Iterator[_TCallable]: |
83 | | - for r in self._listeners: |
84 | | - c = r() |
85 | | - if c is not None: |
86 | | - yield c |
87 | | - |
88 | | - async def __aiter__(self) -> AsyncIterator[_TCallable]: |
89 | | - for r in self.__iter__(): |
90 | | - yield r |
91 | | - |
92 | | - async def _notify( |
93 | | - self, |
94 | | - *args: Any, |
95 | | - callback_filter: Optional[Callable[[_TCallable], bool]] = None, |
96 | | - **kwargs: Any, |
97 | | - ) -> AsyncIterator[_TResult]: |
98 | | - for method in filter( |
99 | | - lambda x: callback_filter(x) if callback_filter is not None else True, |
100 | | - set(self), |
101 | | - ): |
102 | | - result = method(*args, **kwargs) |
103 | | - if inspect.isawaitable(result): |
104 | | - result = await result |
105 | | - |
106 | | - yield result |
107 | | - |
108 | | - |
109 | | -class AsyncEventIterator(AsyncEventResultIteratorBase[_TCallable, _TResult]): |
110 | | - def __call__(self, *args: Any, **kwargs: Any) -> AsyncIterator[_TResult]: |
111 | | - return self._notify(*args, **kwargs) |
112 | | - |
113 | | - |
114 | | -class AsyncEvent(AsyncEventResultIteratorBase[_TCallable, _TResult]): |
115 | | - async def __call__(self, *args: Any, **kwargs: Any) -> List[_TResult]: |
116 | | - return [a async for a in self._notify(*args, **kwargs)] |
117 | | - |
118 | | - |
119 | | -_TEvent = TypeVar("_TEvent") |
120 | | - |
121 | | - |
122 | | -class AsyncEventDescriptorBase(Generic[_TCallable, _TResult, _TEvent]): |
123 | | - def __init__( |
124 | | - self, |
125 | | - _func: _TCallable, |
126 | | - factory: Callable[..., _TEvent], |
127 | | - *factory_args: Any, |
128 | | - **factory_kwargs: Any, |
129 | | - ) -> None: |
130 | | - self._func = _func |
131 | | - self.__factory = factory |
132 | | - self.__factory_args = factory_args |
133 | | - self.__factory_kwargs = factory_kwargs |
134 | | - self._owner: Optional[Any] = None |
135 | | - self._owner_name: Optional[str] = None |
136 | | - |
137 | | - def __set_name__(self, owner: Any, name: str) -> None: |
138 | | - self._owner = owner |
139 | | - self._owner_name = name |
140 | | - |
141 | | - def __get__(self, obj: Any, objtype: Type[Any]) -> _TEvent: |
142 | | - if obj is None: |
143 | | - return self # type: ignore |
144 | | - |
145 | | - name = f"__async_event_{self._func.__name__}__" |
146 | | - if not hasattr(obj, name): |
147 | | - setattr( |
148 | | - obj, |
149 | | - name, |
150 | | - self.__factory(*self.__factory_args, **self.__factory_kwargs), |
151 | | - ) |
152 | | - |
153 | | - return cast("_TEvent", getattr(obj, name)) |
154 | | - |
155 | | - |
156 | | -class async_event_iterator( # noqa: N801 |
157 | | - AsyncEventDescriptorBase[_TCallable, Any, AsyncEventIterator[_TCallable, Any]] |
158 | | -): |
159 | | - def __init__(self, _func: _TCallable) -> None: |
160 | | - super().__init__(_func, AsyncEventIterator[_TCallable, Any]) |
161 | | - |
162 | | - |
163 | | -class async_event(AsyncEventDescriptorBase[_TCallable, Any, AsyncEvent[_TCallable, Any]]): # noqa: N801 |
164 | | - def __init__(self, _func: _TCallable) -> None: |
165 | | - super().__init__(_func, AsyncEvent[_TCallable, Any]) |
166 | | - |
167 | | - |
168 | | -class AsyncTaskingEventResultIteratorBase(AsyncEventResultIteratorBase[_TCallable, _TResult]): |
169 | | - def __init__(self, *, task_name_prefix: Optional[str] = None) -> None: |
170 | | - super().__init__() |
171 | | - self._task_name_prefix = task_name_prefix or type(self).__qualname__ |
172 | | - |
173 | | - async def _notify( # type: ignore |
174 | | - self, |
175 | | - *args: Any, |
176 | | - result_callback: Optional[Callable[[Optional[_TResult], Optional[BaseException]], Any]] = None, |
177 | | - return_exceptions: Optional[bool] = True, |
178 | | - callback_filter: Optional[Callable[[_TCallable], bool]] = None, |
179 | | - threaded: Optional[bool] = True, |
180 | | - **kwargs: Any, |
181 | | - ) -> AsyncIterator[Union[_TResult, BaseException]]: |
182 | | - def _done(f: asyncio.Future[_TResult]) -> None: |
183 | | - if result_callback is not None: |
184 | | - try: |
185 | | - result_callback(f.result(), f.exception()) |
186 | | - except (SystemExit, KeyboardInterrupt): |
187 | | - raise |
188 | | - except BaseException as e: |
189 | | - result_callback(None, e) |
190 | | - |
191 | | - awaitables: List[asyncio.Future[_TResult]] = [] |
192 | | - for method in filter( |
193 | | - lambda x: callback_filter(x) if callback_filter is not None else True, |
194 | | - set(self), |
195 | | - ): |
196 | | - if method is not None: |
197 | | - if threaded and is_threaded_callable(method): |
198 | | - future = run_coroutine_in_thread(ensure_coroutine(method), *args, **kwargs) |
199 | | - else: |
200 | | - future = create_sub_task(ensure_coroutine(method)(*args, **kwargs)) |
201 | | - awaitables.append(future) |
202 | | - |
203 | | - if result_callback is not None: |
204 | | - future.add_done_callback(_done) |
205 | | - |
206 | | - for a in asyncio.as_completed(awaitables): |
207 | | - try: |
208 | | - yield await a |
209 | | - |
210 | | - except (SystemExit, KeyboardInterrupt): |
211 | | - raise |
212 | | - except BaseException as e: |
213 | | - if return_exceptions: |
214 | | - yield e |
215 | | - else: |
216 | | - raise |
217 | | - |
218 | | - |
219 | | -class AsyncTaskingEventIterator(AsyncTaskingEventResultIteratorBase[_TCallable, _TResult]): |
220 | | - def __call__(self, *args: Any, **kwargs: Any) -> AsyncIterator[Union[_TResult, BaseException]]: |
221 | | - return self._notify(*args, **kwargs) |
222 | | - |
223 | | - |
224 | | -def _get_name_prefix( |
225 | | - descriptor: AsyncEventDescriptorBase[Any, Any, Any], |
226 | | -) -> str: |
227 | | - if descriptor._owner is None: |
228 | | - return type(descriptor).__qualname__ |
229 | | - |
230 | | - return f"{descriptor._owner.__qualname__}.{descriptor._owner_name}" |
231 | | - |
232 | | - |
233 | | -class AsyncTaskingEvent(AsyncTaskingEventResultIteratorBase[_TCallable, _TResult]): |
234 | | - async def __call__(self, *args: Any, **kwargs: Any) -> List[Union[_TResult, BaseException]]: |
235 | | - return [a async for a in self._notify(*args, **kwargs)] |
236 | | - |
237 | | - |
238 | | -class async_tasking_event_iterator( # noqa: N801 |
239 | | - AsyncEventDescriptorBase[_TCallable, Any, AsyncTaskingEventIterator[_TCallable, Any]] |
240 | | -): |
241 | | - def __init__(self, _func: _TCallable) -> None: |
242 | | - super().__init__( |
243 | | - _func, |
244 | | - AsyncTaskingEventIterator[_TCallable, Any], |
245 | | - task_name_prefix=lambda: _get_name_prefix(self), |
246 | | - ) |
247 | | - |
248 | | - |
249 | | -class async_tasking_event(AsyncEventDescriptorBase[_TCallable, Any, AsyncTaskingEvent[_TCallable, Any]]): # noqa: N801 |
250 | | - def __init__(self, _func: _TCallable) -> None: |
251 | | - super().__init__( |
252 | | - _func, |
253 | | - AsyncTaskingEvent[_TCallable, Any], |
254 | | - task_name_prefix=lambda: _get_name_prefix(self), |
255 | | - ) |
256 | | - |
257 | | - |
258 | | -async def check_canceled() -> bool: |
259 | | - await asyncio.sleep(0) |
260 | | - |
261 | | - return True |
262 | | - |
263 | 30 |
|
264 | 31 | def check_canceled_sync() -> bool: |
265 | 32 | info = get_current_future_info() |
|
0 commit comments