Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,29 @@ def _update_job_progress(job_id: int, stage: str, progress_percentage: float) ->
job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%")
job.save()

# Clean up async resources for completed jobs that use NATS/Redis
# Only ML jobs with async_pipeline_workers enabled use these resources
if stage == "results" and progress_percentage >= 1.0:
job = Job.objects.get(pk=job_id) # Re-fetch outside transaction
_cleanup_job_if_needed(job)


def _cleanup_job_if_needed(job) -> None:
"""
Clean up async resources (NATS/Redis) if this job type uses them.

Only ML jobs with async_pipeline_workers enabled use NATS/Redis resources.
This function is safe to call for any job - it checks if cleanup is needed.

Args:
job: The Job instance
"""
if job.job_type_key == "ml" and job.project and job.project.feature_flags.async_pipeline_workers:
# import here to avoid circular imports
from ami.ml.orchestration.jobs import cleanup_async_job_resources

cleanup_async_job_resources(job)


@task_prerun.connect(sender=run_job)
def pre_update_job_status(sender, task_id, task, **kwargs):
Expand Down Expand Up @@ -201,6 +224,10 @@ def update_job_status(sender, task_id, task, state: str, retval=None, **kwargs):

job.update_status(state)

# Clean up async resources for revoked jobs
if state == JobState.REVOKED:
_cleanup_job_if_needed(job)


@task_failure.connect(sender=run_job, retry=False)
def update_job_failure(sender, task_id, exception, *args, **kwargs):
Expand All @@ -213,6 +240,9 @@ def update_job_failure(sender, task_id, exception, *args, **kwargs):

job.save()

# Clean up async resources for failed jobs
_cleanup_job_if_needed(job)


def log_time(start: float = 0, msg: str | None = None) -> tuple[float, Callable]:
"""
Expand Down
44 changes: 37 additions & 7 deletions ami/ml/orchestration/jobs.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,58 @@
import logging

from asgiref.sync import async_to_sync

from ami.jobs.models import Job, JobState, logger
from ami.jobs.models import Job, JobState
from ami.main.models import SourceImage
from ami.ml.orchestration.nats_queue import TaskQueueManager
from ami.ml.orchestration.task_state import TaskStateManager
from ami.ml.schemas import PipelineProcessingTask

logger = logging.getLogger(__name__)


# TODO CGJS: (Issue #1083) Call this once a job is fully complete (all images processed and saved)
def cleanup_nats_resources(job: "Job") -> bool:
def cleanup_async_job_resources(job: "Job") -> bool:
"""
Clean up NATS JetStream resources (stream and consumer) for a completed job.
Clean up NATS JetStream and Redis resources for a completed job.

This function cleans up:
1. Redis state (via TaskStateManager.cleanup):
2. NATS JetStream resources (via TaskQueueManager.cleanup_job_resources):

Cleanup failures are logged but don't fail the job - data is already saved.

Args:
job: The Job instance
Returns:
bool: True if cleanup was successful, False otherwise
bool: True if both cleanups succeeded, False otherwise
"""

redis_success = False
nats_success = False

# Cleanup Redis state
try:
state_manager = TaskStateManager(job.pk)
state_manager.cleanup()
job.logger.info(f"Cleaned up Redis state for job {job.pk}")
redis_success = True
except Exception as e:
job.logger.error(f"Error cleaning up Redis state for job {job.pk}: {e}")

# Cleanup NATS resources
async def cleanup():
async with TaskQueueManager() as manager:
return await manager.cleanup_job_resources(job.pk)

return async_to_sync(cleanup)()
try:
nats_success = async_to_sync(cleanup)()
if nats_success:
job.logger.info(f"Cleaned up NATS resources for job {job.pk}")
else:
job.logger.warning(f"Failed to clean up NATS resources for job {job.pk}")
except Exception as e:
job.logger.error(f"Error cleaning up NATS resources for job {job.pk}: {e}")

return redis_success and nats_success


def queue_images_to_nats(job: "Job", images: list[SourceImage]):
Expand Down
Empty file.
212 changes: 212 additions & 0 deletions ami/ml/orchestration/tests/test_cleanup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
"""Integration tests for async job resource cleanup (NATS and Redis)."""

from asgiref.sync import async_to_sync
from django.core.cache import cache
from django.test import TestCase
from nats.js.errors import NotFoundError

from ami.jobs.models import Job, JobState, MLJob
from ami.jobs.tasks import _update_job_progress, update_job_failure, update_job_status
from ami.main.models import Project, ProjectFeatureFlags, SourceImage, SourceImageCollection
from ami.ml.models import Pipeline
from ami.ml.orchestration.jobs import queue_images_to_nats
from ami.ml.orchestration.nats_queue import TaskQueueManager
from ami.ml.orchestration.task_state import TaskStateManager


class TestCleanupAsyncJobResources(TestCase):
"""Test cleanup of NATS and Redis resources for async ML jobs."""

def setUp(self):
"""Set up test fixtures with async_pipeline_workers enabled."""
# Create project with async_pipeline_workers feature flag enabled
self.project = Project.objects.create(
name="Test Cleanup Project",
feature_flags=ProjectFeatureFlags(async_pipeline_workers=True),
)

# Create pipeline
self.pipeline = Pipeline.objects.create(
name="Test Cleanup Pipeline",
slug="test-cleanup-pipeline",
description="Pipeline for cleanup tests",
)
self.pipeline.projects.add(self.project)

# Create source image collection with images
self.collection = SourceImageCollection.objects.create(
name="Test Cleanup Collection",
project=self.project,
)

# Create test images
self.images = [
SourceImage.objects.create(
path=f"test_image_{i}.jpg",
public_base_url="https://example.com",
project=self.project,
)
for i in range(3)
]
for image in self.images:
self.collection.images.add(image)

def _verify_resources_created(self, job_id: int):
"""
Verify that both Redis and NATS resources were created.

Args:
job_id: The job ID to check
"""
# Verify Redis keys exist
state_manager = TaskStateManager(job_id)
for stage in state_manager.STAGES:
pending_key = state_manager._get_pending_key(stage)
self.assertIsNotNone(cache.get(pending_key), f"Redis key {pending_key} should exist")
total_key = state_manager._total_key
self.assertIsNotNone(cache.get(total_key), f"Redis key {total_key} should exist")

# Verify NATS stream and consumer exist
async def check_nats_resources():
async with TaskQueueManager() as manager:
stream_name = manager._get_stream_name(job_id)
consumer_name = manager._get_consumer_name(job_id)

# Try to get stream info - should succeed if created
stream_exists = True
try:
await manager.js.stream_info(stream_name)
except NotFoundError:
stream_exists = False

# Try to get consumer info - should succeed if created
consumer_exists = True
try:
await manager.js.consumer_info(stream_name, consumer_name)
except NotFoundError:
consumer_exists = False

return stream_exists, consumer_exists

stream_exists, consumer_exists = async_to_sync(check_nats_resources)()

self.assertTrue(stream_exists, f"NATS stream for job {job_id} should exist")
self.assertTrue(consumer_exists, f"NATS consumer for job {job_id} should exist")

def _create_job_with_queued_images(self) -> Job:
"""
Helper to create an ML job and queue images to NATS/Redis.

Returns:
Job instance with images queued to NATS and state initialized in Redis
"""
job = Job.objects.create(
job_type_key=MLJob.key,
project=self.project,
name="Test Cleanup Job",
pipeline=self.pipeline,
source_image_collection=self.collection,
)

# Queue images to NATS (also initializes Redis state)
queue_images_to_nats(job, self.images)

# Verify resources were actually created
self._verify_resources_created(job.pk)

return job

def _verify_resources_cleaned(self, job_id: int):
"""
Verify that both Redis and NATS resources are cleaned up.

Args:
job_id: The job ID to check
"""
# Verify Redis keys are deleted
state_manager = TaskStateManager(job_id)
for stage in state_manager.STAGES:
pending_key = state_manager._get_pending_key(stage)
self.assertIsNone(cache.get(pending_key), f"Redis key {pending_key} should be deleted")
total_key = state_manager._total_key
self.assertIsNone(cache.get(total_key), f"Redis key {total_key} should be deleted")

# Verify NATS stream and consumer are deleted
async def check_nats_resources():
async with TaskQueueManager() as manager:
stream_name = manager._get_stream_name(job_id)
consumer_name = manager._get_consumer_name(job_id)

# Try to get stream info - should fail if deleted
stream_exists = True
try:
await manager.js.stream_info(stream_name)
except NotFoundError:
stream_exists = False

# Try to get consumer info - should fail if deleted
consumer_exists = True
try:
await manager.js.consumer_info(stream_name, consumer_name)
except NotFoundError:
consumer_exists = False

return stream_exists, consumer_exists

stream_exists, consumer_exists = async_to_sync(check_nats_resources)()

self.assertFalse(stream_exists, f"NATS stream for job {job_id} should be deleted")
self.assertFalse(consumer_exists, f"NATS consumer for job {job_id} should be deleted")

def test_cleanup_on_job_completion(self):
"""Test that resources are cleaned up when job completes successfully."""
job = self._create_job_with_queued_images()

# Simulate job completion by updating progress to 100% in results stage
_update_job_progress(job.pk, stage="results", progress_percentage=1.0)

# Verify cleanup happened
self._verify_resources_cleaned(job.pk)

def test_cleanup_on_job_failure(self):
"""Test that resources are cleaned up when job fails."""
job = self._create_job_with_queued_images()

# Set task_id so the failure handler can find the job
job.task_id = "test-task-failure-123"
job.save()

# Simulate job failure by calling the failure signal handler
update_job_failure(
sender=None,
task_id=job.task_id,
exception=Exception("Test failure"),
)

# Verify cleanup happened
self._verify_resources_cleaned(job.pk)

def test_cleanup_on_job_revoked(self):
"""Test that resources are cleaned up when job is revoked/cancelled."""
job = self._create_job_with_queued_images()

# Create a mock task request object for the signal handler
class MockRequest:
def __init__(self):
self.kwargs = {"job_id": job.pk}

class MockTask:
def __init__(self, job_id):
self.request = MockRequest()
self.request.kwargs["job_id"] = job_id

# Simulate job revocation by calling the postrun signal handler with REVOKED state
update_job_status(
sender=None,
task_id="test-task-revoked-456",
task=MockTask(job.pk),
state=JobState.REVOKED,
)

# Verify cleanup happened
self._verify_resources_cleaned(job.pk)