Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions redis/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@

if TYPE_CHECKING:
import redis.asyncio.client
import redis.asyncio.cluster
import redis.client
import redis.cluster


class ACLCommands(CommandsProtocol):
Expand Down Expand Up @@ -5914,7 +5916,11 @@ class Script:
An executable Lua script object returned by ``register_script``
"""

def __init__(self, registered_client: "redis.client.Redis", script: ScriptTextT):
def __init__(
self,
registered_client: Union["redis.client.Redis", "redis.cluster.RedisCluster"],
script: ScriptTextT,
):
self.registered_client = registered_client
self.script = script
# Precalculate and store the SHA1 hex digest of the script.
Expand All @@ -5930,7 +5936,7 @@ def __call__(
self,
keys: Union[Sequence[KeyT], None] = None,
args: Union[Iterable[EncodableT], None] = None,
client: Union["redis.client.Redis", None] = None,
client: Union["redis.client.Redis", "redis.cluster.RedisCluster", None] = None,
):
"""Execute the script, passing any required ``args``"""
keys = keys or []
Expand Down Expand Up @@ -5979,7 +5985,9 @@ class AsyncScript:

def __init__(
self,
registered_client: "redis.asyncio.client.Redis",
registered_client: Union[
"redis.asyncio.client.Redis", "redis.asyncio.cluster.RedisCluster"
],
script: ScriptTextT,
):
self.registered_client = registered_client
Expand All @@ -6001,7 +6009,9 @@ async def __call__(
self,
keys: Union[Sequence[KeyT], None] = None,
args: Union[Iterable[EncodableT], None] = None,
client: Union["redis.asyncio.client.Redis", None] = None,
client: Union[
"redis.asyncio.client.Redis", "redis.asyncio.cluster.RedisCluster", None
] = None,
):
"""Execute the script, passing any required ``args``"""
keys = keys or []
Expand Down Expand Up @@ -6234,7 +6244,10 @@ def script_load(self, script: ScriptTextT) -> ResponseT:
"""
return self.execute_command("SCRIPT LOAD", script)

def register_script(self: "redis.client.Redis", script: ScriptTextT) -> Script:
def register_script(
self: Union["redis.client.Redis", "redis.cluster.RedisCluster"],
script: ScriptTextT,
) -> Script:
"""
Register a Lua ``script`` specifying the ``keys`` it will touch.
Returns a Script object that is callable and hides the complexity of
Expand All @@ -6249,7 +6262,9 @@ async def script_debug(self, *args) -> None:
return super().script_debug()

def register_script(
self: "redis.asyncio.client.Redis",
self: Union[
"redis.asyncio.client.Redis", "redis.asyncio.cluster.RedisCluster"
],
script: ScriptTextT,
) -> AsyncScript:
"""
Expand Down
24 changes: 24 additions & 0 deletions tests/test_asyncio/test_scripting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
import pytest_asyncio
from redis import exceptions
from redis.asyncio.cluster import RedisCluster
from redis.commands.core import AsyncScript
from tests.conftest import skip_if_server_version_lt

multiply_script = """
Expand Down Expand Up @@ -150,3 +152,25 @@ async def test_eval_msgpack_pipeline_error_in_lua(self, r):
with pytest.raises(exceptions.ResponseError) as excinfo:
await pipe.execute()
assert excinfo.type == exceptions.ResponseError


class TestAsyncScriptTypeHints:
"""Tests for AsyncScript type hints with RedisCluster support."""

@pytest.mark.asyncio()
async def test_async_script_with_cluster_client(self):
"""Test that AsyncScript class accepts RedisCluster as registered_client.

This verifies the type hints fix for register_script to support RedisCluster.
We use a mock-like approach since we don't need actual cluster connection.
"""
from unittest.mock import MagicMock

# Create a mock RedisCluster instance
mock_cluster = MagicMock(spec=RedisCluster)
test_script = "return 1"

# AsyncScript should accept RedisCluster without type errors
script = AsyncScript(mock_cluster, test_script)
assert script.registered_client is mock_cluster
assert script.script == test_script
17 changes: 17 additions & 0 deletions tests/test_scripting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import redis
from redis import exceptions
from redis.cluster import RedisCluster
from redis.commands.core import Script
from tests.conftest import skip_if_redis_enterprise, skip_if_server_version_lt

Expand Down Expand Up @@ -53,6 +54,22 @@ def test_encoder(self, r, script_bytes):
assert encoder is not None
assert encoder.encode("fake-script") == b"fake-script"

def test_script_with_cluster_client(self, script_str):
"""Test that Script class accepts RedisCluster as registered_client.

This verifies the type hints fix for register_script to support RedisCluster.
We use a mock-like approach since we don't need actual cluster connection.
"""
from unittest.mock import MagicMock

# Create a mock RedisCluster instance
mock_cluster = MagicMock(spec=RedisCluster)

# Script should accept RedisCluster without type errors
script = Script(mock_cluster, script_str)
assert script.registered_client is mock_cluster
assert script.script == script_str


class TestScripting:
@pytest.fixture(autouse=True)
Expand Down