diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index a5c63fa39..c76977688 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -386,7 +386,11 @@ def _fan_out_with_async(self, generator: ColumnGeneratorWithModelRegistry, max_w progress_tracker, executor_kwargs = self._setup_fan_out(generator, max_workers) executor = AsyncConcurrentExecutor(max_workers=max_workers, **executor_kwargs) work_items = [ - (generator.agenerate(record), {"index": i}) for i, record in self.batch_manager.iter_current_batch() + ( + generator.agenerate(record), + {"index": i, "column_name": generator.config.name}, + ) + for i, record in self.batch_manager.iter_current_batch() ] executor.run(work_items) self._finalize_fan_out(progress_tracker) @@ -397,7 +401,11 @@ def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max progress_tracker, executor_kwargs = self._setup_fan_out(generator, max_workers) with ConcurrentThreadExecutor(max_workers=max_workers, **executor_kwargs) as executor: for i, record in self.batch_manager.iter_current_batch(): - executor.submit(lambda record: generator.generate(record), record, context={"index": i}) + executor.submit( + lambda record: generator.generate(record), + record, + context={"index": i, "column_name": generator.config.name}, + ) self._finalize_fan_out(progress_tracker) def _make_result_callback(self, progress_tracker: ProgressTracker) -> Callable[[dict], None]: @@ -484,12 +492,61 @@ def _cleanup_dropped_record_images(self, dropped_indices: set[int]) -> None: for path in [paths] if isinstance(paths, str) else paths: media_storage.delete_image(path) + @staticmethod + def _extract_failure_detail(exc: Exception) -> str: + detail = getattr(exc, "detail", None) + if isinstance(detail, str): + normalized_detail = " ".join(detail.split()).strip() + if normalized_detail: + return normalized_detail + exc_str = str(exc).strip() + for line in exc_str.splitlines(): + if "Cause:" in line: + return " ".join(line.split("Cause:", maxsplit=1)[1].split()).strip() + return " ".join(exc_str.split()).strip() or type(exc).__name__ + + @classmethod + def _classify_worker_failure(cls, exc: Exception) -> str: + failure_kind = getattr(exc, "failure_kind", None) + if isinstance(failure_kind, str) and failure_kind.strip(): + return failure_kind.replace("_", " ") + + detail = cls._extract_failure_detail(exc).lower() + exc_name = type(exc).__name__.lower() + + if "timeout" in exc_name or "timed out" in detail: + return "timeout" + if "rate" in exc_name and "limit" in exc_name: + return "rate limit" + if "authentication" in exc_name: + return "authentication" + if "permission" in exc_name: + return "permission denied" + if "contextwindow" in exc_name or "context width" in detail: + return "context window" + if "response_schema" in detail or "schema" in detail: + return "schema validation" + if "validation" in exc_name or "validation" in detail: + return "validation" + return "generation error" + + @classmethod + def _format_worker_failure_warning(cls, exc: Exception, *, context: dict | None = None) -> str: + record_index = context["index"] if context is not None and "index" in context else "unknown" + column_name = context.get("column_name") if context is not None else None + context_label = f" in column {column_name!r}" if column_name else "" + failure_kind = cls._classify_worker_failure(exc) + failure_detail = cls._extract_failure_detail(exc) + return ( + f"⚠️ Generation for record at index {record_index} failed{context_label} ({failure_kind}). " + f"Will omit this record from the dataset. Detail: {failure_detail}" + ) + def _worker_error_callback(self, exc: Exception, *, context: dict | None = None) -> None: """If a worker fails, we can handle the exception here.""" - logger.warning( - f"⚠️ Generation for record at index {context['index']} failed. " - f"Will omit this record from the dataset.\n{exc}" - ) + logger.warning(self._format_worker_failure_warning(exc, context=context)) + if context is None or "index" not in context: + raise RuntimeError("Worker error callback called without a valid context index.") self._records_to_drop.add(context["index"]) def _worker_result_callback(self, result: dict | list[dict], *, context: dict | None = None) -> None: diff --git a/packages/data-designer-engine/src/data_designer/engine/models/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/errors.py index 6ad084fa7..e01b184bf 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/errors.py @@ -20,6 +20,13 @@ logger = logging.getLogger(__name__) +def _normalize_error_detail(detail: str | None) -> str | None: + if detail is None: + return None + normalized = " ".join(detail.split()).strip() + return normalized or None + + def get_exception_primary_cause(exception: BaseException) -> BaseException: """Returns the primary cause of an exception by walking backwards. @@ -38,7 +45,27 @@ def get_exception_primary_cause(exception: BaseException) -> BaseException: return get_exception_primary_cause(exception.__cause__) -class GenerationValidationFailureError(Exception): ... +class GenerationValidationFailureError(Exception): + summary: str + detail: str | None + failure_kind: str + + def __init__( + self, + summary: str, + *, + detail: str | None = None, + failure_kind: str = "validation_error", + ) -> None: + self.summary = summary.strip() + self.detail = _normalize_error_detail(detail) + self.failure_kind = failure_kind + + message = self.summary + if self.detail is not None: + message = f"{message} Validation detail: {self.detail}" + + super().__init__(message) class ModelRateLimitError(DataDesignerError): ... @@ -80,7 +107,23 @@ class ModelAPIConnectionError(DataDesignerError): ... class ModelStructuredOutputError(DataDesignerError): ... -class ModelGenerationValidationFailureError(DataDesignerError): ... +class ModelGenerationValidationFailureError(DataDesignerError): + detail: str | None + failure_kind: str | None + + def __init__( + self, + message: object | None = None, + *, + detail: str | None = None, + failure_kind: str | None = None, + ) -> None: + if message is None: + super().__init__() + else: + super().__init__(message) + self.detail = _normalize_error_detail(detail) + self.failure_kind = failure_kind class ImageGenerationError(DataDesignerError): ... @@ -214,11 +257,18 @@ def handle_llm_exceptions( # Parsing and validation errors case GenerationValidationFailureError(): + detail_text = exception.detail.rstrip(".") if exception.detail is not None else None + validation_detail = f" Validation detail: {detail_text}." if detail_text is not None else "" raise ModelGenerationValidationFailureError( FormattedLLMErrorMessage( - cause=f"The provided output schema was unable to be parsed from model {model_name!r} responses while {purpose}.", + cause=( + f"The model output from {model_name!r} could not be parsed into the requested format " + f"while {purpose}.{validation_detail}" + ), solution="This is most likely temporary as we make additional attempts. If you continue to see more of this, simplify or modify the output schema for structured output and try again. If you are attempting token-intensive tasks like generations with high-reasoning effort, ensure that max_tokens in the model config is high enough to reach completion.", - ) + ), + detail=exception.detail, + failure_kind=exception.failure_kind, ) from None case DataDesignerError(): diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index d13e96ea9..8df7beffe 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -47,6 +47,23 @@ def _identity(x: Any) -> Any: logger = logging.getLogger(__name__) +def _classify_generation_failure_kind(exc: ParserException) -> str: + detail = " ".join(str(get_exception_primary_cause(exc)).split()).lower() + if "response_schema" in detail or "model_validate" in detail: + return "schema_validation" + if "validation error" in detail or "doesn't match requested" in detail: + return "schema_validation" + return "parse_error" + + +def _build_generation_validation_error(summary: str, exc: ParserException) -> GenerationValidationFailureError: + return GenerationValidationFailureError( + summary, + detail=str(get_exception_primary_cause(exc)), + failure_kind=_classify_generation_failure_kind(exc), + ) + + # Known keyword arguments extracted into request fields for each modality. # Note: `extra_body` and `extra_headers` appear in every set but receive special # treatment in `consolidate_kwargs` (merged with provider-level overrides) and in @@ -326,8 +343,9 @@ def generate( break except ParserException as exc: if max_correction_steps == 0 and max_conversation_restarts == 0: - raise GenerationValidationFailureError( - "Unsuccessful generation attempt. No retries were attempted." + raise _build_generation_validation_error( + "Unsuccessful generation attempt. No retries were attempted.", + exc, ) from exc if curr_num_correction_steps <= max_correction_steps: @@ -341,9 +359,12 @@ def generate( tool_call_turns = checkpoint_tool_call_turns else: - raise GenerationValidationFailureError( - f"Unsuccessful generation despite {max_correction_steps} correction steps " - f"and {max_conversation_restarts} conversation restarts." + raise _build_generation_validation_error( + ( + f"Unsuccessful generation despite {max_correction_steps} correction steps " + f"and {max_conversation_restarts} conversation restarts." + ), + exc, ) from exc if not skip_usage_tracking and mcp_facade is not None: @@ -424,8 +445,9 @@ async def agenerate( break except ParserException as exc: if max_correction_steps == 0 and max_conversation_restarts == 0: - raise GenerationValidationFailureError( - "Unsuccessful generation attempt. No retries were attempted." + raise _build_generation_validation_error( + "Unsuccessful generation attempt. No retries were attempted.", + exc, ) from exc if curr_num_correction_steps <= max_correction_steps: @@ -438,9 +460,12 @@ async def agenerate( tool_call_turns = checkpoint_tool_call_turns else: - raise GenerationValidationFailureError( - f"Unsuccessful generation despite {max_correction_steps} correction steps " - f"and {max_conversation_restarts} conversation restarts." + raise _build_generation_validation_error( + ( + f"Unsuccessful generation despite {max_correction_steps} correction steps " + f"and {max_conversation_restarts} conversation restarts." + ), + exc, ) from exc if not skip_usage_tracking and mcp_facade is not None: diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py index 4f602d6a9..095e548e6 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_column_wise_builder.py @@ -3,6 +3,7 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING from unittest.mock import Mock, patch @@ -20,6 +21,11 @@ from data_designer.engine.column_generators.generators.base import GenerationStrategy from data_designer.engine.dataset_builders.column_wise_builder import ColumnWiseDatasetBuilder from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError +from data_designer.engine.models.errors import ( + FormattedLLMErrorMessage, + ModelGenerationValidationFailureError, + ModelTimeoutError, +) from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum from data_designer.engine.models.usage import ModelUsageStats, TokenUsageStats from data_designer.engine.processing.processors.base import Processor @@ -154,6 +160,77 @@ def test_column_wise_dataset_builder_records_to_drop_initialization(stub_column_ assert stub_column_wise_builder._records_to_drop == set() +def test_worker_error_callback_logs_schema_validation_detail( + stub_column_wise_builder: ColumnWiseDatasetBuilder, + caplog: pytest.LogCaptureFixture, +) -> None: + exc = ModelGenerationValidationFailureError( + FormattedLLMErrorMessage( + cause=( + "The model output from 'test-model' could not be parsed into the requested format while " + "running generation for column 'test_column'. Validation detail: Response doesn't match " + "requested 'name' is a required property." + ), + solution="Simplify the schema and retry.", + ), + detail="Response doesn't match requested 'name' is a required property.", + failure_kind="schema_validation", + ) + + with caplog.at_level(logging.WARNING): + stub_column_wise_builder._worker_error_callback(exc, context={"index": 248, "column_name": "test_column"}) + + assert "record at index 248" in caplog.text + assert "column 'test_column'" in caplog.text + assert "(schema validation)" in caplog.text + assert "Response doesn't match requested 'name' is a required property." in caplog.text + assert 248 in stub_column_wise_builder._records_to_drop + + +def test_worker_error_callback_logs_timeout_detail( + stub_column_wise_builder: ColumnWiseDatasetBuilder, + caplog: pytest.LogCaptureFixture, +) -> None: + exc = ModelTimeoutError( + FormattedLLMErrorMessage( + cause="The request to model 'test-model' timed out while running generation for column 'test_column'.", + solution="Increase the timeout setting for the model and retry.", + ) + ) + + with caplog.at_level(logging.WARNING): + stub_column_wise_builder._worker_error_callback(exc, context={"index": 17, "column_name": "test_column"}) + + assert "record at index 17" in caplog.text + assert "column 'test_column'" in caplog.text + assert "(timeout)" in caplog.text + assert ( + "The request to model 'test-model' timed out while running generation for column 'test_column'." in caplog.text + ) + assert 17 in stub_column_wise_builder._records_to_drop + + +def test_worker_error_callback_requires_context_index( + stub_column_wise_builder: ColumnWiseDatasetBuilder, + caplog: pytest.LogCaptureFixture, +) -> None: + exc = ModelTimeoutError( + FormattedLLMErrorMessage( + cause="The request to model 'test-model' timed out while running generation for column 'test_column'.", + solution="Increase the timeout setting for the model and retry.", + ) + ) + + with ( + caplog.at_level(logging.WARNING), + pytest.raises(RuntimeError, match="Worker error callback called without a valid context index."), + ): + stub_column_wise_builder._worker_error_callback(exc, context=None) + + assert "record at index unknown" in caplog.text + assert len(stub_column_wise_builder._records_to_drop) == 0 + + def test_column_wise_dataset_builder_batch_manager_initialization(stub_column_wise_builder, stub_resource_provider): assert stub_column_wise_builder.batch_manager is not None assert stub_column_wise_builder.batch_manager.artifact_storage == stub_resource_provider.artifact_storage @@ -420,6 +497,86 @@ def test_fan_out_with_threads_uses_early_shutdown_settings_from_resource_provide assert call_kwargs["disable_early_shutdown"] == disable_early_shutdown +@patch("data_designer.engine.dataset_builders.column_wise_builder.ConcurrentThreadExecutor") +def test_fan_out_with_threads_passes_column_name_in_context( + mock_executor_class: Mock, + stub_resource_provider: Mock, + stub_test_config_builder: DataDesignerConfigBuilder, +) -> None: + builder = ColumnWiseDatasetBuilder( + data_designer_config=stub_test_config_builder.build(), + resource_provider=stub_resource_provider, + ) + + mock_executor = Mock() + mock_executor_class.return_value.__enter__ = Mock(return_value=mock_executor) + mock_executor_class.return_value.__exit__ = Mock(return_value=False) + + mock_generator = Mock() + mock_generator.get_generation_strategy.return_value = GenerationStrategy.CELL_BY_CELL + mock_generator.config.name = "test_column" + mock_generator.config.column_type = "llm_text" + mock_generator.config.tool_alias = None + + builder.batch_manager = Mock() + builder.batch_manager.num_records_batch = 2 + builder.batch_manager.num_records_in_buffer = 2 + builder.batch_manager.iter_current_batch.return_value = [(0, {"seed": "a"}), (1, {"seed": "b"})] + + builder._fan_out_with_threads(mock_generator, max_workers=2) + + submitted_contexts = [call.kwargs["context"] for call in mock_executor.submit.call_args_list] + assert submitted_contexts == [ + {"index": 0, "column_name": "test_column"}, + {"index": 1, "column_name": "test_column"}, + ] + + +@patch("data_designer.engine.dataset_builders.column_wise_builder.AsyncConcurrentExecutor", create=True) +def test_fan_out_with_async_passes_column_name_in_context( + mock_executor_class: Mock, + stub_resource_provider: Mock, + stub_test_config_builder: DataDesignerConfigBuilder, +) -> None: + builder = ColumnWiseDatasetBuilder( + data_designer_config=stub_test_config_builder.build(), + resource_provider=stub_resource_provider, + ) + + mock_executor = Mock() + + def _run(work_items: list[tuple[object, dict[str, int | str]]]) -> None: + for coro, _context in work_items: + coro.close() + + mock_executor.run.side_effect = _run + mock_executor_class.return_value = mock_executor + + mock_generator = Mock() + mock_generator.get_generation_strategy.return_value = GenerationStrategy.CELL_BY_CELL + mock_generator.config.name = "test_column" + mock_generator.config.column_type = "llm_text" + mock_generator.config.tool_alias = None + + async def _agenerate(record: dict[str, str]) -> dict[str, str]: + return record + + mock_generator.agenerate.side_effect = _agenerate + + builder.batch_manager = Mock() + builder.batch_manager.num_records_batch = 2 + builder.batch_manager.iter_current_batch.return_value = [(0, {"seed": "a"}), (1, {"seed": "b"})] + + builder._fan_out_with_async(mock_generator, max_workers=2) + + work_items = mock_executor.run.call_args.args[0] + submitted_contexts = [context for _coro, context in work_items] + assert submitted_contexts == [ + {"index": 0, "column_name": "test_column"}, + {"index": 1, "column_name": "test_column"}, + ] + + def test_full_column_custom_generator_error_is_descriptive(stub_resource_provider, stub_model_configs): @custom_column_generator(required_columns=["some_id"]) def bad_fn(df: pd.DataFrame) -> pd.DataFrame: diff --git a/packages/data-designer-engine/tests/engine/models/test_facade.py b/packages/data-designer-engine/tests/engine/models/test_facade.py index 0e6169ae7..84e91325e 100644 --- a/packages/data-designer-engine/tests/engine/models/test_facade.py +++ b/packages/data-designer-engine/tests/engine/models/test_facade.py @@ -119,6 +119,57 @@ def capture_and_return(*args: Any, **kwargs: Any) -> ChatCompletionResponse: assert captured_messages[0] == expected_messages +@patch.object(ModelFacade, "completion", autospec=True) +def test_generate_includes_parser_validation_detail_in_user_facing_error( + mock_completion: Any, + stub_model_facade: ModelFacade, +) -> None: + mock_completion.return_value = _make_response("bad response") + + def _failing_parser(response: str) -> str: + raise ParserException("Response doesn't match requested \n'name' is a required property") + + with pytest.raises( + ModelGenerationValidationFailureError, + match="Validation detail: Response doesn't match requested 'name' is a required property.", + ) as exc_info: + stub_model_facade.generate( + prompt="foo", + parser=_failing_parser, + max_correction_steps=0, + max_conversation_restarts=0, + ) + + assert exc_info.value.detail == "Response doesn't match requested 'name' is a required property" + assert exc_info.value.failure_kind == "schema_validation" + + +@patch.object(ModelFacade, "acompletion", new_callable=AsyncMock) +@pytest.mark.asyncio +async def test_agenerate_includes_parser_validation_detail_in_user_facing_error( + mock_acompletion: AsyncMock, + stub_model_facade: ModelFacade, +) -> None: + mock_acompletion.return_value = _make_response("bad response") + + def _failing_parser(response: str) -> str: + raise ParserException("Response doesn't match requested \n'name' is a required property") + + with pytest.raises( + ModelGenerationValidationFailureError, + match="Validation detail: Response doesn't match requested 'name' is a required property.", + ) as exc_info: + await stub_model_facade.agenerate( + prompt="foo", + parser=_failing_parser, + max_correction_steps=0, + max_conversation_restarts=0, + ) + + assert exc_info.value.detail == "Response doesn't match requested 'name' is a required property" + assert exc_info.value.failure_kind == "schema_validation" + + @pytest.mark.parametrize( "raw_content,expected", [ diff --git a/packages/data-designer-engine/tests/engine/models/test_model_errors.py b/packages/data-designer-engine/tests/engine/models/test_model_errors.py index e8372bded..d325838c5 100644 --- a/packages/data-designer-engine/tests/engine/models/test_model_errors.py +++ b/packages/data-designer-engine/tests/engine/models/test_model_errors.py @@ -114,9 +114,16 @@ f"Cause: One or more of the parameters you provided were found to be unsupported by model '{stub_model_name}' while {stub_purpose}.", ), ( - GenerationValidationFailureError("Generation validation failure"), + GenerationValidationFailureError( + "Generation validation failure", + detail="Response doesn't match requested : 'name' is a required property", + ), ModelGenerationValidationFailureError, - f"Cause: The provided output schema was unable to be parsed from model '{stub_model_name}' responses while {stub_purpose}.", + ( + f"Cause: The model output from '{stub_model_name}' could not be parsed into the requested format " + f"while {stub_purpose}. Validation detail: Response doesn't match requested : " + "'name' is a required property." + ), ), ( Exception("Some unexpected error"), @@ -131,6 +138,40 @@ def test_handle_llm_exceptions(exception, expected_exception, expected_error_msg handle_llm_exceptions(exception, stub_model_name, stub_model_provider_name, stub_purpose) +def test_handle_llm_exceptions_preserves_generation_failure_kind() -> None: + with pytest.raises(ModelGenerationValidationFailureError) as exc_info: + handle_llm_exceptions( + GenerationValidationFailureError( + "Generation validation failure", + detail="Response doesn't match requested : 'name' is a required property", + failure_kind="schema_validation", + ), + stub_model_name, + stub_model_provider_name, + stub_purpose, + ) + + assert exc_info.value.failure_kind == "schema_validation" + assert exc_info.value.detail == "Response doesn't match requested : 'name' is a required property" + + +def test_handle_llm_exceptions_strips_duplicate_period_from_validation_detail() -> None: + with pytest.raises(ModelGenerationValidationFailureError, match=r"Validation detail: Field required\.") as exc_info: + handle_llm_exceptions( + GenerationValidationFailureError( + "Generation validation failure", + detail="Field required.", + failure_kind="schema_validation", + ), + stub_model_name, + stub_model_provider_name, + stub_purpose, + ) + + assert "Field required.." not in str(exc_info.value) + assert exc_info.value.detail == "Field required." + + def test_catch_llm_exceptions(): @catch_llm_exceptions def stub_function(model_facade: Any, *args, **kwargs):