|
2 | 2 |
|
3 | 3 | from includes import * |
4 | 4 | import os |
| 5 | +from functools import wraps |
5 | 6 |
|
6 | 7 | ''' |
7 | 8 | python -m RLTest --test tests_llapi.py --module path/to/redisai.so |
|
10 | 11 | goal_dir = os.path.join(os.getcwd(), "../module/LLAPI.so") |
11 | 12 | TEST_MODULE_PATH = os.path.abspath(goal_dir) |
12 | 13 |
|
| 14 | + |
| 15 | +def skip_if_gears_not_loaded(f): |
| 16 | + @wraps(f) |
| 17 | + def wrapper(env, *args, **kwargs): |
| 18 | + con = env.getConnection() |
| 19 | + modules = con.execute_command("MODULE", "LIST") |
| 20 | + if "rg" in [module[1] for module in modules]: |
| 21 | + return f(env, *args, **kwargs) |
| 22 | + try: |
| 23 | + redisgears_path = os.path.join(os.path.dirname(__file__), '../../../RedisGears/redisgears.so') |
| 24 | + ret = con.execute_command('MODULE', 'LOAD', redisgears_path) |
| 25 | + env.assertEqual(ret, b'OK') |
| 26 | + return f(env, *args, **kwargs) |
| 27 | + except Exception as e: |
| 28 | + env.debugPrint("skipping since RedisGears not loaded", force=True) |
| 29 | + return |
| 30 | + return wrapper |
| 31 | + |
| 32 | + |
13 | 33 | def test_basic_check(env): |
14 | 34 |
|
15 | 35 | con = env.getConnection() |
@@ -38,3 +58,51 @@ def test_model_run_async(env): |
38 | 58 | con.execute_command('AI.TENSORSET', 'b{1}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3) |
39 | 59 | ret = con.execute_command("RAI_llapi.modelRun") |
40 | 60 | env.assertEqual(ret, b'Async run success') |
| 61 | + |
| 62 | + |
| 63 | +@skip_if_gears_not_loaded |
| 64 | +def test_model_run_async_via_gears(env): |
| 65 | + script = ''' |
| 66 | +import redisAI |
| 67 | +
|
| 68 | +async def RedisAIModelRun(record): |
| 69 | + keys = ['a{1}', 'b{1}'] |
| 70 | + tensors = redisAI.mgetTensorsFromKeyspace(keys) |
| 71 | + modelRunner = redisAI.createModelRunner('m{1}') |
| 72 | + redisAI.modelRunnerAddInput(modelRunner, 'a', tensors[0]) |
| 73 | + redisAI.modelRunnerAddInput(modelRunner, 'b', tensors[1]) |
| 74 | + redisAI.modelRunnerAddOutput(modelRunner, 'mul') |
| 75 | + res = await redisAI.modelRunnerRunAsync(modelRunner) |
| 76 | + if len(res[1]) > 0: |
| 77 | + raise Exception(res[1][0]) |
| 78 | + redisAI.setTensorInKey('c{1}', res[0][0]) |
| 79 | + return "OK" |
| 80 | +
|
| 81 | +GB("CommandReader").map(RedisAIModelRun).register(trigger="ModelRunAsyncTest") |
| 82 | + ''' |
| 83 | + con = env.getConnection() |
| 84 | + ret = con.execute_command('rg.pyexecute', script) |
| 85 | + env.assertEqual(ret, b'OK') |
| 86 | + |
| 87 | + test_data_path = os.path.join(os.path.dirname(__file__), 'test_data') |
| 88 | + model_filename = os.path.join(test_data_path, 'graph.pb') |
| 89 | + |
| 90 | + with open(model_filename, 'rb') as f: |
| 91 | + model_pb = f.read() |
| 92 | + |
| 93 | + ret = con.execute_command('AI.MODELSET', 'm{1}', 'TF', DEVICE, |
| 94 | + 'INPUTS', 'a', 'b', 'OUTPUTS', 'mul', 'BLOB', model_pb) |
| 95 | + env.assertEqual(ret, b'OK') |
| 96 | + |
| 97 | + ret = con.execute_command('AI.MODELGET', 'm{1}', 'META') |
| 98 | + env.assertEqual(len(ret), 14) |
| 99 | + |
| 100 | + con.execute_command('AI.TENSORSET', 'a{1}', 'FLOAT', |
| 101 | + 2, 2, 'VALUES', 2, 3, 2, 3) |
| 102 | + con.execute_command('AI.TENSORSET', 'b{1}', 'FLOAT', |
| 103 | + 2, 2, 'VALUES', 2, 3, 2, 3) |
| 104 | + |
| 105 | + ret = con.execute_command('rg.trigger', 'ModelRunAsyncTest') |
| 106 | + env.assertEqual(ret[0], b'OK') |
| 107 | + values = con.execute_command('AI.TENSORGET', 'c{1}', 'VALUES') |
| 108 | + env.assertEqual(values, [b'4', b'9', b'4', b'9']) |
0 commit comments