Skip to content

Commit f45b8d8

Browse files
author
DvirDukhan
authored
Merge branch 'master' into skip_test_on_valgrind
2 parents 242dbd9 + 34ee494 commit f45b8d8

File tree

13 files changed

+468
-147
lines changed

13 files changed

+468
-147
lines changed

src/DAG/dag.c

Lines changed: 61 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -237,53 +237,52 @@ void RedisAI_DagRunSession_ScriptRun_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur
237237
uint n_inkeys = array_len(currentOp->inkeys);
238238
uint n_outkeys = array_len(currentOp->outkeys);
239239

240-
RAI_ContextReadLock(rinfo);
240+
if (!rinfo->single_op_dag) {
241241

242-
RAI_Tensor *inputTensors[n_inkeys];
243-
for (uint i = 0; i < n_inkeys; i++) {
244-
RAI_Tensor *inputTensor;
245-
const int get_result = RAI_getTensorFromLocalContext(
246-
NULL, rinfo->dagTensorsContext, currentOp->inkeys[i], &inputTensor, currentOp->err);
247-
if (get_result == REDISMODULE_ERR) {
248-
// We check for this outside the function
249-
// this check cannot be covered by tests
250-
currentOp->result = REDISMODULE_ERR;
251-
RAI_ContextUnlock(rinfo);
252-
return;
242+
RAI_ContextReadLock(rinfo);
243+
RAI_Tensor *inputTensors[n_inkeys];
244+
for (uint i = 0; i < n_inkeys; i++) {
245+
RAI_Tensor *inputTensor;
246+
const int get_result = RAI_getTensorFromLocalContext(
247+
NULL, rinfo->dagTensorsContext, currentOp->inkeys[i], &inputTensor, currentOp->err);
248+
if (get_result == REDISMODULE_ERR) {
249+
// We check for this outside the function
250+
// this check cannot be covered by tests
251+
currentOp->result = REDISMODULE_ERR;
252+
RAI_ContextUnlock(rinfo);
253+
return;
254+
}
255+
inputTensors[i] = inputTensor;
253256
}
254-
inputTensors[i] = inputTensor;
255-
}
256-
257-
RAI_ContextUnlock(rinfo);
258-
259-
for (uint i = 0; i < n_inkeys; i++) {
260-
RAI_ScriptRunCtxAddInput(currentOp->sctx, inputTensors[i], currentOp->err);
261-
}
257+
RAI_ContextUnlock(rinfo);
262258

263-
for (uint i = 0; i < n_outkeys; i++) {
264-
RAI_ScriptRunCtxAddOutput(currentOp->sctx);
259+
for (uint i = 0; i < n_inkeys; i++) {
260+
RAI_ScriptRunCtxAddInput(currentOp->sctx, inputTensors[i], currentOp->err);
261+
}
262+
for (uint i = 0; i < n_outkeys; i++) {
263+
RAI_ScriptRunCtxAddOutput(currentOp->sctx);
264+
}
265265
}
266266

267267
const long long start = ustime();
268268
int result = RAI_ScriptRun(currentOp->sctx, currentOp->err);
269269
const long long end = ustime();
270270

271-
RAI_ContextWriteLock(rinfo);
272-
273-
const size_t noutputs = RAI_ScriptRunCtxNumOutputs(currentOp->sctx);
274-
for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) {
275-
RAI_Tensor *tensor = RAI_ScriptRunCtxOutputTensor(currentOp->sctx, outputNumber);
276-
RedisModuleString *key_string = currentOp->outkeys[outputNumber];
277-
tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL;
278-
AI_dictReplace(rinfo->dagTensorsContext, (void *)key_string, tensor);
279-
}
280-
281271
currentOp->result = result;
282272
currentOp->duration_us = end - start;
283273

284-
RAI_ContextUnlock(rinfo);
274+
if (!rinfo->single_op_dag) {
285275

286-
return;
276+
RAI_ContextWriteLock(rinfo);
277+
const size_t noutputs = RAI_ScriptRunCtxNumOutputs(currentOp->sctx);
278+
for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) {
279+
RAI_Tensor *tensor = RAI_ScriptRunCtxOutputTensor(currentOp->sctx, outputNumber);
280+
RedisModuleString *key_string = currentOp->outkeys[outputNumber];
281+
tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL;
282+
AI_dictReplace(rinfo->dagTensorsContext, (void *)key_string, tensor);
283+
}
284+
RAI_ContextUnlock(rinfo);
285+
}
287286
}
288287

289288
size_t RAI_DagOpBatchSize(RAI_DagOp *op, RedisAI_RunInfo *rinfo) {
@@ -572,17 +571,16 @@ static int _StoreTensorInKeySpace(RedisModuleCtx *ctx, RAI_Tensor *tensor,
572571
return ret;
573572
}
574573

575-
static void PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
574+
static void _PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
575+
576576
AI_dictIterator *persist_iter = AI_dictGetSafeIterator(rinfo->dagTensorsPersistedContext);
577577
AI_dictEntry *persist_entry = AI_dictNext(persist_iter);
578+
578579
while (persist_entry) {
579580
RedisModuleString *persist_key_name = AI_dictGetKey(persist_entry);
580-
581581
AI_dictEntry *tensor_entry = AI_dictFind(rinfo->dagTensorsContext, persist_key_name);
582-
583582
if (tensor_entry) {
584583
RAI_Tensor *tensor = AI_dictGetVal(tensor_entry);
585-
586584
if (tensor == NULL) {
587585
persist_entry = AI_dictNext(persist_iter);
588586
continue;
@@ -594,17 +592,17 @@ static void PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
594592
RedisModule_ReplyWithError(ctx,
595593
"ERR specified persistent key that was not used in DAG");
596594
rinfo->dagReplyLength++;
597-
598595
RedisModule_Log(ctx, "warning",
599-
"on DAGRUN's PERSIST pecified persistent key (%s) that "
596+
"on DAGRUN's PERSIST specified persistent key (%s) that "
600597
"was not used on DAG. Logging all local context keys",
601-
persist_key_name);
598+
RedisModule_StringPtrLen(persist_key_name, NULL));
602599
AI_dictIterator *local_iter = AI_dictGetSafeIterator(rinfo->dagTensorsContext);
603600
AI_dictEntry *local_entry = AI_dictNext(local_iter);
601+
604602
while (local_entry) {
605603
RedisModuleString *localcontext_key_name = AI_dictGetKey(local_entry);
606604
RedisModule_Log(ctx, "warning", "DAG's local context key (%s)",
607-
localcontext_key_name);
605+
RedisModule_StringPtrLen(localcontext_key_name, NULL));
608606
local_entry = AI_dictNext(local_iter);
609607
}
610608
AI_dictReleaseIterator(local_iter);
@@ -619,7 +617,7 @@ static void PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
619617
AI_dictReleaseIterator(persist_iter);
620618
}
621619

622-
static void ModelSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
620+
static void _ModelSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
623621
const size_t noutputs = RAI_ModelRunCtxNumOutputs(op->mctx);
624622
for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) {
625623
RAI_Tensor *tensor = RAI_ModelRunCtxOutputTensor(op->mctx, outputNumber);
@@ -629,6 +627,16 @@ static void ModelSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
629627
}
630628
}
631629

630+
static void _ScriptSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
631+
const size_t noutputs = RAI_ScriptRunCtxNumOutputs(op->sctx);
632+
for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) {
633+
RAI_Tensor *tensor = RAI_ScriptRunCtxOutputTensor(op->sctx, outputNumber);
634+
tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL;
635+
if (tensor)
636+
_StoreTensorInKeySpace(ctx, tensor, op->outkeys[outputNumber], false);
637+
}
638+
}
639+
632640
int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
633641
REDISMODULE_NOT_USED(argv);
634642
REDISMODULE_NOT_USED(argc);
@@ -650,7 +658,7 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
650658
return REDISMODULE_OK;
651659
}
652660

653-
if (rinfo->single_op_dag == 0) {
661+
if (!rinfo->single_op_dag) {
654662
RedisModule_ReplyWithArray(ctx, REDISMODULE_POSTPONED_ARRAY_LEN);
655663
}
656664

@@ -745,18 +753,20 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
745753
return REDISMODULE_ERR;
746754
}
747755

748-
// TODO: Take care of script single op
749-
if (rinfo->single_op_dag == 0 || rinfo->dagOps[0]->commandType == REDISAI_DAG_CMD_SCRIPTRUN) {
756+
if (!rinfo->single_op_dag) {
750757
// Save the required tensors in redis key space.
751-
PersistTensors(ctx, rinfo);
752-
if (rinfo->single_op_dag == 0)
753-
RedisModule_ReplySetArrayLength(ctx, rinfo->dagReplyLength);
758+
_PersistTensors(ctx, rinfo);
759+
RedisModule_ReplySetArrayLength(ctx, rinfo->dagReplyLength);
754760
} else {
755-
ModelSingleOp_PersistTensors(ctx, rinfo->dagOps[0]);
761+
if (rinfo->dagOps[0]->commandType == REDISAI_DAG_CMD_MODELRUN) {
762+
_ModelSingleOp_PersistTensors(ctx, rinfo->dagOps[0]);
763+
} else {
764+
RedisModule_Assert(rinfo->dagOps[0]->commandType == REDISAI_DAG_CMD_SCRIPTRUN);
765+
_ScriptSingleOp_PersistTensors(ctx, rinfo->dagOps[0]);
766+
}
756767
}
757768

758769
RAI_FreeRunInfo(rinfo);
759-
760770
return REDISMODULE_OK;
761771
}
762772

src/DAG/dag_parser.c

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -136,18 +136,9 @@ int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, b
136136
int chainingOpCount = 0;
137137
bool load_complete = false;
138138
bool persist_complete = false;
139-
int arg_pos = 1;
140-
141-
// If we're parsing a AI.SCRIPTRUN command, we don't expect there to be a chaining |> operator
142-
if (!strcasecmp(RedisModule_StringPtrLen(argv[0], NULL), "AI.SCRIPTRUN")) {
143-
arg_pos = 0;
144-
chainingOpCount++;
145-
rinfo->single_op_dag = 1;
146-
rinfo->single_device_dag = 1;
147-
}
148139

149140
// The first arg is "AI.DAGRUN", so we go over from the next arg.
150-
for (; arg_pos < argc; arg_pos++) {
141+
for (int arg_pos = 1; arg_pos < argc; arg_pos++) {
151142
const char *arg_string = RedisModule_StringPtrLen(argv[arg_pos], NULL);
152143

153144
if (!strcasecmp(arg_string, "LOAD") && !load_complete) {
@@ -303,35 +294,6 @@ int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, b
303294
}
304295
}
305296

306-
if (rinfo->single_op_dag && rinfo->dagOps[0]->commandType == REDISAI_DAG_CMD_SCRIPTRUN) {
307-
RAI_DagOp *op = rinfo->dagOps[0];
308-
RAI_Tensor *t;
309-
RedisModuleKey *key;
310-
for (size_t i = 0; i < array_len(op->inkeys); i++) {
311-
RedisModuleString *inkey = op->inkeys[i];
312-
const int status = RAI_GetTensorFromKeyspace(ctx, inkey, &key, &t, REDISMODULE_READ);
313-
if (status == REDISMODULE_ERR) {
314-
RedisModule_Log(ctx, "warning",
315-
"on DAGRUN's LOAD could not load tensor %s from keyspace",
316-
RedisModule_StringPtrLen(inkey, NULL));
317-
return REDISMODULE_ERR;
318-
}
319-
char buf[16];
320-
sprintf(buf, "%04d", 1);
321-
RedisModuleString *dictKey = RedisModule_CreateStringFromString(NULL, inkey);
322-
RedisModule_StringAppendBuffer(NULL, dictKey, buf, strlen(buf));
323-
AI_dictAdd(rinfo->dagTensorsContext, (void *)dictKey,
324-
(void *)RAI_TensorGetShallowCopy(t));
325-
AI_dictAdd(rinfo->dagTensorsLoadedContext, (void *)dictKey, (void *)1);
326-
RedisModule_Free(dictKey);
327-
}
328-
329-
for (size_t i = 0; i < array_len(op->outkeys); i++) {
330-
RedisModuleString *outkey = op->outkeys[i];
331-
AI_dictAdd(rinfo->dagTensorsPersistedContext, (void *)outkey, (void *)1);
332-
}
333-
}
334-
335297
// At this point, we have built a sequence of DAG operations, each with its own
336298
// input and output keys. The names of the keys will be used to look whether the
337299
// inputs to a DAG operation have all been realized by previous operations (or if
@@ -462,4 +424,4 @@ int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, b
462424
}
463425
}
464426
return REDISMODULE_OK;
465-
}
427+
}

0 commit comments

Comments
 (0)