Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ def __init__(self, config: EvaluationRunConfig, callback_manager: "EvalCallbackM
# Pre-generated OTEL root span_ids for eager trace linking (item_id -> span_id)
self._item_span_ids: dict[str, int] = {}

def _write_checkpoint_item(self, checkpoint_file: Path, item_dict: dict[str, Any]) -> None:
"""Helper to write a single JSONL line to disk. Called via to_thread to avoid blocking."""
with open(checkpoint_file, "a", encoding="utf-8") as f:
f.write(json.dumps(item_dict) + "\n")
f.flush()

def _compute_usage_stats(self, item: EvalInputItem):
"""Compute usage stats for a single item using the intermediate steps"""
# get the prompt and completion tokens from the intermediate steps
Expand Down Expand Up @@ -172,7 +178,7 @@ def _compute_usage_stats(self, item: EvalInputItem):
llm_latency=llm_latency)
return self.usage_stats.usage_stats_items[item.id]

async def run_workflow_local(self, session_manager: SessionManager):
async def run_workflow_local(self, session_manager: SessionManager, dataset_handler: DatasetHandler) -> None:
'''
Launch the workflow with the specified questions and extract the output using the jsonpath
'''
Expand Down Expand Up @@ -272,6 +278,27 @@ async def cancel_pending_tasks():

self.weave_eval.log_prediction(item, output)
await self.weave_eval.log_usage_stats(item, usage_stats_item)

# START INCREMENTAL CHECKPOINTING
if self.config.write_output:
try:
output_dir = self.eval_config.general.output_dir
output_dir.mkdir(parents=True, exist_ok=True)
checkpoint_file = output_dir / "workflow_output.jsonl"

step_filter = self.eval_config.general.output.workflow_output_step_filter \
if self.eval_config.general.output else None

from nat.data_models.evaluator import EvalInput
temp_input = EvalInput(eval_input_items=[item])
item_json_list = dataset_handler.publish_eval_input(temp_input, step_filter)
item_dict = json.loads(item_json_list)[0]

# Use to_thread to prevent blocking the event loop
await asyncio.to_thread(self._write_checkpoint_item, checkpoint_file, item_dict)
except Exception:
logger.exception("Failed to write incremental checkpoint for item %s", item.id)
# END INCREMENTAL CHECKPOINTING
finally:
if root_span_token is not None:
ctx_state._root_span_id.reset(root_span_token)
Expand All @@ -292,14 +319,28 @@ async def wrapped_run(item: EvalInputItem) -> None:
await asyncio.gather(*[wrapped_run(item) for item in eval_input_items])
pbar.close()

async def run_workflow_remote(self):
async def run_workflow_remote(self, dataset_handler: DatasetHandler) -> None:
from nat.plugins.eval.runtime.remote_workflow import EvaluationRemoteWorkflowHandler
handler = EvaluationRemoteWorkflowHandler(self.config, self.eval_config.general.max_concurrency)
await handler.run_workflow_remote(self.eval_input)
for item in self.eval_input.eval_input_items:
usage_stats_item = self._compute_usage_stats(item)
self.weave_eval.log_prediction(item, item.output_obj)
await self.weave_eval.log_usage_stats(item, usage_stats_item)

# Add the checkpointing here too for remote runs
if self.config.write_output:
try:
output_dir = self.eval_config.general.output_dir
output_dir.mkdir(parents=True, exist_ok=True)
checkpoint_file = output_dir / "workflow_output.jsonl"
step_filter = self.eval_config.general.output.workflow_output_step_filter if self.eval_config.general.output else None

for item in self.eval_input.eval_input_items:
from nat.data_models.evaluator import EvalInput
temp_input = EvalInput(eval_input_items=[item])
item_dict = json.loads(dataset_handler.publish_eval_input(temp_input, step_filter))[0]

# Use to_thread here as well
await asyncio.to_thread(self._write_checkpoint_item, checkpoint_file, item_dict)
except Exception:
logger.exception("Failed to write remote checkpoint items")

async def profile_workflow(self) -> ProfilerResults:
"""
Expand Down Expand Up @@ -720,15 +761,15 @@ async def run_and_evaluate(self,
local_session_manager: SessionManager | None = None
try:
if self.config.endpoint:
await self.run_workflow_remote()
await self.run_workflow_remote(dataset_handler)
elif not self.config.skip_workflow:
if session_manager is None:
session_manager = await SessionManager.create(
config=config,
shared_builder=eval_workflow,
max_concurrency=self.eval_config.general.max_concurrency)
local_session_manager = session_manager
await self.run_workflow_local(session_manager)
await self.run_workflow_local(session_manager, dataset_handler)

# Pre-evaluation process the workflow output
self.eval_input = dataset_handler.pre_eval_process_eval_input(self.eval_input)
Expand Down