Skip to content

Commit 8fdf44d

Browse files
author
DvirDukhan
authored
Merge branch 'master' into torchscript_extensions
2 parents d655511 + 64001de commit 8fdf44d

File tree

14 files changed

+122
-234
lines changed

14 files changed

+122
-234
lines changed

src/DAG/dag.c

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -884,20 +884,3 @@ void DAG_ReplyAndUnblock(RedisAI_OnFinishCtx *ctx, void *private_data) {
884884
if (rinfo->client)
885885
RedisModule_UnblockClient(rinfo->client, rinfo);
886886
}
887-
888-
void Dag_PopulateOp(RAI_DagOp *currentOp, void *rctx, RedisModuleString **inkeys,
889-
RedisModuleString **outkeys, RedisModuleString *runkey) {
890-
891-
if (currentOp->commandType == REDISAI_DAG_CMD_MODELRUN) {
892-
currentOp->mctx = (RAI_ModelRunCtx *)rctx;
893-
currentOp->devicestr = currentOp->mctx->model->devicestr;
894-
} else {
895-
assert(currentOp->commandType == REDISAI_DAG_CMD_SCRIPTRUN);
896-
currentOp->sctx = (RAI_ScriptRunCtx *)rctx;
897-
currentOp->devicestr = currentOp->sctx->script->devicestr;
898-
}
899-
900-
currentOp->inkeys = inkeys;
901-
currentOp->outkeys = outkeys;
902-
currentOp->runkey = runkey;
903-
}

src/DAG/dag.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -147,17 +147,4 @@ void RunInfo_FreeData(RedisModuleCtx *ctx, void *rinfo);
147147
*/
148148
void RedisAI_Disconnected(RedisModuleCtx *ctx, RedisModuleBlockedClient *bc);
149149

150-
/**
151-
* @brief Populate a DAG modelrun/scriptrun op with its params .
152-
* @param rinfo An existing DAG to populate.
153-
* @param rctx ModelRunCtx or ScriptRunCtx that represents the single MODELRUN op.
154-
* @param inkeys The DAG operation inkeys (the input tensors).
155-
* @param outkeys The DAG operation outkeys (the output tensors).
156-
* @param runkey The model key.
157-
* @param cmd The DAG command (modelrun/scriptrun).
158-
*/
159-
160-
void Dag_PopulateOp(RAI_DagOp *currentOp, void *rctx, RedisModuleString **inkeys,
161-
RedisModuleString **outkeys, RedisModuleString *runkey);
162-
163150
#endif /* SRC_DAG_H_ */

src/DAG/dag_parser.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, b
229229
return REDISMODULE_ERR;
230230
}
231231
currentOp->devicestr = mto->devicestr;
232+
RAI_HoldString(NULL, argv[arg_pos + 1]);
232233
currentOp->runkey = argv[arg_pos + 1];
233234
currentOp->mctx = RAI_ModelRunCtxCreate(mto);
234235
}
@@ -249,6 +250,7 @@ int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, b
249250
}
250251
currentOp->devicestr = sto->devicestr;
251252
const char *functionName = RedisModule_StringPtrLen(argv[arg_pos + 2], NULL);
253+
RAI_HoldString(NULL, argv[arg_pos + 1]);
252254
currentOp->runkey = argv[arg_pos + 1];
253255
currentOp->sctx = RAI_ScriptRunCtxCreate(sto, functionName);
254256
}
@@ -395,8 +397,8 @@ int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, b
395397
sprintf(buf, "%04d", *instance);
396398
RedisModuleString *mangled_key = RedisModule_CreateStringFromString(NULL, key);
397399
RedisModule_StringAppendBuffer(NULL, mangled_key, buf, strlen(buf));
398-
399400
AI_dictAdd(mangled_persisted, (void *)mangled_key, (void *)1);
401+
RedisModule_FreeString(NULL, mangled_key);
400402
entry = AI_dictNext(iter);
401403
}
402404
AI_dictReleaseIterator(iter);

src/backends/tensorflow.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
250250
char *msg = RedisModule_Calloc(60 + len, sizeof(*msg));
251251
sprintf(msg, "ERR Input node named \"%s\" not found in TF graph.", inputs[i]);
252252
RAI_SetError(error, RAI_EMODELIMPORT, msg);
253+
RedisModule_Free(msg);
253254
return NULL;
254255
}
255256
}

src/background_workers.c

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -391,14 +391,11 @@ void *RedisAI_Run_ThreadMain(void *arg) {
391391

392392
// Run is over, now iterate over the run info structs in the batch
393393
// and see if any error was generated
394-
bool first_dag_error = false;
395394
for (long long i = 0; i < array_len(batch_rinfo); i++) {
396395
RedisAI_RunInfo *rinfo = batch_rinfo[i];
397396
// We record that there was an error for later on
398397
run_error = __atomic_load_n(rinfo->dagError, __ATOMIC_RELAXED);
399-
if (i == 0 && run_error == 1) {
400-
first_dag_error = true;
401-
}
398+
402399
// If there was an error and the reference count for the dag
403400
// has gone to zero and the client is still around, we unblock
404401
if (run_error) {
@@ -413,37 +410,35 @@ void *RedisAI_Run_ThreadMain(void *arg) {
413410
__atomic_add_fetch(rinfo->dagCompleteOpCount, 1, __ATOMIC_RELAXED);
414411
}
415412
}
416-
if (first_dag_error) {
417-
run_queue_len = queueLength(run_queue_info->run_queue);
418-
continue;
419-
}
420413
}
421414

422415
// We initialize variables where we'll store the fact hat, after the current
423416
// run, all ops for the device or all ops in the dag could be complete. This
424417
// way we can avoid placing the op back on the queue if there's nothing left
425418
// to do.
426-
RedisModule_Assert(run_error == 0);
427-
int device_complete_after_run = RedisAI_DagDeviceComplete(batch_rinfo[0]);
428-
int dag_complete_after_run = RedisAI_DagComplete(batch_rinfo[0]);
429-
430-
long long dagRefCount = -1;
431-
RedisAI_RunInfo *orig;
432-
if (device_complete == 1 || device_complete_after_run == 1) {
433-
RedisAI_RunInfo *evicted_rinfo = (RedisAI_RunInfo *)(evicted_items[0]->value);
434-
orig = evicted_rinfo->orig_copy;
435-
// We decrease and get the reference count for the DAG.
436-
dagRefCount = RAI_DagRunInfoFreeShallowCopy(evicted_rinfo);
437-
}
419+
int device_complete_after_run;
420+
if (run_error == 0) {
421+
device_complete_after_run = RedisAI_DagDeviceComplete(batch_rinfo[0]);
422+
int dag_complete_after_run = RedisAI_DagComplete(batch_rinfo[0]);
423+
424+
long long dagRefCount = -1;
425+
RedisAI_RunInfo *orig;
426+
if (device_complete == 1 || device_complete_after_run == 1) {
427+
RedisAI_RunInfo *evicted_rinfo = (RedisAI_RunInfo *)(evicted_items[0]->value);
428+
orig = evicted_rinfo->orig_copy;
429+
// We decrease and get the reference count for the DAG.
430+
dagRefCount = RAI_DagRunInfoFreeShallowCopy(evicted_rinfo);
431+
}
438432

439-
// If the DAG was complete, then it's time to unblock the client
440-
if (do_unblock == 1 || dag_complete_after_run == 1) {
433+
// If the DAG was complete, then it's time to unblock the client
434+
if (do_unblock == 1 || dag_complete_after_run == 1) {
441435

442-
// If the reference count for the DAG is zero and the client is still around,
443-
// then we actually unblock the client
444-
if (dagRefCount == 0) {
445-
RedisAI_OnFinishCtx *finish_ctx = orig;
446-
orig->OnFinish(finish_ctx, orig->private_data);
436+
// If the reference count for the DAG is zero and the client is still around,
437+
// then we actually unblock the client
438+
if (dagRefCount == 0) {
439+
RedisAI_OnFinishCtx *finish_ctx = orig;
440+
orig->OnFinish(finish_ctx, orig->private_data);
441+
}
447442
}
448443
}
449444

@@ -499,11 +494,8 @@ void *RedisAI_Run_ThreadMain(void *arg) {
499494

500495
// If there's nothing else to do for the DAG in the current worker or if an error
501496
// occurred in any worker, we just move on
502-
if (device_complete == 1 || device_complete_after_run == 1 || do_unblock == 1 ||
503-
run_error == 1) {
504-
for (long long i = 0; i < array_len(evicted_items); i++) {
505-
RedisModule_Free(evicted_items[i]);
506-
}
497+
for (long long i = 0; i < array_len(evicted_items); i++) {
498+
RedisModule_Free(evicted_items[i]);
507499
}
508500
run_queue_len = queueLength(run_queue_info->run_queue);
509501
}

src/command_parser.c

Lines changed: 37 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -125,52 +125,41 @@ static int _ModelRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inkey
125125
int ParseModelRunCommand(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModuleString **argv,
126126
int argc) {
127127

128+
RAI_DagOp *currentOp;
129+
RAI_InitDagOp(&currentOp);
130+
rinfo->dagOps = array_append(rinfo->dagOps, currentOp);
131+
128132
// Build a ModelRunCtx from command.
129-
RAI_Error error = {0};
130133
RAI_Model *model;
131-
RedisModuleString **inkeys = array_new(RedisModuleString *, 1);
132-
RedisModuleString **outkeys = array_new(RedisModuleString *, 1);
133-
RedisModuleString *runkey = NULL;
134-
RAI_ModelRunCtx *mctx = NULL;
135-
RAI_DagOp *currentOp;
136134

137135
long long timeout = 0;
138-
if (_ModelRunCommand_ParseArgs(ctx, argv, argc, &model, &error, &inkeys, &outkeys, &runkey,
136+
if (_ModelRunCommand_ParseArgs(ctx, argv, argc, &model, currentOp->err, &currentOp->inkeys,
137+
&currentOp->outkeys, &currentOp->runkey,
139138
&timeout) == REDISMODULE_ERR) {
140-
RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(&error));
139+
RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(currentOp->err));
140+
goto cleanup;
141+
}
142+
143+
if (timeout > 0 && !rinfo->single_op_dag) {
144+
RedisModule_ReplyWithError(ctx, "ERR TIMEOUT not allowed within a DAG command");
141145
goto cleanup;
142146
}
143-
mctx = RAI_ModelRunCtxCreate(model);
144147

148+
RAI_ModelRunCtx *mctx = RAI_ModelRunCtxCreate(model);
145149
if (rinfo->single_op_dag) {
146150
rinfo->timeout = timeout;
147151
// Set params in ModelRunCtx, bring inputs from key space.
148-
if (_ModelRunCtx_SetParams(ctx, inkeys, outkeys, mctx) == REDISMODULE_ERR)
152+
if (_ModelRunCtx_SetParams(ctx, currentOp->inkeys, currentOp->outkeys, mctx) ==
153+
REDISMODULE_ERR)
149154
goto cleanup;
150155
}
151-
if (RAI_InitDagOp(&currentOp) == REDISMODULE_ERR) {
152-
RedisModule_ReplyWithError(
153-
ctx, "ERR Unable to allocate the memory and initialise the RAI_dagOp structure");
154-
goto cleanup;
155-
}
156+
156157
currentOp->commandType = REDISAI_DAG_CMD_MODELRUN;
157-
Dag_PopulateOp(currentOp, mctx, inkeys, outkeys, runkey);
158-
rinfo->dagOps = array_append(rinfo->dagOps, currentOp);
158+
currentOp->mctx = mctx;
159+
currentOp->devicestr = mctx->model->devicestr;
159160
return REDISMODULE_OK;
160161

161162
cleanup:
162-
for (size_t i = 0; i < array_len(inkeys); i++) {
163-
RedisModule_FreeString(NULL, inkeys[i]);
164-
}
165-
array_free(inkeys);
166-
for (size_t i = 0; i < array_len(outkeys); i++) {
167-
RedisModule_FreeString(NULL, outkeys[i]);
168-
}
169-
array_free(outkeys);
170-
if (runkey)
171-
RedisModule_FreeString(NULL, runkey);
172-
if (mctx)
173-
RAI_ModelRunCtxFree(mctx);
174163
RAI_FreeRunInfo(rinfo);
175164
return REDISMODULE_ERR;
176165
}
@@ -293,55 +282,44 @@ static int _ScriptRunCtx_SetParams(RedisModuleCtx *ctx, RedisModuleString **inke
293282
int ParseScriptRunCommand(RedisAI_RunInfo *rinfo, RedisModuleCtx *ctx, RedisModuleString **argv,
294283
int argc) {
295284

285+
RAI_DagOp *currentOp;
286+
RAI_InitDagOp(&currentOp);
287+
rinfo->dagOps = array_append(rinfo->dagOps, currentOp);
288+
296289
// Build a ScriptRunCtx from command.
297-
RAI_Error error = {0};
298290
RAI_Script *script;
299-
RedisModuleString **inkeys = array_new(RedisModuleString *, 1);
300-
RedisModuleString **outkeys = array_new(RedisModuleString *, 1);
301-
RedisModuleString *runkey = NULL;
302291
const char *func_name = NULL;
303-
RAI_ScriptRunCtx *sctx = NULL;
304-
RAI_DagOp *currentOp;
305292

306293
long long timeout = 0;
307294
int variadic = -1;
308-
if (_ScriptRunCommand_ParseArgs(ctx, argv, argc, &script, &error, &inkeys, &outkeys, &runkey,
309-
&func_name, &timeout, &variadic) == REDISMODULE_ERR) {
310-
RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(&error));
295+
if (_ScriptRunCommand_ParseArgs(ctx, argv, argc, &script, currentOp->err, &currentOp->inkeys,
296+
&currentOp->outkeys, &currentOp->runkey, &func_name, &timeout,
297+
&variadic) == REDISMODULE_ERR) {
298+
RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(currentOp->err));
311299
goto cleanup;
312300
}
313-
sctx = RAI_ScriptRunCtxCreate(script, func_name);
301+
if (timeout > 0 && !rinfo->single_op_dag) {
302+
RedisModule_ReplyWithError(ctx, "ERR TIMEOUT not allowed within a DAG command");
303+
goto cleanup;
304+
}
305+
306+
RAI_ScriptRunCtx *sctx = RAI_ScriptRunCtxCreate(script, func_name);
314307
sctx->variadic = variadic;
315308

316309
if (rinfo->single_op_dag) {
317310
rinfo->timeout = timeout;
318311
// Set params in ScriptRunCtx, bring inputs from key space.
319-
if (_ScriptRunCtx_SetParams(ctx, inkeys, outkeys, sctx) == REDISMODULE_ERR)
312+
if (_ScriptRunCtx_SetParams(ctx, currentOp->inkeys, currentOp->outkeys, sctx) ==
313+
REDISMODULE_ERR)
320314
goto cleanup;
321315
}
322-
if (RAI_InitDagOp(&currentOp) == REDISMODULE_ERR) {
323-
RedisModule_ReplyWithError(
324-
ctx, "ERR Unable to allocate the memory and initialise the RAI_dagOp structure");
325-
goto cleanup;
326-
}
316+
currentOp->sctx = sctx;
327317
currentOp->commandType = REDISAI_DAG_CMD_SCRIPTRUN;
328-
Dag_PopulateOp(currentOp, sctx, inkeys, outkeys, runkey);
329-
rinfo->dagOps = array_append(rinfo->dagOps, currentOp);
318+
currentOp->devicestr = sctx->script->devicestr;
319+
330320
return REDISMODULE_OK;
331321

332322
cleanup:
333-
for (size_t i = 0; i < array_len(inkeys); i++) {
334-
RedisModule_FreeString(NULL, inkeys[i]);
335-
}
336-
array_free(inkeys);
337-
for (size_t i = 0; i < array_len(outkeys); i++) {
338-
RedisModule_FreeString(NULL, outkeys[i]);
339-
}
340-
array_free(outkeys);
341-
if (runkey)
342-
RedisModule_FreeString(NULL, runkey);
343-
if (sctx)
344-
RAI_ScriptRunCtxFree(sctx);
345323
RAI_FreeRunInfo(rinfo);
346324
return REDISMODULE_ERR;
347325
}

src/err.c

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,6 @@ void RAI_SetError(RAI_Error *err, RAI_ErrorCode code, const char *detail) {
5656
int RAI_InitError(RAI_Error **result) {
5757
RAI_Error *err;
5858
err = (RAI_Error *)RedisModule_Calloc(1, sizeof(RAI_Error));
59-
if (!err) {
60-
return REDISMODULE_ERR;
61-
}
6259
err->code = 0;
6360
err->detail = NULL;
6461
err->detail_oneline = NULL;

src/model.c

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,19 +107,24 @@ static void *RAI_Model_RdbLoad(struct RedisModuleIO *io, int encver) {
107107
return NULL;
108108
}
109109

110+
for (size_t i = 0; i < ninputs; i++) {
111+
RedisModule_Free(inputs[i]);
112+
}
113+
for (size_t i = 0; i < noutputs; i++) {
114+
RedisModule_Free(outputs[i]);
115+
}
110116
RedisModule_Free(inputs);
111117
RedisModule_Free(outputs);
112118
RedisModule_Free(buffer);
113119

114120
RedisModuleCtx *stats_ctx = RedisModule_GetContextFromIO(io);
115121
RedisModuleString *stats_keystr =
116122
RedisModule_CreateStringFromString(stats_ctx, RedisModule_GetKeyNameFromIO(io));
117-
const char *stats_devicestr = RedisModule_Strdup(devicestr);
118-
RedisModuleString *stats_tag = RAI_HoldString(NULL, tag);
119123

120-
model->infokey =
121-
RAI_AddStatsEntry(stats_ctx, stats_keystr, RAI_MODEL, backend, stats_devicestr, stats_tag);
124+
model->infokey = RAI_AddStatsEntry(stats_ctx, stats_keystr, RAI_MODEL, backend, devicestr, tag);
122125

126+
RedisModule_FreeString(NULL, tag);
127+
RedisModule_Free(devicestr);
123128
RedisModule_FreeString(NULL, stats_keystr);
124129

125130
return model;
@@ -371,7 +376,6 @@ void RAI_ModelFree(RAI_Model *model, RAI_Error *err) {
371376
}
372377

373378
RedisModule_FreeString(NULL, model->tag);
374-
375379
RAI_RemoveStatsEntry(model->infokey);
376380

377381
RedisModule_Free(model);
@@ -504,19 +508,17 @@ RedisModuleType *RAI_ModelRedisType(void) { return RedisAI_ModelType; }
504508
int RAI_ModelRunAsync(RAI_ModelRunCtx *mctx, RAI_OnFinishCB ModelAsyncFinish, void *private_data) {
505509

506510
RedisAI_RunInfo *rinfo = NULL;
507-
if (RAI_InitRunInfo(&rinfo) == REDISMODULE_ERR) {
508-
return REDISMODULE_ERR;
509-
}
511+
RAI_InitRunInfo(&rinfo);
512+
510513
rinfo->single_op_dag = 1;
511514
rinfo->OnFinish = (RedisAI_OnFinishCB)ModelAsyncFinish;
512515
rinfo->private_data = private_data;
513516

514517
RAI_DagOp *op;
515-
if (RAI_InitDagOp(&op) == REDISMODULE_ERR) {
516-
return REDISMODULE_ERR;
517-
}
518+
RAI_InitDagOp(&op);
518519
op->commandType = REDISAI_DAG_CMD_MODELRUN;
519-
Dag_PopulateOp(op, mctx, NULL, NULL, NULL);
520+
op->devicestr = mctx->model->devicestr;
521+
op->mctx = mctx;
520522

521523
rinfo->dagOps = array_append(rinfo->dagOps, op);
522524
rinfo->dagOpCount = 1;

0 commit comments

Comments
 (0)