Skip to content

Commit 34c4e3e

Browse files
committed
Add test with gears, PR fixes.
1 parent 2b539d4 commit 34c4e3e

File tree

8 files changed

+92
-24
lines changed

8 files changed

+92
-24
lines changed

src/command_parser.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ static int _ModelRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **a
102102
* @return REDISMODULE_OK in case of success, REDISMODULE_ERR otherwise
103103
*/
104104

105-
static int ModelRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inkeys,
106-
RedisModuleString **outkeys, RAI_ModelRunCtx *mctx) {
105+
static int _ModelRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inkeys,
106+
RedisModuleString **outkeys, RAI_ModelRunCtx *mctx) {
107107

108108
RAI_Model *model = mctx->model;
109109
RAI_Tensor *t;
@@ -152,7 +152,7 @@ int ParseModelRunCommand(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModul
152152
if (rinfo->single_op_dag) {
153153
rinfo->timeout = timeout;
154154
// Set params in ModelRunCtx, bring inputs from key space.
155-
if (ModelRunCtx_SetParams(ctx, inkeys, outkeys, mctx) == REDISMODULE_ERR)
155+
if (_ModelRunCtx_SetParams(ctx, inkeys, outkeys, mctx) == REDISMODULE_ERR)
156156
goto cleanup;
157157
}
158158
if (RAI_InitDagOp(&currentOp) == REDISMODULE_ERR) {
@@ -177,7 +177,7 @@ int ParseModelRunCommand(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModul
177177
if (runkey)
178178
RedisModule_FreeString(NULL, runkey);
179179
if (mctx)
180-
RAI_ModelRunCtxFree(mctx, true);
180+
RAI_ModelRunCtxFree(mctx);
181181
RAI_FreeRunInfo(rinfo);
182182
return REDISMODULE_ERR;
183183
}

src/modelRun_ctx.c

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,14 @@ RAI_Tensor *RAI_ModelRunCtxOutputTensor(RAI_ModelRunCtx *mctx, size_t index) {
4444
return mctx->outputs[index].tensor;
4545
}
4646

47-
void RAI_ModelRunCtxFree(RAI_ModelRunCtx *mctx, int freeTensors) {
48-
if (freeTensors) {
49-
for (size_t i = 0; i < array_len(mctx->inputs); ++i) {
50-
RAI_TensorFree(mctx->inputs[i].tensor);
51-
}
47+
void RAI_ModelRunCtxFree(RAI_ModelRunCtx *mctx) {
48+
for (size_t i = 0; i < array_len(mctx->inputs); ++i) {
49+
RAI_TensorFree(mctx->inputs[i].tensor);
50+
}
5251

53-
for (size_t i = 0; i < array_len(mctx->outputs); ++i) {
54-
if (mctx->outputs[i].tensor) {
55-
RAI_TensorFree(mctx->outputs[i].tensor);
56-
}
52+
for (size_t i = 0; i < array_len(mctx->outputs); ++i) {
53+
if (mctx->outputs[i].tensor) {
54+
RAI_TensorFree(mctx->outputs[i].tensor);
5755
}
5856
}
5957

src/modelRun_ctx.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ RAI_ModelRunCtx *RAI_ModelRunCtxCreate(RAI_Model *model);
1818
* @param mctx
1919
* @param freeTensors free input and output tensors or leave them allocated
2020
*/
21-
void RAI_ModelRunCtxFree(RAI_ModelRunCtx *mctx, int freeTensors);
21+
void RAI_ModelRunCtxFree(RAI_ModelRunCtx *mctxs);
2222

2323
/**
2424
* Allocates a RAI_ModelCtxParam data structure, and enforces a shallow copy of

src/redisai.c

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -567,8 +567,7 @@ int RedisAI_ModelRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
567567
if (RedisModule_IsKeysPositionRequest(ctx)) {
568568
return RedisAI_ModelRun_IsKeysPositionRequest_ReportKeys(ctx, argv, argc);
569569
}
570-
bool ro_dag = false;
571-
return RedisAI_ExecuteCommand(ctx, argv, argc, CMD_MODELRUN, ro_dag);
570+
return RedisAI_ExecuteCommand(ctx, argv, argc, CMD_MODELRUN, false);
572571
}
573572

574573
/**
@@ -584,8 +583,7 @@ int RedisAI_ScriptRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
584583
return RedisModule_WrongArity(ctx);
585584

586585
// Convert The script run command into a DAG command that contains a single op.
587-
bool ro_dag = false;
588-
return RedisAI_ExecuteCommand(ctx, argv, argc, CMD_SCRIPTRUN, ro_dag);
586+
return RedisAI_ExecuteCommand(ctx, argv, argc, CMD_SCRIPTRUN, false);
589587
}
590588

591589
/**
@@ -892,8 +890,7 @@ int RedisAI_DagRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, i
892890
if (RedisModule_IsKeysPositionRequest(ctx)) {
893891
return RedisAI_DagRun_IsKeysPositionRequest_ReportKeys(ctx, argv, argc);
894892
}
895-
bool ro_only = false;
896-
return RedisAI_ExecuteCommand(ctx, argv, argc, CMD_DAGRUN, ro_only);
893+
return RedisAI_ExecuteCommand(ctx, argv, argc, CMD_DAGRUN, false);
897894
}
898895

899896
/**

src/redisai.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ typedef void (*RAI_OnFinishCB)(RAI_OnFinishCtx *ctx, void *private_data);
4848
#define RedisAI_ErrorCode_ETENSORSET 14
4949
#define RedisAI_ErrorCode_ETENSORGET 15
5050
#define RedisAI_ErrorCode_EDAGRUN 16
51+
#define RedisAI_ErrorCode_EFINISHCTX 17
5152

5253
enum RedisAI_DataFmt { REDISAI_DATA_BLOB = 0, REDISAI_DATA_VALUES, REDISAI_DATA_NONE };
5354

@@ -93,7 +94,7 @@ int MODULE_API_FUNC(RedisAI_ModelRunCtxAddInput)(RAI_ModelRunCtx *mctx, const ch
9394
int MODULE_API_FUNC(RedisAI_ModelRunCtxAddOutput)(RAI_ModelRunCtx *mctx, const char *outputName);
9495
size_t MODULE_API_FUNC(RedisAI_ModelRunCtxNumOutputs)(RAI_ModelRunCtx *mctx);
9596
RAI_Tensor *MODULE_API_FUNC(RedisAI_ModelRunCtxOutputTensor)(RAI_ModelRunCtx *mctx, size_t index);
96-
void MODULE_API_FUNC(RedisAI_ModelRunCtxFree)(RAI_ModelRunCtx *mctx, bool freeTensors);
97+
void MODULE_API_FUNC(RedisAI_ModelRunCtxFree)(RAI_ModelRunCtx *mctx);
9798
int MODULE_API_FUNC(RedisAI_ModelRun)(RAI_ModelRunCtx **mctx, long long n, RAI_Error *err);
9899
RAI_Model *MODULE_API_FUNC(RedisAI_ModelGetShallowCopy)(RAI_Model *model);
99100
int MODULE_API_FUNC(RedisAI_ModelSerialize)(RAI_Model *model, char **buffer, size_t *len,

src/run_info.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ void RAI_FreeDagOp(RAI_DagOp *dagOp) {
179179
array_free(dagOp->outTensors);
180180

181181
if (dagOp->mctx) {
182-
RAI_ModelRunCtxFree(dagOp->mctx, true);
182+
RAI_ModelRunCtxFree(dagOp->mctx);
183183
}
184184
if (dagOp->sctx) {
185185
RAI_ScriptRunCtxFree(dagOp->sctx, true);
@@ -347,6 +347,10 @@ int RAI_RunInfoBatchable(struct RAI_DagOp *op1, struct RAI_DagOp *op2) {
347347
RAI_ModelRunCtx *RAI_GetAsModelRunCtx(RedisAI_RunInfo *rinfo, RAI_Error *err) {
348348

349349
RAI_DagOp *op = rinfo->dagOps[0];
350+
if (!rinfo->single_op_dag || !op->mctx) {
351+
RAI_SetError(err, RedisAI_ErrorCode_EFINISHCTX, "Finish ctx is not a model run ctx");
352+
return NULL;
353+
}
350354
RAI_SetError(err, RAI_GetErrorCode(op->err), RAI_GetError(op->err));
351355
RAI_ModelRunCtx *mctx = op->mctx;
352356
rinfo->dagOps[0]->mctx = NULL;

tests/flow/tests_llapi.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from includes import *
44
import os
5+
from functools import wraps
56

67
'''
78
python -m RLTest --test tests_llapi.py --module path/to/redisai.so
@@ -10,6 +11,25 @@
1011
goal_dir = os.path.join(os.getcwd(), "../module/LLAPI.so")
1112
TEST_MODULE_PATH = os.path.abspath(goal_dir)
1213

14+
15+
def skip_if_gears_not_loaded(f):
16+
@wraps(f)
17+
def wrapper(env, *args, **kwargs):
18+
con = env.getConnection()
19+
modules = con.execute_command("MODULE", "LIST")
20+
if "rg" in [module[1] for module in modules]:
21+
return f(env, *args, **kwargs)
22+
try:
23+
redisgears_path = os.path.join(os.path.dirname(__file__), '../../../RedisGears/redisgears.so')
24+
ret = con.execute_command('MODULE', 'LOAD', redisgears_path)
25+
env.assertEqual(ret, b'OK')
26+
return f(env, *args, **kwargs)
27+
except Exception as e:
28+
env.debugPrint("skipping since RedisGears not loaded", force=True)
29+
return
30+
return wrapper
31+
32+
1333
def test_basic_check(env):
1434

1535
con = env.getConnection()
@@ -38,3 +58,51 @@ def test_model_run_async(env):
3858
con.execute_command('AI.TENSORSET', 'b{1}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
3959
ret = con.execute_command("RAI_llapi.modelRun")
4060
env.assertEqual(ret, b'Async run success')
61+
62+
63+
@skip_if_gears_not_loaded
64+
def test_model_run_async_via_gears(env):
65+
script = '''
66+
import redisAI
67+
68+
async def RedisAIModelRun(record):
69+
keys = ['a{1}', 'b{1}']
70+
tensors = redisAI.mgetTensorsFromKeyspace(keys)
71+
modelRunner = redisAI.createModelRunner('m{1}')
72+
redisAI.modelRunnerAddInput(modelRunner, 'a', tensors[0])
73+
redisAI.modelRunnerAddInput(modelRunner, 'b', tensors[1])
74+
redisAI.modelRunnerAddOutput(modelRunner, 'mul')
75+
res = await redisAI.modelRunnerRunAsync(modelRunner)
76+
if len(res[1]) > 0:
77+
raise Exception(res[1][0])
78+
redisAI.setTensorInKey('c{1}', res[0][0])
79+
return "OK"
80+
81+
GB("CommandReader").map(RedisAIModelRun).register(trigger="ModelRunAsyncTest")
82+
'''
83+
con = env.getConnection()
84+
ret = con.execute_command('rg.pyexecute', script)
85+
env.assertEqual(ret, b'OK')
86+
87+
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
88+
model_filename = os.path.join(test_data_path, 'graph.pb')
89+
90+
with open(model_filename, 'rb') as f:
91+
model_pb = f.read()
92+
93+
ret = con.execute_command('AI.MODELSET', 'm{1}', 'TF', DEVICE,
94+
'INPUTS', 'a', 'b', 'OUTPUTS', 'mul', 'BLOB', model_pb)
95+
env.assertEqual(ret, b'OK')
96+
97+
ret = con.execute_command('AI.MODELGET', 'm{1}', 'META')
98+
env.assertEqual(len(ret), 14)
99+
100+
con.execute_command('AI.TENSORSET', 'a{1}', 'FLOAT',
101+
2, 2, 'VALUES', 2, 3, 2, 3)
102+
con.execute_command('AI.TENSORSET', 'b{1}', 'FLOAT',
103+
2, 2, 'VALUES', 2, 3, 2, 3)
104+
105+
ret = con.execute_command('rg.trigger', 'ModelRunAsyncTest')
106+
env.assertEqual(ret[0], b'OK')
107+
values = con.execute_command('AI.TENSORGET', 'c{1}', 'VALUES')
108+
env.assertEqual(values, [b'4', b'9', b'4', b'9'])

tests/module/LLAPI.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,15 @@ static int _ExecuteModelRunAsync(RedisModuleCtx *ctx, RAI_ModelRunCtx* mctx) {
7070
pthread_mutex_lock(&global_lock);
7171
if (RedisAI_ModelRunAsync(mctx, ModelFinishFunc, &status) != REDISMODULE_OK) {
7272
pthread_mutex_unlock(&global_lock);
73-
RedisAI_ModelRunCtxFree(mctx, true);
73+
RedisAI_ModelRunCtxFree(mctx);
7474
RedisModule_ReplyWithError(ctx, "Async run could not start");
7575
return LLAPI_RUN_NONE;
7676
}
7777

7878
// Wait until the onFinish callback returns.
7979
pthread_cond_wait(&global_cond, &global_lock);
8080
pthread_mutex_unlock(&global_lock);
81-
RedisAI_ModelRunCtxFree(mctx, true);
81+
RedisAI_ModelRunCtxFree(mctx);
8282
return status;
8383
}
8484

0 commit comments

Comments
 (0)