Skip to content

Commit 2b539d4

Browse files
committed
GPU bug fix
1 parent 13bd6b7 commit 2b539d4

File tree

3 files changed

+30
-15
lines changed

3 files changed

+30
-15
lines changed

src/DAG/dag.c

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -294,9 +294,9 @@ size_t RAI_DagOpBatchSize(RAI_DagOp *op, RedisAI_RunInfo *rinfo) {
294294
}
295295

296296
size_t ninputs = array_len(op->inkeys);
297-
size_t batchsize = 0;
297+
int batchsize = 0;
298298

299-
if (!rinfo->single_op_dag) {
299+
if (!rinfo->single_device_dag) {
300300
RAI_ContextReadLock(rinfo);
301301
}
302302
for (size_t i = 0; i < ninputs; i++) {
@@ -322,7 +322,7 @@ size_t RAI_DagOpBatchSize(RAI_DagOp *op, RedisAI_RunInfo *rinfo) {
322322
break;
323323
}
324324
}
325-
if (!rinfo->single_op_dag) {
325+
if (!rinfo->single_device_dag) {
326326
RAI_ContextUnlock(rinfo);
327327
}
328328
return batchsize;
@@ -334,20 +334,21 @@ int RAI_DagOpBatchable(RAI_DagOp *op1, RedisAI_RunInfo *rinfo1, RAI_DagOp *op2,
334334
if (op1->mctx == NULL || op2->mctx == NULL) {
335335
return 0;
336336
}
337-
338337
if (op1->mctx->model != op2->mctx->model) {
339338
return 0;
340339
}
341-
342-
// const int ninputs1 = RAI_ModelRunCtxNumInputs(op1->mctx);
343-
// const int ninputs2 = RAI_ModelRunCtxNumInputs(op2->mctx);
344340
const int ninputs1 = array_len(op1->inkeys);
345341
const int ninputs2 = array_len(op2->inkeys);
346342

347343
if (ninputs1 != ninputs2) {
348344
return 0;
349345
}
350-
346+
if (!rinfo1->single_device_dag) {
347+
RAI_ContextReadLock(rinfo1);
348+
}
349+
if (!rinfo2->single_device_dag) {
350+
RAI_ContextReadLock(rinfo2);
351+
}
351352
for (int i = 0; i < ninputs1; i++) {
352353
RAI_Tensor *input1;
353354
if (rinfo1->single_op_dag == 1) {
@@ -388,6 +389,12 @@ int RAI_DagOpBatchable(RAI_DagOp *op1, RedisAI_RunInfo *rinfo1, RAI_DagOp *op2,
388389
}
389390
}
390391
}
392+
if (!rinfo1->single_device_dag) {
393+
RAI_ContextUnlock(rinfo1);
394+
}
395+
if (!rinfo2->single_device_dag) {
396+
RAI_ContextUnlock(rinfo2);
397+
}
391398
return 1;
392399
}
393400

@@ -462,8 +469,6 @@ void RedisAI_DagOpBatchingMatch(RedisAI_RunInfo *rinfo1, RAI_DagOp *op1, RedisAI
462469
*batched = 0;
463470
*inbatchsize = 0;
464471

465-
RAI_ContextReadLock(rinfo2);
466-
467472
if (op2->mctx) {
468473
int match = RAI_DagOpBatchable(op1, rinfo1, op2, rinfo2);
469474

@@ -472,8 +477,6 @@ void RedisAI_DagOpBatchingMatch(RedisAI_RunInfo *rinfo1, RAI_DagOp *op1, RedisAI
472477
*inbatchsize = RAI_DagOpBatchSize(op2, rinfo2);
473478
}
474479
}
475-
476-
RAI_ContextUnlock(rinfo2);
477480
}
478481

479482
void RedisAI_DagRunSessionStep(RedisAI_RunInfo *rinfo, const char *devicestr) {

src/command_parser.c

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@ static int _ModelRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **a
3434
RAI_SetError(error, RAI_EMODELRUN, "ERR Model not found");
3535
return REDISMODULE_ERR;
3636
}
37-
RedisModule_HoldString(NULL, argv[argpos]);
37+
if (RMAPI_FUNC_SUPPORTED(RedisModule_HoldString)) {
38+
RedisModule_HoldString(NULL, argv[argpos]);
39+
} else {
40+
RedisModule_RetainString(NULL, argv[argpos]);
41+
}
3842
*runkey = argv[argpos];
3943
const char *arg_string = RedisModule_StringPtrLen(argv[++argpos], NULL);
4044

@@ -58,7 +62,11 @@ static int _ModelRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **a
5862
is_input = false;
5963
is_output = true;
6064
} else {
61-
RedisModule_HoldString(NULL, argv[argpos]);
65+
if (RMAPI_FUNC_SUPPORTED(RedisModule_HoldString)) {
66+
RedisModule_HoldString(NULL, argv[argpos]);
67+
} else {
68+
RedisModule_RetainString(NULL, argv[argpos]);
69+
}
6270
if (is_input) {
6371
ninputs++;
6472
*inkeys = array_append(*inkeys, argv[argpos]);

src/modelRun_ctx.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,11 @@ int RedisAI_Parse_ModelRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString *
9898
is_input = 1;
9999
outputs_flag_count = 1;
100100
} else {
101-
RedisModule_HoldString(ctx, argv[argpos]);
101+
if (RMAPI_FUNC_SUPPORTED(RedisModule_HoldString)) {
102+
RedisModule_HoldString(NULL, argv[argpos]);
103+
} else {
104+
RedisModule_RetainString(NULL, argv[argpos]);
105+
}
102106
if (is_input == 0) {
103107
*inkeys = array_append(*inkeys, argv[argpos]);
104108
ninputs++;

0 commit comments

Comments
 (0)