Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions ami/jobs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,30 @@ def reset(self, status: JobState = JobState.CREATED):
stage.progress = 0
stage.status = status

def is_complete(self) -> bool:
"""
Check if all stages have finished processing.

A job is considered complete when ALL of its stages have:
- progress >= 1.0 (fully processed)
- status in a final state (SUCCESS, FAILURE, or REVOKED)

This method works for any job type regardless of which stages it has.
It's used by the Celery task_postrun signal to determine whether to
set the job's final SUCCESS status, or defer to async progress handlers.

Related: Job.update_progress() (lines 924-947) calculates the aggregate
progress percentage across all stages for display purposes. This method
is a binary check for completion that considers both progress AND status.

Returns:
True if all stages are complete, False otherwise.
Returns False if job has no stages (shouldn't happen in practice).
"""
if not self.stages:
return False
return all(stage.progress >= 1.0 and stage.status in JobState.final_states() for stage in self.stages)

class Config:
use_enum_values = True
as_dict = True
Expand Down
7 changes: 7 additions & 0 deletions ami/jobs/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,10 @@
required=False,
type=int,
)

processing_service_name_param = OpenApiParameter(
name="processing_service_name",
description="Inform the name of the calling processing service",
required=False,
type=str,
)
11 changes: 10 additions & 1 deletion ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def pre_update_job_status(sender, task_id, task, **kwargs):

@task_postrun.connect(sender=run_job)
def update_job_status(sender, task_id, task, state: str, retval=None, **kwargs):
from ami.jobs.models import Job
from ami.jobs.models import Job, JobState

job_id = task.request.kwargs["job_id"]
if job_id is None:
Expand All @@ -190,6 +190,15 @@ def update_job_status(sender, task_id, task, state: str, retval=None, **kwargs):
logger.error(f"No job found for task {task_id} or job_id {job_id}")
return

# Guard only SUCCESS state - let FAILURE, REVOKED, RETRY pass through immediately
# SUCCESS should only be set when all stages are actually complete
# This prevents premature SUCCESS when async workers are still processing
if state == JobState.SUCCESS and not job.progress.is_complete():
job.logger.info(
f"Job {job.pk} task completed but stages not finished - " "deferring SUCCESS status to progress handler"
)
return

job.update_status(state)


Expand Down
87 changes: 87 additions & 0 deletions ami/jobs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,93 @@ def test_create_job_with_delay(self):
self.assertEqual(job.progress.stages[0].progress, 1)
self.assertEqual(job.progress.stages[0].status, JobState.SUCCESS)

def test_job_status_guard_prevents_premature_success(self):
"""
Test that update_job_status guards against setting SUCCESS
when job stages are not complete.

This tests the fix for race conditions where Celery task completes
but async workers are still processing stages.
"""
from unittest.mock import Mock

from ami.jobs.tasks import update_job_status

# Create job with multiple stages
job = Job.objects.create(
job_type_key=MLJob.key,
project=self.project,
name="Test job with incomplete stages",
pipeline=self.pipeline,
source_image_collection=self.source_image_collection,
)

# Add stages that are NOT complete
job.progress.add_stage("detection")
job.progress.update_stage("detection", progress=0.5, status=JobState.STARTED)
job.progress.add_stage("classification")
job.progress.update_stage("classification", progress=0.0, status=JobState.CREATED)
job.save()

# Verify stages are incomplete
self.assertFalse(job.progress.is_complete())

# Mock task object
mock_task = Mock()
mock_task.request.kwargs = {"job_id": job.pk}
initial_status = job.status

# Attempt to set SUCCESS while stages are incomplete
update_job_status(
sender=mock_task,
task_id="test-task-id",
task=mock_task,
state=JobState.SUCCESS.value, # Pass string value, not enum
retval=None,
)

# Verify job status was NOT updated to SUCCESS (should remain CREATED)
job.refresh_from_db()
self.assertEqual(job.status, initial_status)
self.assertNotEqual(job.status, JobState.SUCCESS.value)

def test_job_status_allows_failure_states_immediately(self):
"""
Test that FAILURE and REVOKED states bypass the completion guard
and are set immediately regardless of stage completion.
"""
from unittest.mock import Mock

from ami.jobs.tasks import update_job_status

job = Job.objects.create(
job_type_key=MLJob.key,
project=self.project,
name="Test job for failure states",
pipeline=self.pipeline,
source_image_collection=self.source_image_collection,
)

# Add incomplete stage
job.progress.add_stage("detection")
job.progress.update_stage("detection", progress=0.3, status=JobState.STARTED)
job.save()

mock_task = Mock()
mock_task.request.kwargs = {"job_id": job.pk}

# Test FAILURE state passes through even with incomplete stages
update_job_status(
sender=mock_task,
task_id="test-task-id",
task=mock_task,
state=JobState.FAILURE.value, # Pass string value, not enum
retval=None,
)

job.refresh_from_db()
self.assertEqual(job.status, JobState.FAILURE.value)


class TestJobView(APITestCase):
"""
Expand Down
29 changes: 27 additions & 2 deletions ami/jobs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ami.base.permissions import ObjectPermission
from ami.base.views import ProjectMixin
from ami.jobs.schemas import batch_param, ids_only_param, incomplete_only_param
from ami.jobs.schemas import batch_param, ids_only_param, incomplete_only_param, processing_service_name_param
from ami.jobs.tasks import process_nats_pipeline_result
from ami.main.api.schemas import project_id_doc_param
from ami.main.api.views import DefaultViewSet
Expand Down Expand Up @@ -203,13 +203,21 @@ def get_queryset(self) -> QuerySet:
project_id_doc_param,
ids_only_param,
incomplete_only_param,
processing_service_name_param,
]
)
def list(self, request, *args, **kwargs):
# Get optional processing_service_name parameter
processing_service_name = request.query_params.get("processing_service_name", None)
if processing_service_name:
logger.info(f"Jobs list requested by processing service: {processing_service_name}")
else:
logger.debug("Jobs list requested without processing service name")

return super().list(request, *args, **kwargs)

@extend_schema(
parameters=[batch_param],
parameters=[batch_param, processing_service_name_param],
responses={200: dict},
)
@action(detail=True, methods=["get"], name="tasks")
Expand All @@ -228,6 +236,13 @@ def tasks(self, request, pk=None):
except Exception as e:
raise ValidationError({"batch": str(e)}) from e

# Get optional processing_service_name parameter
processing_service_name = request.query_params.get("processing_service_name", None)
if processing_service_name:
job.logger.info(f"Job {job.pk} tasks ({batch}) requested by processing service: {processing_service_name}")
else:
job.logger.warning(f"Job {job.pk} tasks ({batch}) requested without processing service name")

# Validate that the job has a pipeline
if not job.pipeline:
raise ValidationError("This job does not have a pipeline configured")
Expand All @@ -249,6 +264,9 @@ async def get_tasks():

return Response({"tasks": tasks})

@extend_schema(
parameters=[processing_service_name_param],
)
@action(detail=True, methods=["post"], name="result")
def result(self, request, pk=None):
"""
Expand All @@ -261,6 +279,13 @@ def result(self, request, pk=None):

job = self.get_object()

# Get optional processing_service_name parameter
processing_service_name = request.query_params.get("processing_service_name", None)
if processing_service_name:
job.logger.info(f"Job {job.pk} result received from processing service: {processing_service_name}")
else:
job.logger.warning(f"Job {job.pk} result received without processing service name")

# Validate request data is a list
if isinstance(request.data, list):
results = request.data
Expand Down
29 changes: 0 additions & 29 deletions ami/utils/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import requests
from django.forms import BooleanField, FloatField
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter
from requests.adapters import HTTPAdapter
from rest_framework.request import Request
from urllib3.util import Retry
Expand Down Expand Up @@ -144,30 +142,3 @@ def get_default_classification_threshold(project: "Project | None" = None, reque
return project.default_filters_score_threshold
else:
return default_threshold


project_id_doc_param = OpenApiParameter(
name="project_id",
description="Filter by project ID",
required=False,
type=int,
)

ids_only_param = OpenApiParameter(
name="ids_only",
description="Return only job IDs instead of full job objects",
required=False,
type=OpenApiTypes.BOOL,
)
incomplete_only_param = OpenApiParameter(
name="incomplete_only",
description="Filter to only show incomplete jobs (excludes SUCCESS, FAILURE, REVOKED)",
required=False,
type=OpenApiTypes.BOOL,
)
batch_param = OpenApiParameter(
name="batch",
description="Number of tasks to pull in the batch",
required=False,
type=OpenApiTypes.INT,
)