Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions async_substrate_interface/async_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import socket
import ssl
import warnings
from contextlib import suppress
from unittest.mock import AsyncMock
from hashlib import blake2b
from typing import (
Expand Down Expand Up @@ -1211,6 +1212,8 @@ def __init__(
self.metadata_version_hex = "0x0f000000" # v15
self._initializing = False
self._mock = _mock
self.startup_runtime_task: Optional[asyncio.Task] = None
self.startup_block_hash: Optional[str] = None

async def __aenter__(self):
if not self._mock:
Expand All @@ -1230,8 +1233,12 @@ async def _initialize(self) -> None:
if not self._chain:
chain = await self.rpc_request("system_chain", [])
self._chain = chain.get("result")
runtime = await self.init_runtime()
self.startup_block_hash = block_hash = await self.get_chain_head()
self.startup_runtime_task = asyncio.create_task(
self.init_runtime(block_hash=block_hash, init=True)
)
if self.ss58_format is None:
runtime = await self.init_runtime(block_hash)
# Check and apply runtime constants
ss58_prefix_constant = await self.get_constant(
"System", "SS58Prefix", runtime=runtime
Expand Down Expand Up @@ -1438,7 +1445,10 @@ async def decode_scale(
return obj

async def init_runtime(
self, block_hash: Optional[str] = None, block_id: Optional[int] = None
self,
block_hash: Optional[str] = None,
block_id: Optional[int] = None,
init: bool = False,
) -> Runtime:
"""
This method is used by all other methods that deals with metadata and types defined in the type registry.
Expand All @@ -1455,6 +1465,13 @@ async def init_runtime(
Returns:
Runtime object
"""
if (
not init
and self.startup_runtime_task is not None
and block_hash == self.startup_block_hash
):
await self.startup_runtime_task
self.startup_runtime_task = None

if block_id and block_hash:
raise ValueError("Cannot provide block_hash and block_id at the same time")
Expand Down Expand Up @@ -4322,6 +4339,10 @@ async def close(self):
Closes the substrate connection, and the websocket connection.
"""
try:
if self.startup_runtime_task is not None:
self.startup_runtime_task.cancel()
with suppress(asyncio.CancelledError):
await self.startup_runtime_task
await self.ws.shutdown()
except AttributeError:
pass
Expand Down
16 changes: 8 additions & 8 deletions tests/integration_tests/test_disk_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,50 +88,50 @@ async def test_disk_cache():
start = time.monotonic()
new_block_hash = await disk_cached_substrate.get_block_hash(current_block)
new_time = time.monotonic()
assert new_time - start < 0.001
assert new_time - start < 0.002

start = time.monotonic()
new_parent_block_hash = await disk_cached_substrate.get_parent_block_hash(
block_hash
)
new_time = time.monotonic()
assert new_time - start < 0.001
assert new_time - start < 0.002
start = time.monotonic()
new_block_runtime_info = await disk_cached_substrate.get_block_runtime_info(
block_hash
)
new_time = time.monotonic()
assert new_time - start < 0.001
assert new_time - start < 0.002
start = time.monotonic()
new_block_runtime_version_for = (
await disk_cached_substrate.get_block_runtime_version_for(block_hash)
)
new_time = time.monotonic()
assert new_time - start < 0.001
assert new_time - start < 0.002
start = time.monotonic()
new_block_hash_from_cache = await disk_cached_substrate.get_block_hash(
current_block
)
new_time = time.monotonic()
assert new_time - start < 0.001
assert new_time - start < 0.002
start = time.monotonic()
new_parent_block_hash_from_cache = (
await disk_cached_substrate.get_parent_block_hash(block_hash_from_cache)
)
new_time = time.monotonic()
assert new_time - start < 0.001
assert new_time - start < 0.002
start = time.monotonic()
new_block_runtime_info_from_cache = (
await disk_cached_substrate.get_block_runtime_info(block_hash_from_cache)
)
new_time = time.monotonic()
assert new_time - start < 0.001
assert new_time - start < 0.002
start = time.monotonic()
new_block_runtime_version_from_cache = (
await disk_cached_substrate.get_block_runtime_version_for(
block_hash_from_cache
)
)
new_time = time.monotonic()
assert new_time - start < 0.001
assert new_time - start < 0.002
print("Disk Cache tests passed")
8 changes: 6 additions & 2 deletions tests/unit_tests/asyncio_/test_substrate_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,9 @@ async def test_get_account_next_index_cached_mode_uses_internal_cache():
substrate.supports_rpc_method = AsyncMock(return_value=True)
substrate.rpc_request = AsyncMock(return_value={"result": 5})

first = await substrate.get_account_next_index("5F3sa2TJAWMqDhXG6jhV4N8ko9NoFz5Y2s8vS8uM9f7v7mA")
first = await substrate.get_account_next_index(
"5F3sa2TJAWMqDhXG6jhV4N8ko9NoFz5Y2s8vS8uM9f7v7mA"
)
second = await substrate.get_account_next_index(
"5F3sa2TJAWMqDhXG6jhV4N8ko9NoFz5Y2s8vS8uM9f7v7mA"
)
Expand Down Expand Up @@ -331,7 +333,9 @@ async def test_get_account_next_index_bypass_mode_does_not_create_or_mutate_cach
async def test_get_account_next_index_bypass_mode_raises_on_rpc_error():
substrate = AsyncSubstrateInterface("ws://localhost", _mock=True)
substrate.supports_rpc_method = AsyncMock(return_value=True)
substrate.rpc_request = AsyncMock(return_value={"error": {"message": "rpc failure"}})
substrate.rpc_request = AsyncMock(
return_value={"error": {"message": "rpc failure"}}
)

with pytest.raises(SubstrateRequestException, match="rpc failure"):
await substrate.get_account_next_index(
Expand Down
Loading