Skip to content

Commit e56046f

Browse files
committed
PR fixes
1 parent a5b5b3e commit e56046f

File tree

2 files changed

+40
-40
lines changed

2 files changed

+40
-40
lines changed

src/DAG/dag_parser.c

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@ void _SetTensorsInDagLocalContext(RedisAI_RunInfo *rinfo) {
3131
int _MangleTensorsNames(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
3232

3333
int res = REDISMODULE_ERR;
34-
AI_dict *mangled_tensors = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL);
34+
AI_dict *occurrences_counter = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL);
3535

3636
{
37+
// We first save the tensors' names that were indicated in the LOAD phase.
38+
// These tensors where loaded and kept in dagTensorsContext with their "mangled" name.
3739
AI_dictIterator *iter = AI_dictGetSafeIterator(rinfo->dagTensorsContext);
3840
AI_dictEntry *entry = AI_dictNext(iter);
3941
while (entry) {
@@ -43,7 +45,7 @@ int _MangleTensorsNames(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
4345
RedisModuleString *demangled_key = RedisModule_CreateString(NULL, key_str, key_len - 4);
4446
int *instance = RedisModule_Alloc(sizeof(int));
4547
*instance = 1;
46-
AI_dictAdd(mangled_tensors, (void *)demangled_key, (void *)instance);
48+
AI_dictAdd(occurrences_counter, (void *)demangled_key, (void *)instance);
4749
RedisModule_FreeString(NULL, demangled_key);
4850
entry = AI_dictNext(iter);
4951
}
@@ -57,7 +59,7 @@ int _MangleTensorsNames(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
5759
array_new(RedisModuleString *, array_len(currentOp->inkeys));
5860
for (long long j = 0; j < array_len(currentOp->inkeys); j++) {
5961
RedisModuleString *key = currentOp->inkeys[j];
60-
AI_dictEntry *entry = AI_dictFind(mangled_tensors, key);
62+
AI_dictEntry *entry = AI_dictFind(occurrences_counter, key);
6163
if (!entry) {
6264
array_free(mangled_inkeys);
6365
RedisModule_ReplyWithError(ctx, "ERR INPUT key cannot be found in DAG");
@@ -75,15 +77,15 @@ int _MangleTensorsNames(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
7577
array_new(RedisModuleString *, array_len(currentOp->outkeys));
7678
for (long long j = 0; j < array_len(currentOp->outkeys); j++) {
7779
RedisModuleString *key = currentOp->outkeys[j];
78-
AI_dictEntry *entry = AI_dictFind(mangled_tensors, key);
80+
AI_dictEntry *entry = AI_dictFind(occurrences_counter, key);
7981
int *instance = NULL;
8082
if (entry) {
8183
instance = AI_dictGetVal(entry);
8284
*instance += 1;
8385
} else {
8486
instance = RedisModule_Alloc(sizeof(int));
8587
*instance = 1;
86-
AI_dictAdd(mangled_tensors, (void *)key, (void *)instance);
88+
AI_dictAdd(occurrences_counter, (void *)key, (void *)instance);
8789
}
8890
char buf[16];
8991
sprintf(buf, "%04d", *instance);
@@ -92,31 +94,30 @@ int _MangleTensorsNames(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
9294
mangled_outkeys = array_append(mangled_outkeys, mangled_key);
9395
}
9496

95-
if (currentOp->inkeys) {
96-
for (size_t j = 0; j < array_len(currentOp->inkeys); j++) {
97-
RedisModule_FreeString(NULL, currentOp->inkeys[j]);
98-
}
99-
array_free(currentOp->inkeys);
97+
for (size_t j = 0; j < array_len(currentOp->inkeys); j++) {
98+
RedisModule_FreeString(NULL, currentOp->inkeys[j]);
10099
}
100+
array_free(currentOp->inkeys);
101101

102-
if (currentOp->outkeys) {
103-
for (size_t j = 0; j < array_len(currentOp->outkeys); j++) {
104-
RedisModule_FreeString(NULL, currentOp->outkeys[j]);
105-
}
106-
array_free(currentOp->outkeys);
102+
for (size_t j = 0; j < array_len(currentOp->outkeys); j++) {
103+
RedisModule_FreeString(NULL, currentOp->outkeys[j]);
107104
}
105+
array_free(currentOp->outkeys);
108106

109107
currentOp->inkeys = mangled_inkeys;
110108
currentOp->outkeys = mangled_outkeys;
111109
}
112110

111+
// If we need to persist a certain tensor under a specified key, we need to take it
112+
// from the last op in which this key appears (that is, the tensor associated with
113+
// the "maximal" mangled name generated from that key).
113114
AI_dict *mangled_persisted = AI_dictCreate(&AI_dictTypeHeapRStrings, NULL);
114115
{
115116
AI_dictIterator *iter = AI_dictGetSafeIterator(rinfo->dagTensorsPersistedContext);
116117
AI_dictEntry *entry = AI_dictNext(iter);
117118
while (entry) {
118119
RedisModuleString *key = (RedisModuleString *)AI_dictGetKey(entry);
119-
AI_dictEntry *mangled_entry = AI_dictFind(mangled_tensors, key);
120+
AI_dictEntry *mangled_entry = AI_dictFind(occurrences_counter, key);
120121
if (!mangled_entry) {
121122
AI_dictRelease(mangled_persisted);
122123
AI_dictReleaseIterator(iter);
@@ -126,10 +127,8 @@ int _MangleTensorsNames(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
126127
int *instance = AI_dictGetVal(mangled_entry);
127128
char buf[16];
128129
sprintf(buf, "%04d", *instance);
129-
RedisModuleString *mangled_key = RedisModule_CreateStringFromString(NULL, key);
130-
RedisModule_StringAppendBuffer(NULL, mangled_key, buf, strlen(buf));
131-
AI_dictAdd(mangled_persisted, (void *)mangled_key, (void *)1);
132-
RedisModule_FreeString(NULL, mangled_key);
130+
RedisModule_StringAppendBuffer(NULL, key, buf, strlen(buf));
131+
AI_dictAdd(mangled_persisted, (void *)key, (void *)1);
133132
entry = AI_dictNext(iter);
134133
}
135134
AI_dictReleaseIterator(iter);
@@ -146,7 +145,7 @@ int _MangleTensorsNames(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
146145
res = REDISMODULE_OK;
147146

148147
cleanup : {
149-
AI_dictIterator *iter = AI_dictGetSafeIterator(mangled_tensors);
148+
AI_dictIterator *iter = AI_dictGetSafeIterator(occurrences_counter);
150149
AI_dictEntry *entry = AI_dictNext(iter);
151150
while (entry) {
152151
int *val = (int *)AI_dictGetVal(entry);
@@ -155,7 +154,7 @@ cleanup : {
155154
}
156155
AI_dictReleaseIterator(iter);
157156
}
158-
AI_dictRelease(mangled_tensors);
157+
AI_dictRelease(occurrences_counter);
159158
return res;
160159
}
161160

@@ -193,12 +192,13 @@ static int _ParseDAGLoadArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int
193192

194193
// Go over the given args and load the tensors from keyspace.
195194
for (size_t argpos = 2; argpos < argc && number_loaded_keys < n_keys; argpos++) {
196-
const char *arg_string = RedisModule_StringPtrLen(argv[argpos], &arg_len);
195+
RedisModuleString *key_name = argv[argpos];
196+
const char *arg_string = RedisModule_StringPtrLen(key_name, &arg_len);
197197
if (!strcasecmp(arg_string, chaining_operator))
198198
break;
199199
RAI_Tensor *t;
200200
RedisModuleKey *key;
201-
const int status = RAI_GetTensorFromKeyspace(ctx, argv[argpos], &key, &t, REDISMODULE_READ);
201+
const int status = RAI_GetTensorFromKeyspace(ctx, key_name, &key, &t, REDISMODULE_READ);
202202
if (status == REDISMODULE_ERR) {
203203
RedisModule_Log(ctx, "warning",
204204
"on DAGRUN's LOAD could not load tensor %s from keyspace", arg_string);
@@ -208,10 +208,8 @@ static int _ParseDAGLoadArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int
208208
// Add the tensor under its "mangled" key name to the DAG local context dict.
209209
char buf[16];
210210
sprintf(buf, "%04d", 1);
211-
RedisModuleString *dictKey = RedisModule_CreateStringFromString(NULL, argv[argpos]);
212-
RedisModule_StringAppendBuffer(NULL, dictKey, buf, strlen(buf));
213-
AI_dictAdd(*localContextDict, (void *)dictKey, (void *)RAI_TensorGetShallowCopy(t));
214-
RedisModule_FreeString(NULL, dictKey);
211+
RedisModule_StringAppendBuffer(NULL, key_name, buf, strlen(buf));
212+
AI_dictAdd(*localContextDict, (void *)key_name, (void *)RAI_TensorGetShallowCopy(t));
215213
number_loaded_keys++;
216214
}
217215

src/tensor.c

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -874,22 +874,24 @@ uint ParseTensorGetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
874874

875875
int ReplyWithTensor(RedisModuleCtx *ctx, uint fmt, RAI_Tensor *t) {
876876

877-
if ((fmt & TENSOR_BLOB) && !(fmt & TENSOR_META)) {
878-
long long size = RAI_TensorByteSize(t);
879-
char *data = RAI_TensorData(t);
880-
RedisModule_ReplyWithStringBuffer(ctx, data, size);
881-
return REDISMODULE_OK;
882-
}
883-
if ((fmt & TENSOR_VALUES) && !(fmt & TENSOR_META)) {
884-
int ret = RAI_TensorReplyWithValues(ctx, t);
885-
if (ret == -1) {
886-
return REDISMODULE_ERR;
877+
if (!(fmt & TENSOR_META)) {
878+
if (fmt & TENSOR_BLOB) {
879+
long long size = RAI_TensorByteSize(t);
880+
char *data = RAI_TensorData(t);
881+
RedisModule_ReplyWithStringBuffer(ctx, data, size);
882+
return REDISMODULE_OK;
883+
}
884+
if (fmt & TENSOR_VALUES) {
885+
int ret = RAI_TensorReplyWithValues(ctx, t);
886+
if (ret == -1) {
887+
return REDISMODULE_ERR;
888+
}
889+
return REDISMODULE_OK;
887890
}
888-
return REDISMODULE_OK;
889891
}
890892

891893
long long resplen = 4;
892-
if ((fmt & TENSOR_BLOB) || (fmt & TENSOR_VALUES))
894+
if (fmt & (TENSOR_BLOB | TENSOR_VALUES))
893895
resplen += 2;
894896

895897
const long long ndims = RAI_TensorNumDims(t);

0 commit comments

Comments
 (0)