diff --git a/skyflow/vault/_client.py b/skyflow/vault/_client.py index 24ba524..cddb351 100644 --- a/skyflow/vault/_client.py +++ b/skyflow/vault/_client.py @@ -3,8 +3,10 @@ ''' import json import types +import typing import requests import asyncio +from concurrent.futures import Future, ThreadPoolExecutor from skyflow.vault._insert import getInsertRequestBody, processResponse, convertResponse from skyflow.vault._update import sendUpdateRequests, createUpdateResponseBody from skyflow.vault._config import Configuration, ConnectionConfig, DeleteOptions, DetokenizeOptions, GetOptions, InsertOptions, UpdateOptions, QueryOptions @@ -86,7 +88,7 @@ def detokenize(self, records: dict, options: DetokenizeOptions = DetokenizeOptio self.storedToken = tokenProviderWrapper( self.storedToken, self.tokenProvider, interface) url = self._get_complete_vault_url() + '/detokenize' - responses = asyncio.run(sendDetokenizeRequests( + responses = run_coro(sendDetokenizeRequests( records, url, self.storedToken, options)) result, partial = createDetokenizeResponseBody(records, responses, options) if partial: @@ -105,7 +107,7 @@ def get(self, records, options: GetOptions = GetOptions()): self.storedToken = tokenProviderWrapper( self.storedToken, self.tokenProvider, interface) url = self._get_complete_vault_url() - responses = asyncio.run(sendGetRequests( + responses = run_coro(sendGetRequests( records, options, url, self.storedToken)) result, partial = createGetResponseBody(responses) if partial: @@ -124,7 +126,7 @@ def get_by_id(self, records): self.storedToken = tokenProviderWrapper( self.storedToken, self.tokenProvider, interface) url = self._get_complete_vault_url() - responses = asyncio.run(sendGetByIdRequests( + responses = run_coro(sendGetByIdRequests( records, url, self.storedToken)) result, partial = createGetResponseBody(responses) if partial: @@ -201,7 +203,7 @@ def update(self, updateInput, options: UpdateOptions = UpdateOptions()): self.storedToken = tokenProviderWrapper( self.storedToken, self.tokenProvider, interface) url = self._get_complete_vault_url() - responses = asyncio.run(sendUpdateRequests( + responses = run_coro(sendUpdateRequests( updateInput, options, url, self.storedToken)) result, partial = createUpdateResponseBody(responses) if partial: @@ -290,4 +292,34 @@ def delete(self, records: dict, options: DeleteOptions = DeleteOptions()): else: log_info(InfoMessages.DELETE_DATA_SUCCESS.value, interface) - return result \ No newline at end of file + return result + + +T = typing.TypeVar('T') + +def run_coro(coro: typing.Coroutine[typing.Any, typing.Any, T]) -> T: + """ + Run a coroutine in a thread pool. This avoids the RuntimeError that occurs + when calling asyncio.run() from a thread that already has an event loop. + + Note that this isn't performant, since it create a new thread with a new + event loop for each call. + + Args: + coro: The coroutine to run. + + Returns: + The result of the coroutine. + """ + + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + + with ThreadPoolExecutor() as executor: + # Must run asyncio.run in a thread. If we don't we'll get the following + # error: + # RuntimeError: asyncio.run() cannot be called from a running event loop + future: Future[T] = executor.submit(asyncio.run, coro) + return future.result() \ No newline at end of file