4545#include "util/dict.h"
4646#include "util/queue.h"
4747#include "dag_parser.h"
48+ #include "util/string_utils.h"
4849
4950/**
5051 * Execution of a TENSORSET DAG step.
@@ -59,7 +60,7 @@ void RedisAI_DagRunSession_TensorSet_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur
5960 const int parse_result =
6061 RAI_parseTensorSetArgs (NULL , currentOp -> argv , currentOp -> argc , & t , 0 , currentOp -> err );
6162 if (parse_result > 0 ) {
62- const char * key_string = RedisModule_StringPtrLen ( currentOp -> outkeys [0 ], NULL ) ;
63+ RedisModuleString * key_string = currentOp -> outkeys [0 ];
6364 RAI_ContextWriteLock (rinfo );
6465 AI_dictReplace (rinfo -> dagTensorsContext , (void * )key_string , t );
6566 RAI_ContextUnlock (rinfo );
@@ -78,7 +79,7 @@ void RedisAI_DagRunSession_TensorSet_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur
7879 * @return
7980 */
8081void RedisAI_DagRunSession_TensorGet_Step (RedisAI_RunInfo * rinfo , RAI_DagOp * currentOp ) {
81- const char * key_string = RedisModule_StringPtrLen ( currentOp -> inkeys [0 ], NULL ) ;
82+ RedisModuleString * key_string = currentOp -> inkeys [0 ];
8283 RAI_Tensor * t = NULL ;
8384 RAI_ContextReadLock (rinfo );
8485 currentOp -> result = RAI_getTensorFromLocalContext (NULL , rinfo -> dagTensorsContext , key_string ,
@@ -102,8 +103,7 @@ static void Dag_LoadInputsToModelRunCtx(RedisAI_RunInfo *rinfo, RAI_DagOp *curre
102103 for (uint i = 0 ; i < n_inkeys ; i ++ ) {
103104 RAI_Tensor * inputTensor ;
104105 const int get_result = RAI_getTensorFromLocalContext (
105- NULL , rinfo -> dagTensorsContext , RedisModule_StringPtrLen (currentOp -> inkeys [i ], NULL ),
106- & inputTensor , currentOp -> err );
106+ NULL , rinfo -> dagTensorsContext , currentOp -> inkeys [i ], & inputTensor , currentOp -> err );
107107 if (get_result == REDISMODULE_ERR ) {
108108 // We check for this outside the function
109109 // this check cannot be covered by tests
@@ -141,9 +141,8 @@ static void Dag_StoreOutputsFromModelRunCtx(RedisAI_RunInfo *rinfo, RAI_DagOp *c
141141 const size_t noutputs = RAI_ModelRunCtxNumOutputs (currentOp -> mctx );
142142 for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
143143 RAI_Tensor * tensor = RAI_ModelRunCtxOutputTensor (currentOp -> mctx , outputNumber );
144- const char * key_string = RedisModule_StringPtrLen (currentOp -> outkeys [outputNumber ], NULL );
145144 tensor = tensor ? RAI_TensorGetShallowCopy (tensor ) : NULL ;
146- AI_dictReplace (rinfo -> dagTensorsContext , (void * )key_string , tensor );
145+ AI_dictReplace (rinfo -> dagTensorsContext , (void * )currentOp -> outkeys [ outputNumber ] , tensor );
147146 }
148147 RAI_ContextUnlock (rinfo );
149148}
@@ -244,8 +243,7 @@ void RedisAI_DagRunSession_ScriptRun_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur
244243 for (uint i = 0 ; i < n_inkeys ; i ++ ) {
245244 RAI_Tensor * inputTensor ;
246245 const int get_result = RAI_getTensorFromLocalContext (
247- NULL , rinfo -> dagTensorsContext , RedisModule_StringPtrLen (currentOp -> inkeys [i ], NULL ),
248- & inputTensor , currentOp -> err );
246+ NULL , rinfo -> dagTensorsContext , currentOp -> inkeys [i ], & inputTensor , currentOp -> err );
249247 if (get_result == REDISMODULE_ERR ) {
250248 // We check for this outside the function
251249 // this check cannot be covered by tests
@@ -275,7 +273,7 @@ void RedisAI_DagRunSession_ScriptRun_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur
275273 const size_t noutputs = RAI_ScriptRunCtxNumOutputs (currentOp -> sctx );
276274 for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
277275 RAI_Tensor * tensor = RAI_ScriptRunCtxOutputTensor (currentOp -> sctx , outputNumber );
278- const char * key_string = RedisModule_StringPtrLen ( currentOp -> outkeys [outputNumber ], NULL ) ;
276+ RedisModuleString * key_string = currentOp -> outkeys [outputNumber ];
279277 tensor = tensor ? RAI_TensorGetShallowCopy (tensor ) : NULL ;
280278 AI_dictReplace (rinfo -> dagTensorsContext , (void * )key_string , tensor );
281279 }
@@ -304,8 +302,7 @@ size_t RAI_DagOpBatchSize(RAI_DagOp *op, RedisAI_RunInfo *rinfo) {
304302 if (rinfo -> single_op_dag ) {
305303 input = op -> mctx -> inputs [i ].tensor ;
306304 } else {
307- RAI_getTensorFromLocalContext (NULL , rinfo -> dagTensorsContext ,
308- RedisModule_StringPtrLen (op -> inkeys [i ], NULL ), & input ,
305+ RAI_getTensorFromLocalContext (NULL , rinfo -> dagTensorsContext , op -> inkeys [i ], & input ,
309306 op -> err );
310307 }
311308 // We are expecting input != NULL, because we only reach this function if all inputs
@@ -354,16 +351,14 @@ int RAI_DagOpBatchable(RAI_DagOp *op1, RedisAI_RunInfo *rinfo1, RAI_DagOp *op2,
354351 if (rinfo1 -> single_op_dag == 1 ) {
355352 input1 = op1 -> mctx -> inputs [i ].tensor ;
356353 } else {
357- RAI_getTensorFromLocalContext (NULL , rinfo1 -> dagTensorsContext ,
358- RedisModule_StringPtrLen (op1 -> inkeys [i ], NULL ), & input1 ,
354+ RAI_getTensorFromLocalContext (NULL , rinfo1 -> dagTensorsContext , op1 -> inkeys [i ], & input1 ,
359355 op1 -> err );
360356 }
361357 RAI_Tensor * input2 ;
362358 if (rinfo2 -> single_op_dag == 1 ) {
363359 input2 = op2 -> mctx -> inputs [i ].tensor ;
364360 } else {
365- RAI_getTensorFromLocalContext (NULL , rinfo2 -> dagTensorsContext ,
366- RedisModule_StringPtrLen (op2 -> inkeys [i ], NULL ), & input2 ,
361+ RAI_getTensorFromLocalContext (NULL , rinfo2 -> dagTensorsContext , op2 -> inkeys [i ], & input2 ,
367362 op2 -> err );
368363 }
369364 if (input1 == NULL || input2 == NULL ) {
@@ -439,8 +434,7 @@ void RedisAI_DagCurrentOpInfo(RedisAI_RunInfo *rinfo, int *currentOpReady,
439434 RAI_ContextReadLock (rinfo );
440435
441436 for (int i = 0 ; i < n_inkeys ; i ++ ) {
442- if (AI_dictFind (rinfo -> dagTensorsContext ,
443- RedisModule_StringPtrLen (currentOp_ -> inkeys [i ], NULL )) == NULL ) {
437+ if (AI_dictFind (rinfo -> dagTensorsContext , currentOp_ -> inkeys [i ]) == NULL ) {
444438 RAI_ContextUnlock (rinfo );
445439 * currentOpReady = 0 ;
446440 return ;
@@ -543,17 +537,22 @@ void RedisAI_BatchedDagRunSessionStep(RedisAI_RunInfo **batched_rinfo, const cha
543537}
544538
545539static int _StoreTensorInKeySpace (RedisModuleCtx * ctx , RAI_Tensor * tensor ,
546- const char * persist_key_name , bool mangled_name ) {
540+ RedisModuleString * persist_key_name , bool mangled_name ) {
541+
547542 int ret = REDISMODULE_ERR ;
548543 RedisModuleKey * key ;
549- char * demangled_key_name = RedisModule_Strdup (persist_key_name );
550- if (mangled_name )
551- demangled_key_name [strlen (persist_key_name ) - 4 ] = 0 ;
552- RedisModuleString * tensor_keyname =
553- RedisModule_CreateString (ctx , demangled_key_name , strlen (demangled_key_name ));
544+ size_t persist_key_len ;
545+ const char * persist_key_str = RedisModule_StringPtrLen (persist_key_name , & persist_key_len );
546+
547+ RedisModuleString * demangled_key_name ;
548+ if (mangled_name ) {
549+ demangled_key_name = RedisModule_CreateString (NULL , persist_key_str , persist_key_len - 4 );
550+ } else {
551+ demangled_key_name = RedisModule_CreateString (NULL , persist_key_str , persist_key_len );
552+ }
553+
554554 const int status =
555- RAI_OpenKey_Tensor (ctx , tensor_keyname , & key , REDISMODULE_READ | REDISMODULE_WRITE );
556- RedisModule_Free (demangled_key_name );
555+ RAI_OpenKey_Tensor (ctx , demangled_key_name , & key , REDISMODULE_READ | REDISMODULE_WRITE );
557556 if (status == REDISMODULE_ERR ) {
558557 RedisModule_ReplyWithError (ctx , "ERR could not save tensor" );
559558 goto clean_up ;
@@ -565,17 +564,19 @@ static int _StoreTensorInKeySpace(RedisModuleCtx *ctx, RAI_Tensor *tensor,
565564 }
566565 }
567566 ret = REDISMODULE_OK ;
567+
568568clean_up :
569569 RedisModule_CloseKey (key );
570- RedisAI_ReplicateTensorSet (ctx , tensor_keyname , tensor );
570+ RedisAI_ReplicateTensorSet (ctx , demangled_key_name , tensor );
571+ RedisModule_FreeString (NULL , demangled_key_name );
571572 return ret ;
572573}
573574
574575static void PersistTensors (RedisModuleCtx * ctx , RedisAI_RunInfo * rinfo ) {
575576 AI_dictIterator * persist_iter = AI_dictGetSafeIterator (rinfo -> dagTensorsPersistedContext );
576577 AI_dictEntry * persist_entry = AI_dictNext (persist_iter );
577578 while (persist_entry ) {
578- const char * persist_key_name = AI_dictGetKey (persist_entry );
579+ RedisModuleString * persist_key_name = AI_dictGetKey (persist_entry );
579580
580581 AI_dictEntry * tensor_entry = AI_dictFind (rinfo -> dagTensorsContext , persist_key_name );
581582
@@ -586,8 +587,7 @@ static void PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
586587 persist_entry = AI_dictNext (persist_iter );
587588 continue ;
588589 }
589- bool mangled = true;
590- if (_StoreTensorInKeySpace (ctx , tensor , persist_key_name , mangled ) == REDISMODULE_ERR )
590+ if (_StoreTensorInKeySpace (ctx , tensor , persist_key_name , true) == REDISMODULE_ERR )
591591 rinfo -> dagReplyLength ++ ;
592592
593593 } else {
@@ -602,7 +602,7 @@ static void PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
602602 AI_dictIterator * local_iter = AI_dictGetSafeIterator (rinfo -> dagTensorsContext );
603603 AI_dictEntry * local_entry = AI_dictNext (local_iter );
604604 while (local_entry ) {
605- const char * localcontext_key_name = AI_dictGetKey (local_entry );
605+ RedisModuleString * localcontext_key_name = AI_dictGetKey (local_entry );
606606 RedisModule_Log (ctx , "warning" , "DAG's local context key (%s)" ,
607607 localcontext_key_name );
608608 local_entry = AI_dictNext (local_iter );
@@ -623,11 +623,9 @@ static void ModelSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
623623 const size_t noutputs = RAI_ModelRunCtxNumOutputs (op -> mctx );
624624 for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
625625 RAI_Tensor * tensor = RAI_ModelRunCtxOutputTensor (op -> mctx , outputNumber );
626- const char * key_string = RedisModule_StringPtrLen (op -> outkeys [outputNumber ], NULL );
627626 tensor = tensor ? RAI_TensorGetShallowCopy (tensor ) : NULL ;
628- bool mangled = false;
629627 if (tensor )
630- _StoreTensorInKeySpace (ctx , tensor , key_string , mangled );
628+ _StoreTensorInKeySpace (ctx , tensor , op -> outkeys [ outputNumber ], false );
631629 }
632630}
633631
@@ -693,8 +691,7 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
693691 case REDISAI_DAG_CMD_MODELRUN : {
694692 rinfo -> dagReplyLength ++ ;
695693 struct RedisAI_RunStats * rstats = NULL ;
696- const char * runkey = RedisModule_StringPtrLen (currentOp -> runkey , NULL );
697- RAI_GetRunStats (runkey , & rstats );
694+ RAI_GetRunStats (currentOp -> runkey , & rstats );
698695 if (currentOp -> result == REDISMODULE_ERR ) {
699696 RAI_SafeAddDataPoint (rstats , 0 , 1 , 1 , 0 );
700697 RedisModule_ReplyWithError (ctx , currentOp -> err -> detail_oneline );
@@ -719,8 +716,7 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
719716 case REDISAI_DAG_CMD_SCRIPTRUN : {
720717 rinfo -> dagReplyLength ++ ;
721718 struct RedisAI_RunStats * rstats = NULL ;
722- const char * runkey = RedisModule_StringPtrLen (currentOp -> runkey , NULL );
723- RAI_GetRunStats (runkey , & rstats );
719+ RAI_GetRunStats (currentOp -> runkey , & rstats );
724720 if (currentOp -> result == REDISMODULE_ERR ) {
725721 RAI_SafeAddDataPoint (rstats , 0 , 1 , 1 , 0 );
726722 RedisModule_ReplyWithError (ctx , currentOp -> err -> detail_oneline );
0 commit comments