diff --git a/requirements.txt b/requirements.txt index 0ca1de9..640c433 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ fastapi==0.88.0 pydantic==1.10.4 -python-ms-core==0.0.23 +python-ms-core==0.0.24 uvicorn==0.20.0 html_testRunner==1.2.1 geopandas==0.14.4 diff --git a/src/config.py b/src/config.py index 7c92892..f9d84ad 100644 --- a/src/config.py +++ b/src/config.py @@ -17,7 +17,8 @@ class Settings(BaseSettings): app_name: str = 'python-osw-validation' event_bus = EventBusSettings() auth_permission_url: str = os.environ.get('AUTH_PERMISSION_URL', None) - max_concurrent_messages: int = os.environ.get('MAX_CONCURRENT_MESSAGES', 2) + max_concurrent_messages: int = os.environ.get('MAX_CONCURRENT_MESSAGES', 1) + max_receivable_messages: int = os.environ.get('MAX_RECEIVABLE_MESSAGES',-1) # -1 means no limit @property def auth_provider(self) -> str: diff --git a/src/main.py b/src/main.py index b8cfc2e..f60b321 100644 --- a/src/main.py +++ b/src/main.py @@ -22,6 +22,7 @@ async def startup_event(settings: Settings = Depends(get_settings)) -> None: try: # OSWValidator() app.validator = OSWValidator() + except: print('\n\n\x1b[31m Application startup failed due to missing or invalid .env file \x1b[0m') print('\x1b[31m Please provide the valid .env file and .env file should contains following parameters\x1b[0m') diff --git a/src/osw_validator.py b/src/osw_validator.py index cfcc5f7..86d41b4 100644 --- a/src/osw_validator.py +++ b/src/osw_validator.py @@ -1,5 +1,8 @@ import gc import logging +import os +import signal +import time import urllib.parse from typing import List from python_ms_core import Core @@ -21,6 +24,9 @@ class OSWValidator: def __init__(self): self.core = Core() + + ## Print the core version + print(f'Core version: {self.core.__version__}') options = { 'provider': self._settings.auth_provider, 'api_url': self._settings.auth_permission_url @@ -31,7 +37,8 @@ def __init__(self): self.logger = self.core.get_logger() self.storage_client = self.core.get_storage_client() self.auth = self.core.get_authorizer(config=options) - self.listener_thread = threading.Thread(target=self.start_listening) + self._shutdown_triggered = threading.Event() + self.listener_thread = threading.Thread(target=self.start_listening, daemon=True) self.listener_thread.start() def start_listening(self): @@ -41,10 +48,11 @@ def process(message) -> None: upload_message = Upload.data_from(queue_message) self.validate(received_message=upload_message) - self.listening_topic.subscribe(subscription=self.subscription_name, callback=process) + self.listening_topic.subscribe(subscription=self.subscription_name, callback=process, max_receivable_messages=self._settings.max_receivable_messages) def validate(self, received_message: Upload): tdei_record_id: str = '' + status_sent = False try: tdei_record_id = received_message.message_id logger.info(f'Received message for : {tdei_record_id} Message received for OSW validation !') @@ -66,6 +74,7 @@ def validate(self, received_message: Upload): validation_result = Validation(file_path=file_upload_path, storage_client=self.storage_client) result = validation_result.validate() self.send_status(result=result, upload_message=received_message) + status_sent = True else: raise Exception('File entity not found') except Exception as e: @@ -74,6 +83,13 @@ def validate(self, received_message: Upload): result.is_valid = False result.validation_message = f'Error occurred while validating OSW request {e}' self.send_status(result=result, upload_message=received_message) + status_sent = True + finally: + if status_sent: + logger.info('Triggering server shutdown after status send.') + else: + logger.warning('Server shutdown skipped because status was not sent.') + self._stop_server_and_container(delay_seconds=2) def send_status(self, result: ValidationResult, upload_message: Upload): upload_message.data.success = result.is_valid @@ -90,6 +106,7 @@ def send_status(self, result: ValidationResult, upload_message: Upload): 'data': resp_data }) try: + logger.info('Sending validation result to response topic.') self.core.get_topic(topic_name=self._settings.event_bus.validation_topic).publish(data=data) logger.info(f'Publishing message for : {upload_message.message_id}') except Exception as e: @@ -113,4 +130,30 @@ def has_permission(self, roles: List[str], queue_message: Upload) -> bool: return False def stop_listening(self): - self.listener_thread.join(timeout=0) # Stop the thread during shutdown.Its still an attempt. Not sure if this will work. \ No newline at end of file + self._stop_server_and_container() + if hasattr(self, 'listener_thread'): + self.listener_thread.join(timeout=0) # Stop the thread during shutdown.Its still an attempt. Not sure if this will work. + + def _stop_server_and_container(self, delay_seconds: float = 0.0): + """ + Attempt to gracefully stop the current process (stopping FastAPI/uvicorn and the Docker container). + """ + logger.info('Gracefully stopping FastAPI/uvicorn and Docker container') + if self._shutdown_triggered.is_set(): + logger.info('Server stop already in progress; skipping duplicate trigger.') + return + self._shutdown_triggered.set() + logger.info('Server stop triggered; scheduling shutdown.') + def _terminate(): + if delay_seconds: + time.sleep(delay_seconds) + try: + logger.info('Sending SIGTERM to stop server/container.') + os.kill(os.getpid(), signal.SIGTERM) + except Exception as err: + logger.warning(f'Error occurred while sending SIGTERM: {err}') + finally: + logger.info('Forcing process exit to stop server/container.') + os._exit(0) + + threading.Thread(target=_terminate, daemon=True).start() diff --git a/tests/unit_tests/interface/test_validator_abstract.py b/tests/unit_tests/interface/test_validator_abstract.py index 5f95071..1f22573 100644 --- a/tests/unit_tests/interface/test_validator_abstract.py +++ b/tests/unit_tests/interface/test_validator_abstract.py @@ -12,6 +12,11 @@ def validate(self, message: QueueMessage) -> None: pass +class SuperCallingValidator(ValidatorAbstract): + def validate(self, message: QueueMessage) -> None: + return super().validate(message) + + class TestValidatorAbstract(unittest.TestCase): def test_abstract_method_enforcement(self): @@ -37,6 +42,14 @@ def test_validate_method_called(self): # Assert that the mocked message object is a valid argument self.assertTrue(hasattr(message, '__class__')) + def test_abstract_base_method_body_returns_none(self): + message = MagicMock(spec=QueueMessage) + validator = SuperCallingValidator() + + result = validator.validate(message) + + self.assertIsNone(result) + if __name__ == '__main__': unittest.main() diff --git a/tests/unit_tests/models/test_queue_message_content.py b/tests/unit_tests/models/test_queue_message_content.py index 4ac1e2c..3316f63 100644 --- a/tests/unit_tests/models/test_queue_message_content.py +++ b/tests/unit_tests/models/test_queue_message_content.py @@ -9,8 +9,8 @@ TEST_JSON_FILE = os.path.join(parent_dir, 'src/assets/osw-upload.json') -TEST_FILE = open(TEST_JSON_FILE) -TEST_DATA = json.loads(TEST_FILE.read()) +with open(TEST_JSON_FILE) as test_file: + TEST_DATA = json.load(test_file) class TestUpload(unittest.TestCase): diff --git a/tests/unit_tests/test_main.py b/tests/unit_tests/test_main.py index 6beac7a..33d6201 100644 --- a/tests/unit_tests/test_main.py +++ b/tests/unit_tests/test_main.py @@ -1,6 +1,9 @@ import unittest +import asyncio +from unittest.mock import MagicMock, patch from fastapi import status from fastapi.testclient import TestClient +import src.main as main from src.main import app, get_settings @@ -22,6 +25,43 @@ def test_get_settings(self): settings = get_settings() self.assertIsNotNone(settings) + @patch('src.main.OSWValidator') + def test_startup_event_sets_validator(self, mock_validator): + validator = MagicMock() + mock_validator.return_value = validator + main.app.validator = None + + asyncio.run(main.startup_event()) + + self.assertIs(main.app.validator, validator) + + @patch('builtins.print') + @patch('src.main.psutil.Process') + @patch('src.main.os.getpid', return_value=123) + @patch('src.main.OSWValidator', side_effect=Exception('boom')) + def test_startup_event_handles_validator_init_failure(self, mock_validator, mock_getpid, mock_process, mock_print): + child_one = MagicMock() + child_two = MagicMock() + parent = MagicMock() + parent.children.return_value = [child_one, child_two] + mock_process.return_value = parent + + asyncio.run(main.startup_event()) + + parent.children.assert_called_once_with(recursive=True) + child_one.kill.assert_called_once() + child_two.kill.assert_called_once() + parent.kill.assert_called_once() + self.assertGreaterEqual(mock_print.call_count, 6) + + def test_shutdown_event_stops_validator(self): + validator = MagicMock() + main.app.validator = validator + + asyncio.run(main.shutdown_event()) + + validator.stop_listening.assert_called_once() + if __name__ == '__main__': unittest.main() diff --git a/tests/unit_tests/test_osw_validator.py b/tests/unit_tests/test_osw_validator.py index 626b749..ed00980 100644 --- a/tests/unit_tests/test_osw_validator.py +++ b/tests/unit_tests/test_osw_validator.py @@ -10,8 +10,8 @@ TEST_JSON_FILE = os.path.join(parent_dir, 'src/assets/osw-upload.json') -TEST_FILE = open(TEST_JSON_FILE) -TEST_DATA = json.loads(TEST_FILE.read()) +with open(TEST_JSON_FILE) as test_file: + TEST_DATA = json.load(test_file) class PermissionResponse: diff --git a/tests/unit_tests/test_service.py b/tests/unit_tests/test_service.py index 1c16397..13d2791 100644 --- a/tests/unit_tests/test_service.py +++ b/tests/unit_tests/test_service.py @@ -7,9 +7,10 @@ class TestOSWValidatorService(unittest.TestCase): + @patch('src.osw_validator.threading.Thread') @patch('src.osw_validator.Settings') @patch('src.osw_validator.Core') - def setUp(self, mock_core, mock_settings): + def setUp(self, mock_core, mock_settings, mock_thread): # Mock Settings mock_settings.return_value.event_bus.upload_subscription = 'test_subscription' mock_settings.return_value.event_bus.upload_topic = 'test_request_topic' @@ -19,14 +20,19 @@ def setUp(self, mock_core, mock_settings): mock_settings.return_value.event_bus.container_name = 'test_container' # Mock Core + mock_core.__version__ = 'test-core-version' + mock_core.return_value.__version__ = 'test-core-version' mock_core.return_value.get_topic.return_value = MagicMock() mock_core.return_value.get_storage_client.return_value = MagicMock() + self.mock_listener_thread = MagicMock() + mock_thread.return_value = self.mock_listener_thread # Initialize OSWValidator with mocked dependencies self.service = OSWValidator() self.service.storage_client = MagicMock() self.service.container_name = 'test_container' self.service.auth = MagicMock() + self.service._stop_server_and_container = MagicMock() # Define a sample message with proper strings self.sample_message = { @@ -41,11 +47,12 @@ def setUp(self, mock_core, mock_settings): @patch('src.osw_validator.QueueMessage') @patch('src.osw_validator.Upload') - def test_subscribe_with_valid_message(self, mock_request_message, mock_queue_message): + def test_subscribe_with_valid_message(self, mock_upload, mock_queue_message): # Arrange mock_message = MagicMock() mock_queue_message.to_dict.return_value = self.sample_message - mock_request_message.from_dict.return_value = mock_request_message + mock_upload_message = MagicMock() + mock_upload.data_from.return_value = mock_upload_message self.service.validate = MagicMock() # Act @@ -54,7 +61,7 @@ def test_subscribe_with_valid_message(self, mock_request_message, mock_queue_mes callback(mock_message) # Assert - self.service.validate.assert_called_once_with(received_message=mock_request_message.data_from()) + self.service.validate.assert_called_once_with(received_message=mock_upload_message) @patch('src.osw_validator.Validation') def test_validate_with_valid_file_path(self, mock_validation): @@ -164,19 +171,16 @@ def test_validate_with_validation_only_in_message_type(self, mock_has_permission self.assertTrue(actual_result.is_valid) self.assertEqual(actual_upload_message, mock_request_message) - @patch('src.osw_validator.threading.Thread') - def test_stop_listening(self, mock_thread): + def test_stop_listening(self): # Arrange - mock_thread_instance = MagicMock() - mock_thread.return_value = mock_thread_instance - - self.service.listener_thread = mock_thread_instance + self.service.listener_thread = self.mock_listener_thread # Act result = self.service.stop_listening() # Assert - mock_thread_instance.join.assert_called_once_with(timeout=0) + self.mock_listener_thread.join.assert_called_once_with(timeout=0) + self.service._stop_server_and_container.assert_called_once() self.assertIsNone(result) def test_has_permission_success(self): diff --git a/tests/unit_tests/test_validation.py b/tests/unit_tests/test_validation.py index 2238151..493adc8 100644 --- a/tests/unit_tests/test_validation.py +++ b/tests/unit_tests/test_validation.py @@ -122,13 +122,18 @@ def test_validate_invalid_file_with_errors(self, mock_download_file, mock_clean_ for expected, error in zip(expected_errors, errors): self.assertEqual(error['filename'], error_in_file) self.assertEqual(error['feature_index'], expected['feature_index']) - self.assertEqual(error['error_message'][0], expected['error_message']) + self.assertTrue( + error['error_message'][0].startswith( + "Additional properties are not allowed ('crossing' was unexpected)" + ) + ) # Ensure clean_up is called twice (once for the file, once for the folder) self.assertEqual(mock_clean_up.call_count, 2) + @patch('src.validation.Validation.download_single_file', return_value=None) @patch('src.validation.OSWValidation') @patch('src.validation.Validation.clean_up') - def test_validate_invalid_zip(self, mock_clean_up, mock_osw_validation): + def test_validate_invalid_zip(self, mock_clean_up, mock_osw_validation, mock_download_file): """Test validate method for invalid zip file with errors.""" # Mock the OSWValidation validate method to return errors mock_validation_result = MagicMock()