Skip to content

Commit d672e6a

Browse files
author
DvirDukhan
committed
test pass
1 parent 3891559 commit d672e6a

File tree

4 files changed

+64
-51
lines changed

4 files changed

+64
-51
lines changed

src/libtorch_c/torch_extensions/torch_redis.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,8 @@ torch::IValue redisExecute(std::string fn_name, std::vector<std::string> args )
5959
RedisModule_FreeCallReply(reply);
6060
return value;
6161
}
62+
63+
64+
torch::List<torch::IValue> asList(torch::IValue v) {
65+
return v.toList();
66+
}

src/libtorch_c/torch_extensions/torch_redis.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
#include "torch/script.h"
33
#include "torch/csrc/jit/frontend/resolver.h"
44

5-
#include "torch_redis_value.h"
6-
75
namespace torch {
86
namespace jit {
97
namespace script {
@@ -35,7 +33,7 @@ namespace torch {
3533
// c10::intrusive_ptr<RedisValue> redisExecute(std::string fn_name, std::vector<std::string> args );
3634

3735
torch::IValue redisExecute(std::string fn_name, std::vector<std::string> args );
36+
torch::List<torch::IValue> asList(torch::IValue);
3837

39-
40-
41-
static auto registry = torch::RegisterOperators("redis::execute", &redisExecute);
38+
static auto registry = torch::RegisterOperators("redis::execute", &redisExecute).op("redis::asList", &asList);
39+
// registry = torch::RegisterOperators("torch::asList", &asList);

tests/flow/test_data/redis_scripts.py

Lines changed: 31 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,68 +7,55 @@ def redis_string_float_to_tensor(redis_value: Any):
77
return torch.tensor(float(str((redis_value))))
88

99

10-
# def redis_int_to_tensor(redis_value: RedisValue):
11-
# return tensor(redis_value.intValue())
10+
def redis_int_to_tensor(redis_value: int):
11+
return torch.tensor(redis_value)
1212

1313

14-
# def redis_int_list_to_tensor(redis_value: RedisValue):
15-
# len = len(redis_value.getList())
16-
# l = []
17-
# for v in redis_value.getList():
18-
# l.append(redis_string_to_int(v))
19-
# return torch.cat(l, dim=0)
14+
def redis_int_list_to_tensor(redis_value: Any):
15+
values = redis.asList(redis_value)
16+
l = [torch.tensor(int(str(v))).reshape(1,1) for v in values]
17+
return torch.cat(l, dim=0)
2018

2119

22-
# def redis_float_list_to_tensor(redis_value: RedisValue):
23-
# len = len(redis_value.getList())
24-
# l = []
25-
# for v in redis_value.getList():
26-
# l.append(redis_string_to_float(v))
27-
# return torch.cat(l, dim=0)
20+
def redis_hash_to_tensor(redis_value: Any):
21+
values = redis.asList(redis_value)
22+
l = [torch.tensor(int(str(v))).reshape(1,1) for v in values]
23+
return torch.cat(l, dim=0)
2824

29-
30-
# def redis_hash_to_tensor(redis_value: RedisValue):
31-
# len = len(redis_value.getList())
32-
# l = []
33-
# for v in redis_value.getList():
34-
# l.append(redis_string_to_float(v.getList()[1]))
35-
# return torch.cat(l, dim=0)
36-
37-
# def test_redis_error():
38-
# res = redis.executeCommand("SET", "x")
39-
# return tensor(res.getValueType())
25+
def test_redis_error():
26+
redis.execute("SET", "x")
4027

4128
def test_int_set_get():
4229
redis.execute("SET", "x", "1")
4330
res = redis.execute("GET", "x",)
4431
redis.execute("DEL", "x")
4532
return redis_string_int_to_tensor(res)
4633

34+
def test_int_set_incr():
35+
redis.execute("SET", "x", "1")
36+
res = redis.execute("INCR", "x")
37+
redis.execute("DEL", "x")
38+
return redis_string_int_to_tensor(res)
39+
4740
def test_float_set_get():
4841
redis.execute("SET", "x", "1.1")
4942
res = redis.execute("GET", "x",)
5043
redis.execute("DEL", "x")
5144
return redis_string_float_to_tensor(res)
5245

53-
# def test_int_list():
54-
# redis.executeCommand("LPUSH", "x", "1")
55-
# redis.executeCommand("LPUSH", "x", "2")
56-
# res = redis.executeCommand("LRANGE", "x")
57-
# redis.executeCommand("DEL", "x")
58-
# return redis_int_list_to_tensor(res)
59-
60-
# def test_float_list():
61-
# redis.executeCommand("LPUSH", "x", "1.1")
62-
# redis.executeCommand("LPUSH", "x", "2.2")
63-
# res = redis.executeCommand("LRANGE", "x")
64-
# redis.executeCommand("DEL", "x")
65-
# return redis_float_list_to_tensor(res)
66-
67-
# def test_hash():
68-
# redis.executeCommand("HSET", "x", "1", "2.2)
69-
# res = redis.executeCommand("HGETALL", "x")
70-
# redis.executeCommand("DEL", "x")
71-
# return redis_float_list_to_tensor(res)
46+
def test_int_list():
47+
redis.execute("RPUSH", "x", "1")
48+
redis.execute("RPUSH", "x", "2")
49+
res = redis.execute("LRANGE", "x", "0", "2")
50+
redis.execute("DEL", "x")
51+
return redis_int_list_to_tensor(res)
52+
53+
54+
def test_hash():
55+
redis.execute("HSET", "x", "field1", "1", "field2", "2")
56+
res = redis.execute("HVALS", "x")
57+
redis.execute("DEL", "x")
58+
return redis_hash_to_tensor(res)
7259

7360

7461
def test_set_key():

tests/flow/test_torchscript_extensions.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,29 @@ def __init__(self):
2929
self.env.assertEqual(ret, b'OK')
3030
# self.env.ensureSlaveSynced(self.con, self.env)
3131

32+
def test_redis_error(self):
33+
try:
34+
self.con.execute_command(
35+
'AI.SCRIPTRUN', 'redis_scripts', 'test_redis_error')
36+
self.env.assertTrue(False)
37+
except:
38+
pass
39+
3240
def test_simple_test_set(self):
3341
self.con.execute_command(
3442
'AI.SCRIPTRUN', 'redis_scripts', 'test_set_key')
3543
self.env.assertEqual(b"1", self.con.get("x"))
3644

37-
def test_int_get_set(self):
45+
def test_int_set_get(self):
3846
self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts', 'test_int_set_get', 'OUTPUTS', 'y')
3947
y = self.con.execute_command('AI.TENSORGET', 'y', 'meta' ,'VALUES')
4048
self.env.assertEqual(y, [b"dtype", b"INT64", b"shape", [], b"values", [1]] )
4149

50+
def test_int_set_incr(self):
51+
self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts', 'test_int_set_incr', 'OUTPUTS', 'y')
52+
y = self.con.execute_command('AI.TENSORGET', 'y', 'meta' ,'VALUES')
53+
self.env.assertEqual(y, [b"dtype", b"INT64", b"shape", [], b"values", [2]] )
54+
4255
def test_float_get_set(self):
4356
self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts', 'test_float_set_get', 'OUTPUTS', 'y')
4457
y = self.con.execute_command('AI.TENSORGET', 'y', 'meta' ,'VALUES')
@@ -47,4 +60,14 @@ def test_float_get_set(self):
4760
self.env.assertEqual(y[2], b"shape")
4861
self.env.assertEqual(y[3], [])
4962
self.env.assertEqual(y[4], b"values")
50-
self.env.assertAlmostEqual(float(y[5][0]), 1.1, 0.1)
63+
self.env.assertAlmostEqual(float(y[5][0]), 1.1, 0.1)
64+
65+
def test_int_list(self):
66+
self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts', 'test_int_list', 'OUTPUTS', 'y')
67+
y = self.con.execute_command('AI.TENSORGET', 'y', 'meta' ,'VALUES')
68+
self.env.assertEqual(y, [b"dtype", b"INT64", b"shape", [2, 1], b"values", [1, 2]] )
69+
70+
def test_hash(self):
71+
self.con.execute_command('AI.SCRIPTRUN', 'redis_scripts', 'test_hash', 'OUTPUTS', 'y')
72+
y = self.con.execute_command('AI.TENSORGET', 'y', 'meta' ,'VALUES')
73+
self.env.assertEqual(y, [b"dtype", b"INT64", b"shape", [2, 1], b"values", [1, 2]] )

0 commit comments

Comments
 (0)