@@ -91,7 +91,7 @@ static void Dag_LoadInputsToModelRunCtx(RedisAI_RunInfo *rinfo, RAI_DagOp *curre
9191
9292static void Dag_StoreOutputsFromModelRunCtx (RedisAI_RunInfo * rinfo , RAI_DagOp * currentOp ) {
9393
94- RAI_ContextReadLock (rinfo );
94+ RAI_ContextWriteLock (rinfo );
9595 const size_t noutputs = RAI_ModelRunCtxNumOutputs (currentOp -> mctx );
9696 for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
9797 RAI_Tensor * tensor = RAI_ModelRunCtxOutputTensor (currentOp -> mctx , outputNumber );
@@ -348,16 +348,20 @@ int RAI_DagOpBatchable(RAI_DagOp *op1, RedisAI_RunInfo *rinfo1, RAI_DagOp *op2,
348348 return 1 ;
349349}
350350
351- int RedisAI_DagDeviceComplete (RedisAI_RunInfo * rinfo ) {
351+ bool RedisAI_DagDeviceComplete (RedisAI_RunInfo * rinfo ) {
352352 return rinfo -> dagDeviceCompleteOpCount == rinfo -> dagDeviceOpCount ;
353353}
354354
355- int RedisAI_DagComplete (RedisAI_RunInfo * rinfo ) {
355+ bool RedisAI_DagComplete (RedisAI_RunInfo * rinfo ) {
356356 int completeOpCount = __atomic_load_n (rinfo -> dagCompleteOpCount , __ATOMIC_RELAXED );
357357
358358 return completeOpCount == rinfo -> dagOpCount ;
359359}
360360
361+ bool RedisAI_DagError (RedisAI_RunInfo * rinfo ) {
362+ return __atomic_load_n (rinfo -> dagError , __ATOMIC_RELAXED ) != 0 ;
363+ }
364+
361365RAI_DagOp * RedisAI_DagCurrentOp (RedisAI_RunInfo * rinfo ) {
362366 if (rinfo -> dagDeviceCompleteOpCount == rinfo -> dagDeviceOpCount ) {
363367 return NULL ;
@@ -366,21 +370,21 @@ RAI_DagOp *RedisAI_DagCurrentOp(RedisAI_RunInfo *rinfo) {
366370 return rinfo -> dagDeviceOps [rinfo -> dagDeviceCompleteOpCount ];
367371}
368372
369- void RedisAI_DagCurrentOpInfo (RedisAI_RunInfo * rinfo , int * currentOpReady ,
370- int * currentOpBatchable ) {
373+ void RedisAI_DagCurrentOpInfo (RedisAI_RunInfo * rinfo , bool * currentOpReady ,
374+ bool * currentOpBatchable ) {
371375 RAI_DagOp * currentOp_ = RedisAI_DagCurrentOp (rinfo );
372376
373- * currentOpReady = 0 ;
374- * currentOpBatchable = 0 ;
377+ * currentOpReady = false ;
378+ * currentOpBatchable = false ;
375379
376380 if (currentOp_ == NULL ) {
377381 return ;
378382 }
379383
380384 if (currentOp_ -> mctx && currentOp_ -> mctx -> model -> opts .batchsize > 0 ) {
381- * currentOpBatchable = 1 ;
385+ * currentOpBatchable = true ;
382386 }
383- * currentOpReady = 1 ;
387+ * currentOpReady = true ;
384388 // If this is a single op dag, the op is definitely ready.
385389 if (rinfo -> single_op_dag == 1 )
386390 return ;
@@ -391,7 +395,7 @@ void RedisAI_DagCurrentOpInfo(RedisAI_RunInfo *rinfo, int *currentOpReady,
391395 for (int i = 0 ; i < n_inkeys ; i ++ ) {
392396 if (AI_dictFind (rinfo -> dagTensorsContext , currentOp_ -> inkeys [i ]) == NULL ) {
393397 RAI_ContextUnlock (rinfo );
394- * currentOpReady = 0 ;
398+ * currentOpReady = false ;
395399 return ;
396400 }
397401 }
@@ -588,7 +592,7 @@ static void _ScriptSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
588592 const size_t noutputs = RAI_ScriptRunCtxNumOutputs (op -> sctx );
589593 for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
590594 RAI_Tensor * tensor = RAI_ScriptRunCtxOutputTensor (op -> sctx , outputNumber );
591- tensor = tensor ? RAI_TensorGetShallowCopy (tensor ) : NULL ;
595+ // tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL;
592596 if (tensor )
593597 _StoreTensorInKeySpace (ctx , tensor , op -> outkeys [outputNumber ], false);
594598 }
0 commit comments