diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index 90ac49d..9901619 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -11,6 +11,7 @@ import socket import ssl import warnings +from contextlib import suppress from unittest.mock import AsyncMock from hashlib import blake2b from typing import ( @@ -1211,6 +1212,8 @@ def __init__( self.metadata_version_hex = "0x0f000000" # v15 self._initializing = False self._mock = _mock + self.startup_runtime_task: Optional[asyncio.Task] = None + self.startup_block_hash: Optional[str] = None async def __aenter__(self): if not self._mock: @@ -1230,8 +1233,12 @@ async def _initialize(self) -> None: if not self._chain: chain = await self.rpc_request("system_chain", []) self._chain = chain.get("result") - runtime = await self.init_runtime() + self.startup_block_hash = block_hash = await self.get_chain_head() + self.startup_runtime_task = asyncio.create_task( + self.init_runtime(block_hash=block_hash, init=True) + ) if self.ss58_format is None: + runtime = await self.init_runtime(block_hash) # Check and apply runtime constants ss58_prefix_constant = await self.get_constant( "System", "SS58Prefix", runtime=runtime @@ -1438,7 +1445,10 @@ async def decode_scale( return obj async def init_runtime( - self, block_hash: Optional[str] = None, block_id: Optional[int] = None + self, + block_hash: Optional[str] = None, + block_id: Optional[int] = None, + init: bool = False, ) -> Runtime: """ This method is used by all other methods that deals with metadata and types defined in the type registry. @@ -1455,6 +1465,13 @@ async def init_runtime( Returns: Runtime object """ + if ( + not init + and self.startup_runtime_task is not None + and block_hash == self.startup_block_hash + ): + await self.startup_runtime_task + self.startup_runtime_task = None if block_id and block_hash: raise ValueError("Cannot provide block_hash and block_id at the same time") @@ -4322,6 +4339,10 @@ async def close(self): Closes the substrate connection, and the websocket connection. """ try: + if self.startup_runtime_task is not None: + self.startup_runtime_task.cancel() + with suppress(asyncio.CancelledError): + await self.startup_runtime_task await self.ws.shutdown() except AttributeError: pass diff --git a/tests/integration_tests/test_disk_cache.py b/tests/integration_tests/test_disk_cache.py index 063eca1..11e7312 100644 --- a/tests/integration_tests/test_disk_cache.py +++ b/tests/integration_tests/test_disk_cache.py @@ -88,44 +88,44 @@ async def test_disk_cache(): start = time.monotonic() new_block_hash = await disk_cached_substrate.get_block_hash(current_block) new_time = time.monotonic() - assert new_time - start < 0.001 + assert new_time - start < 0.002 start = time.monotonic() new_parent_block_hash = await disk_cached_substrate.get_parent_block_hash( block_hash ) new_time = time.monotonic() - assert new_time - start < 0.001 + assert new_time - start < 0.002 start = time.monotonic() new_block_runtime_info = await disk_cached_substrate.get_block_runtime_info( block_hash ) new_time = time.monotonic() - assert new_time - start < 0.001 + assert new_time - start < 0.002 start = time.monotonic() new_block_runtime_version_for = ( await disk_cached_substrate.get_block_runtime_version_for(block_hash) ) new_time = time.monotonic() - assert new_time - start < 0.001 + assert new_time - start < 0.002 start = time.monotonic() new_block_hash_from_cache = await disk_cached_substrate.get_block_hash( current_block ) new_time = time.monotonic() - assert new_time - start < 0.001 + assert new_time - start < 0.002 start = time.monotonic() new_parent_block_hash_from_cache = ( await disk_cached_substrate.get_parent_block_hash(block_hash_from_cache) ) new_time = time.monotonic() - assert new_time - start < 0.001 + assert new_time - start < 0.002 start = time.monotonic() new_block_runtime_info_from_cache = ( await disk_cached_substrate.get_block_runtime_info(block_hash_from_cache) ) new_time = time.monotonic() - assert new_time - start < 0.001 + assert new_time - start < 0.002 start = time.monotonic() new_block_runtime_version_from_cache = ( await disk_cached_substrate.get_block_runtime_version_for( @@ -133,5 +133,5 @@ async def test_disk_cache(): ) ) new_time = time.monotonic() - assert new_time - start < 0.001 + assert new_time - start < 0.002 print("Disk Cache tests passed") diff --git a/tests/unit_tests/asyncio_/test_substrate_interface.py b/tests/unit_tests/asyncio_/test_substrate_interface.py index afefe7a..a0ac123 100644 --- a/tests/unit_tests/asyncio_/test_substrate_interface.py +++ b/tests/unit_tests/asyncio_/test_substrate_interface.py @@ -296,7 +296,9 @@ async def test_get_account_next_index_cached_mode_uses_internal_cache(): substrate.supports_rpc_method = AsyncMock(return_value=True) substrate.rpc_request = AsyncMock(return_value={"result": 5}) - first = await substrate.get_account_next_index("5F3sa2TJAWMqDhXG6jhV4N8ko9NoFz5Y2s8vS8uM9f7v7mA") + first = await substrate.get_account_next_index( + "5F3sa2TJAWMqDhXG6jhV4N8ko9NoFz5Y2s8vS8uM9f7v7mA" + ) second = await substrate.get_account_next_index( "5F3sa2TJAWMqDhXG6jhV4N8ko9NoFz5Y2s8vS8uM9f7v7mA" ) @@ -331,7 +333,9 @@ async def test_get_account_next_index_bypass_mode_does_not_create_or_mutate_cach async def test_get_account_next_index_bypass_mode_raises_on_rpc_error(): substrate = AsyncSubstrateInterface("ws://localhost", _mock=True) substrate.supports_rpc_method = AsyncMock(return_value=True) - substrate.rpc_request = AsyncMock(return_value={"error": {"message": "rpc failure"}}) + substrate.rpc_request = AsyncMock( + return_value={"error": {"message": "rpc failure"}} + ) with pytest.raises(SubstrateRequestException, match="rpc failure"): await substrate.get_account_next_index(