@@ -294,9 +294,9 @@ size_t RAI_DagOpBatchSize(RAI_DagOp *op, RedisAI_RunInfo *rinfo) {
294294 }
295295
296296 size_t ninputs = array_len (op -> inkeys );
297- size_t batchsize = 0 ;
297+ int batchsize = 0 ;
298298
299- if (!rinfo -> single_op_dag ) {
299+ if (!rinfo -> single_device_dag ) {
300300 RAI_ContextReadLock (rinfo );
301301 }
302302 for (size_t i = 0 ; i < ninputs ; i ++ ) {
@@ -322,7 +322,7 @@ size_t RAI_DagOpBatchSize(RAI_DagOp *op, RedisAI_RunInfo *rinfo) {
322322 break ;
323323 }
324324 }
325- if (!rinfo -> single_op_dag ) {
325+ if (!rinfo -> single_device_dag ) {
326326 RAI_ContextUnlock (rinfo );
327327 }
328328 return batchsize ;
@@ -334,20 +334,21 @@ int RAI_DagOpBatchable(RAI_DagOp *op1, RedisAI_RunInfo *rinfo1, RAI_DagOp *op2,
334334 if (op1 -> mctx == NULL || op2 -> mctx == NULL ) {
335335 return 0 ;
336336 }
337-
338337 if (op1 -> mctx -> model != op2 -> mctx -> model ) {
339338 return 0 ;
340339 }
341-
342- // const int ninputs1 = RAI_ModelRunCtxNumInputs(op1->mctx);
343- // const int ninputs2 = RAI_ModelRunCtxNumInputs(op2->mctx);
344340 const int ninputs1 = array_len (op1 -> inkeys );
345341 const int ninputs2 = array_len (op2 -> inkeys );
346342
347343 if (ninputs1 != ninputs2 ) {
348344 return 0 ;
349345 }
350-
346+ if (!rinfo1 -> single_device_dag ) {
347+ RAI_ContextReadLock (rinfo1 );
348+ }
349+ if (!rinfo2 -> single_device_dag ) {
350+ RAI_ContextReadLock (rinfo2 );
351+ }
351352 for (int i = 0 ; i < ninputs1 ; i ++ ) {
352353 RAI_Tensor * input1 ;
353354 if (rinfo1 -> single_op_dag == 1 ) {
@@ -388,6 +389,12 @@ int RAI_DagOpBatchable(RAI_DagOp *op1, RedisAI_RunInfo *rinfo1, RAI_DagOp *op2,
388389 }
389390 }
390391 }
392+ if (!rinfo1 -> single_device_dag ) {
393+ RAI_ContextUnlock (rinfo1 );
394+ }
395+ if (!rinfo2 -> single_device_dag ) {
396+ RAI_ContextUnlock (rinfo2 );
397+ }
391398 return 1 ;
392399}
393400
@@ -462,8 +469,6 @@ void RedisAI_DagOpBatchingMatch(RedisAI_RunInfo *rinfo1, RAI_DagOp *op1, RedisAI
462469 * batched = 0 ;
463470 * inbatchsize = 0 ;
464471
465- RAI_ContextReadLock (rinfo2 );
466-
467472 if (op2 -> mctx ) {
468473 int match = RAI_DagOpBatchable (op1 , rinfo1 , op2 , rinfo2 );
469474
@@ -472,8 +477,6 @@ void RedisAI_DagOpBatchingMatch(RedisAI_RunInfo *rinfo1, RAI_DagOp *op1, RedisAI
472477 * inbatchsize = RAI_DagOpBatchSize (op2 , rinfo2 );
473478 }
474479 }
475-
476- RAI_ContextUnlock (rinfo2 );
477480}
478481
479482void RedisAI_DagRunSessionStep (RedisAI_RunInfo * rinfo , const char * devicestr ) {
0 commit comments