diff --git a/redis/commands/core.py b/redis/commands/core.py index 6e1af05635..9793809858 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -56,7 +56,9 @@ if TYPE_CHECKING: import redis.asyncio.client + import redis.asyncio.cluster import redis.client + import redis.cluster class ACLCommands(CommandsProtocol): @@ -5914,7 +5916,11 @@ class Script: An executable Lua script object returned by ``register_script`` """ - def __init__(self, registered_client: "redis.client.Redis", script: ScriptTextT): + def __init__( + self, + registered_client: Union["redis.client.Redis", "redis.cluster.RedisCluster"], + script: ScriptTextT, + ): self.registered_client = registered_client self.script = script # Precalculate and store the SHA1 hex digest of the script. @@ -5930,7 +5936,7 @@ def __call__( self, keys: Union[Sequence[KeyT], None] = None, args: Union[Iterable[EncodableT], None] = None, - client: Union["redis.client.Redis", None] = None, + client: Union["redis.client.Redis", "redis.cluster.RedisCluster", None] = None, ): """Execute the script, passing any required ``args``""" keys = keys or [] @@ -5979,7 +5985,9 @@ class AsyncScript: def __init__( self, - registered_client: "redis.asyncio.client.Redis", + registered_client: Union[ + "redis.asyncio.client.Redis", "redis.asyncio.cluster.RedisCluster" + ], script: ScriptTextT, ): self.registered_client = registered_client @@ -6001,7 +6009,9 @@ async def __call__( self, keys: Union[Sequence[KeyT], None] = None, args: Union[Iterable[EncodableT], None] = None, - client: Union["redis.asyncio.client.Redis", None] = None, + client: Union[ + "redis.asyncio.client.Redis", "redis.asyncio.cluster.RedisCluster", None + ] = None, ): """Execute the script, passing any required ``args``""" keys = keys or [] @@ -6234,7 +6244,10 @@ def script_load(self, script: ScriptTextT) -> ResponseT: """ return self.execute_command("SCRIPT LOAD", script) - def register_script(self: "redis.client.Redis", script: ScriptTextT) -> Script: + def register_script( + self: Union["redis.client.Redis", "redis.cluster.RedisCluster"], + script: ScriptTextT, + ) -> Script: """ Register a Lua ``script`` specifying the ``keys`` it will touch. Returns a Script object that is callable and hides the complexity of @@ -6249,7 +6262,9 @@ async def script_debug(self, *args) -> None: return super().script_debug() def register_script( - self: "redis.asyncio.client.Redis", + self: Union[ + "redis.asyncio.client.Redis", "redis.asyncio.cluster.RedisCluster" + ], script: ScriptTextT, ) -> AsyncScript: """ diff --git a/tests/test_asyncio/test_scripting.py b/tests/test_asyncio/test_scripting.py index b8e100c04a..95e0c80e85 100644 --- a/tests/test_asyncio/test_scripting.py +++ b/tests/test_asyncio/test_scripting.py @@ -1,6 +1,8 @@ import pytest import pytest_asyncio from redis import exceptions +from redis.asyncio.cluster import RedisCluster +from redis.commands.core import AsyncScript from tests.conftest import skip_if_server_version_lt multiply_script = """ @@ -150,3 +152,25 @@ async def test_eval_msgpack_pipeline_error_in_lua(self, r): with pytest.raises(exceptions.ResponseError) as excinfo: await pipe.execute() assert excinfo.type == exceptions.ResponseError + + +class TestAsyncScriptTypeHints: + """Tests for AsyncScript type hints with RedisCluster support.""" + + @pytest.mark.asyncio() + async def test_async_script_with_cluster_client(self): + """Test that AsyncScript class accepts RedisCluster as registered_client. + + This verifies the type hints fix for register_script to support RedisCluster. + We use a mock-like approach since we don't need actual cluster connection. + """ + from unittest.mock import MagicMock + + # Create a mock RedisCluster instance + mock_cluster = MagicMock(spec=RedisCluster) + test_script = "return 1" + + # AsyncScript should accept RedisCluster without type errors + script = AsyncScript(mock_cluster, test_script) + assert script.registered_client is mock_cluster + assert script.script == test_script diff --git a/tests/test_scripting.py b/tests/test_scripting.py index 899dc69482..ca8eeac5e4 100644 --- a/tests/test_scripting.py +++ b/tests/test_scripting.py @@ -1,6 +1,7 @@ import pytest import redis from redis import exceptions +from redis.cluster import RedisCluster from redis.commands.core import Script from tests.conftest import skip_if_redis_enterprise, skip_if_server_version_lt @@ -53,6 +54,22 @@ def test_encoder(self, r, script_bytes): assert encoder is not None assert encoder.encode("fake-script") == b"fake-script" + def test_script_with_cluster_client(self, script_str): + """Test that Script class accepts RedisCluster as registered_client. + + This verifies the type hints fix for register_script to support RedisCluster. + We use a mock-like approach since we don't need actual cluster connection. + """ + from unittest.mock import MagicMock + + # Create a mock RedisCluster instance + mock_cluster = MagicMock(spec=RedisCluster) + + # Script should accept RedisCluster without type errors + script = Script(mock_cluster, script_str) + assert script.registered_client is mock_cluster + assert script.script == script_str + class TestScripting: @pytest.fixture(autouse=True)