Skip to content

Commit 116d1f6

Browse files
author
DvirDukhan
committed
background worker refactor
1 parent d13b724 commit 116d1f6

File tree

3 files changed

+236
-403
lines changed

3 files changed

+236
-403
lines changed

src/DAG/dag.c

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ static void Dag_LoadInputsToModelRunCtx(RedisAI_RunInfo *rinfo, RAI_DagOp *curre
9191

9292
static void Dag_StoreOutputsFromModelRunCtx(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp) {
9393

94-
RAI_ContextReadLock(rinfo);
94+
RAI_ContextWriteLock(rinfo);
9595
const size_t noutputs = RAI_ModelRunCtxNumOutputs(currentOp->mctx);
9696
for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) {
9797
RAI_Tensor *tensor = RAI_ModelRunCtxOutputTensor(currentOp->mctx, outputNumber);
@@ -348,16 +348,20 @@ int RAI_DagOpBatchable(RAI_DagOp *op1, RedisAI_RunInfo *rinfo1, RAI_DagOp *op2,
348348
return 1;
349349
}
350350

351-
int RedisAI_DagDeviceComplete(RedisAI_RunInfo *rinfo) {
351+
bool RedisAI_DagDeviceComplete(RedisAI_RunInfo *rinfo) {
352352
return rinfo->dagDeviceCompleteOpCount == rinfo->dagDeviceOpCount;
353353
}
354354

355-
int RedisAI_DagComplete(RedisAI_RunInfo *rinfo) {
355+
bool RedisAI_DagComplete(RedisAI_RunInfo *rinfo) {
356356
int completeOpCount = __atomic_load_n(rinfo->dagCompleteOpCount, __ATOMIC_RELAXED);
357357

358358
return completeOpCount == rinfo->dagOpCount;
359359
}
360360

361+
bool RedisAI_DagError(RedisAI_RunInfo *rinfo) {
362+
return __atomic_load_n(rinfo->dagError, __ATOMIC_RELAXED) != 0;
363+
}
364+
361365
RAI_DagOp *RedisAI_DagCurrentOp(RedisAI_RunInfo *rinfo) {
362366
if (rinfo->dagDeviceCompleteOpCount == rinfo->dagDeviceOpCount) {
363367
return NULL;
@@ -366,21 +370,21 @@ RAI_DagOp *RedisAI_DagCurrentOp(RedisAI_RunInfo *rinfo) {
366370
return rinfo->dagDeviceOps[rinfo->dagDeviceCompleteOpCount];
367371
}
368372

369-
void RedisAI_DagCurrentOpInfo(RedisAI_RunInfo *rinfo, int *currentOpReady,
370-
int *currentOpBatchable) {
373+
void RedisAI_DagCurrentOpInfo(RedisAI_RunInfo *rinfo, bool *currentOpReady,
374+
bool *currentOpBatchable) {
371375
RAI_DagOp *currentOp_ = RedisAI_DagCurrentOp(rinfo);
372376

373-
*currentOpReady = 0;
374-
*currentOpBatchable = 0;
377+
*currentOpReady = false;
378+
*currentOpBatchable = false;
375379

376380
if (currentOp_ == NULL) {
377381
return;
378382
}
379383

380384
if (currentOp_->mctx && currentOp_->mctx->model->opts.batchsize > 0) {
381-
*currentOpBatchable = 1;
385+
*currentOpBatchable = true;
382386
}
383-
*currentOpReady = 1;
387+
*currentOpReady = true;
384388
// If this is a single op dag, the op is definitely ready.
385389
if (rinfo->single_op_dag == 1)
386390
return;
@@ -391,7 +395,7 @@ void RedisAI_DagCurrentOpInfo(RedisAI_RunInfo *rinfo, int *currentOpReady,
391395
for (int i = 0; i < n_inkeys; i++) {
392396
if (AI_dictFind(rinfo->dagTensorsContext, currentOp_->inkeys[i]) == NULL) {
393397
RAI_ContextUnlock(rinfo);
394-
*currentOpReady = 0;
398+
*currentOpReady = false;
395399
return;
396400
}
397401
}
@@ -588,7 +592,7 @@ static void _ScriptSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
588592
const size_t noutputs = RAI_ScriptRunCtxNumOutputs(op->sctx);
589593
for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) {
590594
RAI_Tensor *tensor = RAI_ScriptRunCtxOutputTensor(op->sctx, outputNumber);
591-
tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL;
595+
// tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL;
592596
if (tensor)
593597
_StoreTensorInKeySpace(ctx, tensor, op->outkeys[outputNumber], false);
594598
}

src/DAG/dag.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,25 @@
1919
* successfully. Since rinfo carries information on what queue
2020
* it has been placed in, there's no need to pass the device identifier.
2121
* @param rinfo context in which RedisAI blocking commands operate.
22-
* @return nonzero if all ops are complete for device, 0 otherwise
22+
* @return true if all ops are complete for device, 0 otherwise
2323
*/
24-
int RedisAI_DagDeviceComplete(RedisAI_RunInfo *rinfo);
24+
bool RedisAI_DagDeviceComplete(RedisAI_RunInfo *rinfo);
2525

2626
/**
2727
* Get whether all DAG ops have been executed successfully irrespective
2828
* of the device, i.e. if the DAG has been completely executed.
2929
* @param rinfo context in which RedisAI blocking commands operate.
30-
* @return nonzero of all ops in DAG are complete, 0 otherwise
30+
* @return true of all ops in DAG are complete, 0 otherwise
3131
*/
32-
int RedisAI_DagComplete(RedisAI_RunInfo *rinfo);
32+
bool RedisAI_DagComplete(RedisAI_RunInfo *rinfo);
33+
34+
/**
35+
* @brief Get an indication if an error happend during the dag run.
36+
*
37+
* @param rinfo context in which RedisAI blocking commands operate.
38+
* @return true if there was an error
39+
*/
40+
bool RedisAI_DagError(RedisAI_RunInfo *rinfo);
3341

3442
/**
3543
* Get current DAG op for the given device. An op is current if it's
@@ -50,7 +58,7 @@ RAI_DagOp *RedisAI_DagCurrentOp(RedisAI_RunInfo *rinfo);
5058
* a MODELRUN and is BATCHSIZE greater than zero
5159
* @return
5260
*/
53-
void RedisAI_DagCurrentOpInfo(RedisAI_RunInfo *rinfo, int *currentOpReady, int *currentOpBatchable);
61+
void RedisAI_DagCurrentOpInfo(RedisAI_RunInfo *rinfo, bool *currentOpReady, bool *currentOpBatchable);
5462

5563
/**
5664
* Get batching information about a DAG op.

0 commit comments

Comments
 (0)