@@ -31,9 +31,11 @@ void _SetTensorsInDagLocalContext(RedisAI_RunInfo *rinfo) {
3131int _MangleTensorsNames (RedisModuleCtx * ctx , RedisAI_RunInfo * rinfo ) {
3232
3333 int res = REDISMODULE_ERR ;
34- AI_dict * mangled_tensors = AI_dictCreate (& AI_dictTypeHeapRStrings , NULL );
34+ AI_dict * occurrences_counter = AI_dictCreate (& AI_dictTypeHeapRStrings , NULL );
3535
3636 {
37+ // We first save the tensors' names that were indicated in the LOAD phase.
38+ // These tensors where loaded and kept in dagTensorsContext with their "mangled" name.
3739 AI_dictIterator * iter = AI_dictGetSafeIterator (rinfo -> dagTensorsContext );
3840 AI_dictEntry * entry = AI_dictNext (iter );
3941 while (entry ) {
@@ -43,7 +45,7 @@ int _MangleTensorsNames(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
4345 RedisModuleString * demangled_key = RedisModule_CreateString (NULL , key_str , key_len - 4 );
4446 int * instance = RedisModule_Alloc (sizeof (int ));
4547 * instance = 1 ;
46- AI_dictAdd (mangled_tensors , (void * )demangled_key , (void * )instance );
48+ AI_dictAdd (occurrences_counter , (void * )demangled_key , (void * )instance );
4749 RedisModule_FreeString (NULL , demangled_key );
4850 entry = AI_dictNext (iter );
4951 }
@@ -57,7 +59,7 @@ int _MangleTensorsNames(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
5759 array_new (RedisModuleString * , array_len (currentOp -> inkeys ));
5860 for (long long j = 0 ; j < array_len (currentOp -> inkeys ); j ++ ) {
5961 RedisModuleString * key = currentOp -> inkeys [j ];
60- AI_dictEntry * entry = AI_dictFind (mangled_tensors , key );
62+ AI_dictEntry * entry = AI_dictFind (occurrences_counter , key );
6163 if (!entry ) {
6264 array_free (mangled_inkeys );
6365 RedisModule_ReplyWithError (ctx , "ERR INPUT key cannot be found in DAG" );
@@ -75,15 +77,15 @@ int _MangleTensorsNames(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
7577 array_new (RedisModuleString * , array_len (currentOp -> outkeys ));
7678 for (long long j = 0 ; j < array_len (currentOp -> outkeys ); j ++ ) {
7779 RedisModuleString * key = currentOp -> outkeys [j ];
78- AI_dictEntry * entry = AI_dictFind (mangled_tensors , key );
80+ AI_dictEntry * entry = AI_dictFind (occurrences_counter , key );
7981 int * instance = NULL ;
8082 if (entry ) {
8183 instance = AI_dictGetVal (entry );
8284 * instance += 1 ;
8385 } else {
8486 instance = RedisModule_Alloc (sizeof (int ));
8587 * instance = 1 ;
86- AI_dictAdd (mangled_tensors , (void * )key , (void * )instance );
88+ AI_dictAdd (occurrences_counter , (void * )key , (void * )instance );
8789 }
8890 char buf [16 ];
8991 sprintf (buf , "%04d" , * instance );
@@ -92,31 +94,30 @@ int _MangleTensorsNames(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
9294 mangled_outkeys = array_append (mangled_outkeys , mangled_key );
9395 }
9496
95- if (currentOp -> inkeys ) {
96- for (size_t j = 0 ; j < array_len (currentOp -> inkeys ); j ++ ) {
97- RedisModule_FreeString (NULL , currentOp -> inkeys [j ]);
98- }
99- array_free (currentOp -> inkeys );
97+ for (size_t j = 0 ; j < array_len (currentOp -> inkeys ); j ++ ) {
98+ RedisModule_FreeString (NULL , currentOp -> inkeys [j ]);
10099 }
100+ array_free (currentOp -> inkeys );
101101
102- if (currentOp -> outkeys ) {
103- for (size_t j = 0 ; j < array_len (currentOp -> outkeys ); j ++ ) {
104- RedisModule_FreeString (NULL , currentOp -> outkeys [j ]);
105- }
106- array_free (currentOp -> outkeys );
102+ for (size_t j = 0 ; j < array_len (currentOp -> outkeys ); j ++ ) {
103+ RedisModule_FreeString (NULL , currentOp -> outkeys [j ]);
107104 }
105+ array_free (currentOp -> outkeys );
108106
109107 currentOp -> inkeys = mangled_inkeys ;
110108 currentOp -> outkeys = mangled_outkeys ;
111109 }
112110
111+ // If we need to persist a certain tensor under a specified key, we need to take it
112+ // from the last op in which this key appears (that is, the tensor associated with
113+ // the "maximal" mangled name generated from that key).
113114 AI_dict * mangled_persisted = AI_dictCreate (& AI_dictTypeHeapRStrings , NULL );
114115 {
115116 AI_dictIterator * iter = AI_dictGetSafeIterator (rinfo -> dagTensorsPersistedContext );
116117 AI_dictEntry * entry = AI_dictNext (iter );
117118 while (entry ) {
118119 RedisModuleString * key = (RedisModuleString * )AI_dictGetKey (entry );
119- AI_dictEntry * mangled_entry = AI_dictFind (mangled_tensors , key );
120+ AI_dictEntry * mangled_entry = AI_dictFind (occurrences_counter , key );
120121 if (!mangled_entry ) {
121122 AI_dictRelease (mangled_persisted );
122123 AI_dictReleaseIterator (iter );
@@ -126,10 +127,8 @@ int _MangleTensorsNames(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
126127 int * instance = AI_dictGetVal (mangled_entry );
127128 char buf [16 ];
128129 sprintf (buf , "%04d" , * instance );
129- RedisModuleString * mangled_key = RedisModule_CreateStringFromString (NULL , key );
130- RedisModule_StringAppendBuffer (NULL , mangled_key , buf , strlen (buf ));
131- AI_dictAdd (mangled_persisted , (void * )mangled_key , (void * )1 );
132- RedisModule_FreeString (NULL , mangled_key );
130+ RedisModule_StringAppendBuffer (NULL , key , buf , strlen (buf ));
131+ AI_dictAdd (mangled_persisted , (void * )key , (void * )1 );
133132 entry = AI_dictNext (iter );
134133 }
135134 AI_dictReleaseIterator (iter );
@@ -146,7 +145,7 @@ int _MangleTensorsNames(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
146145 res = REDISMODULE_OK ;
147146
148147cleanup : {
149- AI_dictIterator * iter = AI_dictGetSafeIterator (mangled_tensors );
148+ AI_dictIterator * iter = AI_dictGetSafeIterator (occurrences_counter );
150149 AI_dictEntry * entry = AI_dictNext (iter );
151150 while (entry ) {
152151 int * val = (int * )AI_dictGetVal (entry );
@@ -155,7 +154,7 @@ cleanup : {
155154 }
156155 AI_dictReleaseIterator (iter );
157156}
158- AI_dictRelease (mangled_tensors );
157+ AI_dictRelease (occurrences_counter );
159158 return res ;
160159}
161160
@@ -193,12 +192,13 @@ static int _ParseDAGLoadArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int
193192
194193 // Go over the given args and load the tensors from keyspace.
195194 for (size_t argpos = 2 ; argpos < argc && number_loaded_keys < n_keys ; argpos ++ ) {
196- const char * arg_string = RedisModule_StringPtrLen (argv [argpos ], & arg_len );
195+ RedisModuleString * key_name = argv [argpos ];
196+ const char * arg_string = RedisModule_StringPtrLen (key_name , & arg_len );
197197 if (!strcasecmp (arg_string , chaining_operator ))
198198 break ;
199199 RAI_Tensor * t ;
200200 RedisModuleKey * key ;
201- const int status = RAI_GetTensorFromKeyspace (ctx , argv [ argpos ] , & key , & t , REDISMODULE_READ );
201+ const int status = RAI_GetTensorFromKeyspace (ctx , key_name , & key , & t , REDISMODULE_READ );
202202 if (status == REDISMODULE_ERR ) {
203203 RedisModule_Log (ctx , "warning" ,
204204 "on DAGRUN's LOAD could not load tensor %s from keyspace" , arg_string );
@@ -208,10 +208,8 @@ static int _ParseDAGLoadArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int
208208 // Add the tensor under its "mangled" key name to the DAG local context dict.
209209 char buf [16 ];
210210 sprintf (buf , "%04d" , 1 );
211- RedisModuleString * dictKey = RedisModule_CreateStringFromString (NULL , argv [argpos ]);
212- RedisModule_StringAppendBuffer (NULL , dictKey , buf , strlen (buf ));
213- AI_dictAdd (* localContextDict , (void * )dictKey , (void * )RAI_TensorGetShallowCopy (t ));
214- RedisModule_FreeString (NULL , dictKey );
211+ RedisModule_StringAppendBuffer (NULL , key_name , buf , strlen (buf ));
212+ AI_dictAdd (* localContextDict , (void * )key_name , (void * )RAI_TensorGetShallowCopy (t ));
215213 number_loaded_keys ++ ;
216214 }
217215
0 commit comments