diff --git a/packages/nvidia_nat_eval/src/nat/plugins/eval/runtime/evaluate.py b/packages/nvidia_nat_eval/src/nat/plugins/eval/runtime/evaluate.py index 42db91d993..f8f5e7473b 100644 --- a/packages/nvidia_nat_eval/src/nat/plugins/eval/runtime/evaluate.py +++ b/packages/nvidia_nat_eval/src/nat/plugins/eval/runtime/evaluate.py @@ -117,6 +117,12 @@ def __init__(self, config: EvaluationRunConfig, callback_manager: "EvalCallbackM # Pre-generated OTEL root span_ids for eager trace linking (item_id -> span_id) self._item_span_ids: dict[str, int] = {} + def _write_checkpoint_item(self, checkpoint_file: Path, item_dict: dict[str, Any]) -> None: + """Helper to write a single JSONL line to disk. Called via to_thread to avoid blocking.""" + with open(checkpoint_file, "a", encoding="utf-8") as f: + f.write(json.dumps(item_dict) + "\n") + f.flush() + def _compute_usage_stats(self, item: EvalInputItem): """Compute usage stats for a single item using the intermediate steps""" # get the prompt and completion tokens from the intermediate steps @@ -172,7 +178,7 @@ def _compute_usage_stats(self, item: EvalInputItem): llm_latency=llm_latency) return self.usage_stats.usage_stats_items[item.id] - async def run_workflow_local(self, session_manager: SessionManager): + async def run_workflow_local(self, session_manager: SessionManager, dataset_handler: DatasetHandler) -> None: ''' Launch the workflow with the specified questions and extract the output using the jsonpath ''' @@ -272,6 +278,27 @@ async def cancel_pending_tasks(): self.weave_eval.log_prediction(item, output) await self.weave_eval.log_usage_stats(item, usage_stats_item) + + # START INCREMENTAL CHECKPOINTING + if self.config.write_output: + try: + output_dir = self.eval_config.general.output_dir + output_dir.mkdir(parents=True, exist_ok=True) + checkpoint_file = output_dir / "workflow_output.jsonl" + + step_filter = self.eval_config.general.output.workflow_output_step_filter \ + if self.eval_config.general.output else None + + from nat.data_models.evaluator import EvalInput + temp_input = EvalInput(eval_input_items=[item]) + item_json_list = dataset_handler.publish_eval_input(temp_input, step_filter) + item_dict = json.loads(item_json_list)[0] + + # Use to_thread to prevent blocking the event loop + await asyncio.to_thread(self._write_checkpoint_item, checkpoint_file, item_dict) + except Exception: + logger.exception("Failed to write incremental checkpoint for item %s", item.id) + # END INCREMENTAL CHECKPOINTING finally: if root_span_token is not None: ctx_state._root_span_id.reset(root_span_token) @@ -292,14 +319,28 @@ async def wrapped_run(item: EvalInputItem) -> None: await asyncio.gather(*[wrapped_run(item) for item in eval_input_items]) pbar.close() - async def run_workflow_remote(self): + async def run_workflow_remote(self, dataset_handler: DatasetHandler) -> None: from nat.plugins.eval.runtime.remote_workflow import EvaluationRemoteWorkflowHandler handler = EvaluationRemoteWorkflowHandler(self.config, self.eval_config.general.max_concurrency) await handler.run_workflow_remote(self.eval_input) - for item in self.eval_input.eval_input_items: - usage_stats_item = self._compute_usage_stats(item) - self.weave_eval.log_prediction(item, item.output_obj) - await self.weave_eval.log_usage_stats(item, usage_stats_item) + + # Add the checkpointing here too for remote runs + if self.config.write_output: + try: + output_dir = self.eval_config.general.output_dir + output_dir.mkdir(parents=True, exist_ok=True) + checkpoint_file = output_dir / "workflow_output.jsonl" + step_filter = self.eval_config.general.output.workflow_output_step_filter if self.eval_config.general.output else None + + for item in self.eval_input.eval_input_items: + from nat.data_models.evaluator import EvalInput + temp_input = EvalInput(eval_input_items=[item]) + item_dict = json.loads(dataset_handler.publish_eval_input(temp_input, step_filter))[0] + + # Use to_thread here as well + await asyncio.to_thread(self._write_checkpoint_item, checkpoint_file, item_dict) + except Exception: + logger.exception("Failed to write remote checkpoint items") async def profile_workflow(self) -> ProfilerResults: """ @@ -720,7 +761,7 @@ async def run_and_evaluate(self, local_session_manager: SessionManager | None = None try: if self.config.endpoint: - await self.run_workflow_remote() + await self.run_workflow_remote(dataset_handler) elif not self.config.skip_workflow: if session_manager is None: session_manager = await SessionManager.create( @@ -728,7 +769,7 @@ async def run_and_evaluate(self, shared_builder=eval_workflow, max_concurrency=self.eval_config.general.max_concurrency) local_session_manager = session_manager - await self.run_workflow_local(session_manager) + await self.run_workflow_local(session_manager, dataset_handler) # Pre-evaluation process the workflow output self.eval_input = dataset_handler.pre_eval_process_eval_input(self.eval_input)