Skip to content

Commit f566fcd

Browse files
authored
Avoid using arr in ModelRun API, align redisai.h ModelRun signature (#361)
1 parent 1220e8a commit f566fcd

File tree

5 files changed

+30
-32
lines changed

5 files changed

+30
-32
lines changed

src/dag.c

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,11 @@ void *RedisAI_DagRunSession(RedisAI_RunInfo *rinfo) {
6969
&(rinfo->dagTensorsContext), 0, NULL, currentOp->err);
7070

7171
if (parse_result > 0) {
72-
RAI_ModelRunCtx **mctxs = NULL;
73-
mctxs = array_new(RAI_ModelRunCtx *, 1);
74-
mctxs = array_append(mctxs, currentOp->mctx);
72+
RAI_ModelRunCtx *mctxs[1];
73+
mctxs[0] = currentOp->mctx;
7574
currentOp->result = REDISMODULE_OK;
7675
const long long start = ustime();
77-
currentOp->result = RAI_ModelRun(mctxs, currentOp->err);
76+
currentOp->result = RAI_ModelRun(mctxs, 1, currentOp->err);
7877
currentOp->duration_us = ustime() - start;
7978
const size_t noutputs = RAI_ModelRunCtxNumOutputs(currentOp->mctx);
8079
for (size_t outputNumber = 0; outputNumber < noutputs;
@@ -92,7 +91,6 @@ void *RedisAI_DagRunSession(RedisAI_RunInfo *rinfo) {
9291
currentOp->result = REDISMODULE_ERR;
9392
}
9493
}
95-
array_free(mctxs);
9694
} else {
9795
currentOp->result = REDISMODULE_ERR;
9896
}

src/model.c

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -413,48 +413,55 @@ void RAI_ModelRunCtxFree(RAI_ModelRunCtx* mctx) {
413413
RedisModule_Free(mctx);
414414
}
415415

416-
int RAI_ModelRun(RAI_ModelRunCtx** mctxs, RAI_Error* err) {
416+
int RAI_ModelRun(RAI_ModelRunCtx** mctxs, long long n, RAI_Error* err) {
417417
int ret;
418418

419-
if (array_len(mctxs) == 0) {
419+
if (n == 0) {
420420
RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Nothing to run");
421421
return REDISMODULE_ERR;
422422
}
423423

424-
switch (mctxs[0]->model->backend) {
424+
RAI_ModelRunCtx** mctxs_arr = array_newlen(RAI_ModelRunCtx*, n);
425+
for (int i=0; i<n; i++) {
426+
mctxs_arr[i] = mctxs[i];
427+
}
428+
429+
switch (mctxs_arr[0]->model->backend) {
425430
case RAI_BACKEND_TENSORFLOW:
426431
if (!RAI_backends.tf.model_run) {
427432
RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TF");
428433
return REDISMODULE_ERR;
429434
}
430-
ret = RAI_backends.tf.model_run(mctxs, err);
435+
ret = RAI_backends.tf.model_run(mctxs_arr, err);
431436
break;
432437
case RAI_BACKEND_TFLITE:
433438
if (!RAI_backends.tflite.model_run) {
434439
RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TFLITE");
435440
return REDISMODULE_ERR;
436441
}
437-
ret = RAI_backends.tflite.model_run(mctxs, err);
442+
ret = RAI_backends.tflite.model_run(mctxs_arr, err);
438443
break;
439444
case RAI_BACKEND_TORCH:
440445
if (!RAI_backends.torch.model_run) {
441446
RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: TORCH");
442447
return REDISMODULE_ERR;
443448
}
444-
ret = RAI_backends.torch.model_run(mctxs, err);
449+
ret = RAI_backends.torch.model_run(mctxs_arr, err);
445450
break;
446451
case RAI_BACKEND_ONNXRUNTIME:
447452
if (!RAI_backends.onnx.model_run) {
448453
RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Backend not loaded: ONNX");
449454
return REDISMODULE_ERR;
450455
}
451-
ret = RAI_backends.onnx.model_run(mctxs, err);
456+
ret = RAI_backends.onnx.model_run(mctxs_arr, err);
452457
break;
453458
default:
454459
RAI_SetError(err, RAI_EUNSUPPORTEDBACKEND, "ERR Unsupported backend");
455460
return REDISMODULE_ERR;
456461
}
457462

463+
array_free(mctxs_arr);
464+
458465
return ret;
459466
}
460467

@@ -641,4 +648,4 @@ int RedisAI_Parse_ModelRun_RedisCommand(RedisModuleCtx *ctx,
641648
return -1;
642649
}
643650
return argpos;
644-
}
651+
}

src/model.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,13 @@ RAI_Tensor* RAI_ModelRunCtxOutputTensor(RAI_ModelRunCtx* mctx, size_t index);
150150
* definition.
151151
*
152152
* @param mctxs array on input model contexts
153+
* @param n length of input model contexts array
153154
* @param error error data structure to store error message in the case of
154155
* failures
155156
* @return REDISMODULE_OK if the underlying backend `model_run` runned
156157
* successfully, or REDISMODULE_ERR if failed.
157158
*/
158-
int RAI_ModelRun(RAI_ModelRunCtx** mctxs, RAI_Error* err);
159+
int RAI_ModelRun(RAI_ModelRunCtx** mctxs, long long n, RAI_Error* err);
159160

160161
/**
161162
* Every call to this function, will make the RAI_Model 'model' requiring an

src/model_script_run_session.c

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,28 +34,26 @@ void *RAI_ModelRunScriptRunSession(RedisAI_RunInfo **batch_rinfo) {
3434
return NULL;
3535
}
3636

37-
RAI_ModelRunCtx **mctxs = NULL;
3837
RAI_ScriptRunCtx *sctx = NULL;
3938

4039
RAI_Error *err = RedisModule_Calloc(1, sizeof(RAI_Error));
4140
long long rtime;
4241
int status;
42+
bool is_model;
4343
if (batch_rinfo[0]->mctx) {
44-
mctxs = array_new(RAI_ModelRunCtx *, batch_size);
44+
RAI_ModelRunCtx *mctxs[batch_size];
4545
for (long long i = 0; i < batch_size; i++) {
46-
mctxs = array_append(mctxs, batch_rinfo[i]->mctx);
46+
mctxs[i] = batch_rinfo[i]->mctx;
4747
}
48+
const long long start = ustime();
49+
status = RAI_ModelRun(mctxs, batch_size, err);
50+
rtime = ustime() - start;
4851
} else if (batch_rinfo[0]->sctx) {
4952
sctx = batch_rinfo[0]->sctx;
50-
}
51-
52-
const long long start = ustime();
53-
if (mctxs) {
54-
status = RAI_ModelRun(mctxs, err);
55-
} else if (sctx) {
53+
const long long start = ustime();
5654
status = RAI_ScriptRun(sctx, err);
55+
rtime = ustime() - start;
5756
}
58-
rtime = ustime() - start;
5957

6058
for (long long i = 0; i < batch_size; i++) {
6159
struct RedisAI_RunInfo *rinfo = batch_rinfo[i];
@@ -76,12 +74,6 @@ void *RAI_ModelRunScriptRunSession(RedisAI_RunInfo **batch_rinfo) {
7674
}
7775
}
7876

79-
if (mctxs) {
80-
array_free(mctxs);
81-
} else if (sctx) {
82-
// No batching for scripts for now
83-
}
84-
8577
return NULL;
8678
}
8779

@@ -159,4 +151,4 @@ void RedisAI_FreeData(RedisModuleCtx *ctx, void *rinfo) {}
159151
void RedisAI_Disconnected(RedisModuleCtx *ctx, RedisModuleBlockedClient *bc) {
160152
RedisModule_Log(ctx, "warning", "Blocked client %p disconnected!",
161153
(void *)bc);
162-
}
154+
}

src/redisai.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ int MODULE_API_FUNC(RedisAI_ModelRunCtxAddOutput)(RAI_ModelRunCtx* mctx, const c
6565
size_t MODULE_API_FUNC(RedisAI_ModelRunCtxNumOutputs)(RAI_ModelRunCtx* mctx);
6666
RAI_Tensor* MODULE_API_FUNC(RedisAI_ModelRunCtxOutputTensor)(RAI_ModelRunCtx* mctx, size_t index);
6767
void MODULE_API_FUNC(RedisAI_ModelRunCtxFree)(RAI_ModelRunCtx* mctx);
68-
int MODULE_API_FUNC(RedisAI_ModelRun)(RAI_ModelRunCtx* mctx, RAI_Error* err);
68+
int MODULE_API_FUNC(RedisAI_ModelRun)(RAI_ModelRunCtx** mctx, long long n, RAI_Error* err);
6969
RAI_Model* MODULE_API_FUNC(RedisAI_ModelGetShallowCopy)(RAI_Model* model);
7070
int MODULE_API_FUNC(RedisAI_ModelSerialize)(RAI_Model *model, char **buffer, size_t *len, RAI_Error *err);
7171

0 commit comments

Comments
 (0)