diff --git a/environments/math_python/math_python.py b/environments/math_python/math_python.py index 87cc69335..4f9a9d106 100644 --- a/environments/math_python/math_python.py +++ b/environments/math_python/math_python.py @@ -31,7 +31,7 @@ def load_environment( parser = vf.Parser(extract_fn=extract_boxed_answer) math_rubric = vf.MathRubric(parser=parser) - vf_env = vf.PythonEnv( + return vf.PythonEnv( dataset=dataset, system_prompt=system_prompt, parser=parser, @@ -50,7 +50,3 @@ def load_environment( sandbox_client_max_workers=sandbox_client_max_workers, **kwargs, ) - assert vf_env.tools is not None - tool_rubric = vf.ToolRubric(tools=vf_env.tools) - vf_env.rubric = vf.RubricGroup(rubrics=[tool_rubric, vf_env.rubric]) - return vf_env diff --git a/tests/test_env_group.py b/tests/test_env_group.py index 72289fcee..0d834bac7 100644 --- a/tests/test_env_group.py +++ b/tests/test_env_group.py @@ -47,7 +47,7 @@ def func3(completion, **kwargs): assert rubric.env_map == env_map # Should have all unique reward function names - assert set(rubric.all_reward_names) == {"func1", "func2", "func3"} + assert set(rubric.all_reward_names) == {"num_turns", "func1", "func2", "func3"} @pytest.mark.asyncio async def test_env_group_rubric_score_rollout(self, mock_openai_client): diff --git a/verifiers/__init__.py b/verifiers/__init__.py index f148395b2..528c94fe4 100644 --- a/verifiers/__init__.py +++ b/verifiers/__init__.py @@ -29,7 +29,6 @@ from .parsers.xml_parser import XMLParser from .rubrics.judge_rubric import JudgeRubric from .rubrics.rubric_group import RubricGroup -from .rubrics.tool_rubric import ToolRubric from .utils.data_utils import ( extract_boxed_answer, extract_hash_answer, @@ -84,7 +83,6 @@ def setup_logging( "Rubric", "JudgeRubric", "RubricGroup", - "ToolRubric", "MathRubric", "TextArenaEnv", "ReasoningGymEnv", diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 662fa155a..2eb176caa 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -1058,6 +1058,14 @@ def set_kwargs(self, **kwargs) -> None: else: setattr(self, key, value) + def add_rubric(self, rubric: Rubric) -> None: + if self.rubric is None: + self.rubric = rubric + elif isinstance(self.rubric, vf.RubricGroup): + self.rubric.rubrics.append(rubric) + else: + self.rubric = vf.RubricGroup(rubrics=[self.rubric, rubric]) + def set_max_seq_len(self, max_seq_len: int | None) -> None: """Set the maximum sequence length for this environment.""" self.max_seq_len = max_seq_len diff --git a/verifiers/envs/multiturn_env.py b/verifiers/envs/multiturn_env.py index 1aec4dd7a..9841cf6d8 100644 --- a/verifiers/envs/multiturn_env.py +++ b/verifiers/envs/multiturn_env.py @@ -23,11 +23,22 @@ logger = logging.getLogger(__name__) +class MultiTurnMonitorRubric(vf.Rubric): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.add_metric(self.num_turns) + + async def num_turns(self, state: State) -> int: + return len(state["trajectory"]) + + class MultiTurnEnv(vf.Environment): def __init__(self, max_turns: int = -1, **kwargs): super().__init__(**kwargs) self.max_turns = max_turns + self.add_rubric(MultiTurnMonitorRubric()) + @abstractmethod async def env_response( self, messages: Messages, state: State, **kwargs diff --git a/verifiers/envs/python_env.py b/verifiers/envs/python_env.py index 8c695e1b2..b5651ec82 100644 --- a/verifiers/envs/python_env.py +++ b/verifiers/envs/python_env.py @@ -17,6 +17,7 @@ class PythonWorkerState(TypedDict): ready: bool execution_count: int + ready_wait_time: float class PythonWorkerNotReadyError(vf.SandboxError): ... @@ -28,6 +29,15 @@ class PythonWorkerRequestError(vf.SandboxError): ... class PythonWorkerDeadError(vf.SandboxError): ... +class PythonMonitorRubric(vf.Rubric): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.add_metric(self.python_ready_wait_time) + + async def python_ready_wait_time(self, state: vf.State) -> float: + return state["python_state"]["ready_wait_time"] + + class PythonEnv(SandboxEnv): """Sandbox-backed environment exposing a persistent Python REPL.""" @@ -189,6 +199,7 @@ def __init__( start_command=start_command, **kwargs, ) + self.add_rubric(PythonMonitorRubric()) self.add_tool( self.python, args_to_skip=["sandbox_id", "sandbox_state", "python_state"] ) @@ -199,6 +210,7 @@ async def setup_state(self, state: vf.State, **kwargs: Any) -> vf.State: state["python_state"] = { "ready": False, "execution_count": 0, + "ready_wait_time": -1.0, } return state @@ -229,7 +241,7 @@ async def python( ) -> str: """Execute `code` inside persistent Python REPL.""" if not python_state["ready"]: - await self._wait_for_worker_ready(sandbox_state, sandbox_id) + await self._wait_for_worker_ready(sandbox_id, sandbox_state, python_state) python_state["ready"] = True self.logger.debug(f"Executing code\n{code}") sandbox_response = await self._send_worker_request( @@ -242,7 +254,10 @@ async def cleanup_python_state(self, state: vf.State): state.pop("python_state", None) async def _wait_for_worker_ready( - self, sandbox_state: SandboxState, sandbox_id: str + self, + sandbox_id: str, + sandbox_state: SandboxState, + python_state: PythonWorkerState, ) -> None: s = time.time() try: @@ -260,11 +275,13 @@ async def _wait_for_worker_ready( ) if result.exit_code != 0: raise RuntimeError(result.stderr) - self.logger.debug( - f"Waited {time.time() - s:.1f}s for Python worker to be ready" - ) except Exception as e: raise PythonWorkerNotReadyError from e + ready_wait_time = time.time() - s + python_state["ready_wait_time"] = ready_wait_time + self.logger.debug( + f"Waited {ready_wait_time:.1f}s for Python worker to be ready" + ) async def _send_worker_request( self, diff --git a/verifiers/envs/sandbox_env.py b/verifiers/envs/sandbox_env.py index ff7fc52b8..794b88c16 100644 --- a/verifiers/envs/sandbox_env.py +++ b/verifiers/envs/sandbox_env.py @@ -89,6 +89,8 @@ def teardown(self, wait: bool = True) -> None: class SandboxState(TypedDict): ready: bool + ready_wait_time: float + command_execution_times: list[float] class SandboxCreationError(vf.SandboxError): ... @@ -97,6 +99,24 @@ class SandboxCreationError(vf.SandboxError): ... class SandboxNotReadyError(vf.SandboxError): ... +class SandboxMonitorRubric(vf.Rubric): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.add_metric(self.sandbox_ready_wait_time) + self.add_metric(self.sandbox_command_execution_time) + + async def sandbox_ready_wait_time(self, state: vf.State) -> float: + return state["sandbox_state"]["ready_wait_time"] + + async def sandbox_command_execution_time(self, state: vf.State) -> float: + command_execution_times = state["sandbox_state"]["command_execution_times"] + return ( + sum(command_execution_times) / len(command_execution_times) + if len(command_execution_times) > 0 + else 0.0 + ) + + class SandboxEnv(vf.StatefulToolEnv): def __init__( self, @@ -127,6 +147,7 @@ def __init__( stop_errors=stop_errors if stop_errors is not None else [vf.SandboxError], **kwargs, ) + self.add_rubric(SandboxMonitorRubric()) self.timeout_per_command_seconds = timeout_per_command_seconds self.sandbox_client = ThreadedAsyncSandboxClient( max_workers=sandbox_client_max_workers, @@ -173,7 +194,9 @@ async def _wait_for_sandbox_ready( sandbox_state["ready"] = True except Exception as e: raise SandboxNotReadyError(e) - self.logger.debug(f"Waited {time.time() - s:.1f}s for sandbox to be ready") + ready_wait_time = time.time() - s + sandbox_state["ready_wait_time"] = ready_wait_time + self.logger.debug(f"Waited {ready_wait_time:.1f}s for sandbox to be ready") async def bash( self, @@ -197,13 +220,16 @@ async def bash( timeout=self.timeout_per_command_seconds, ) except CommandTimeoutError: - e = time.time() timeout_msg = f"Command timed out after {self.timeout_per_command_seconds}s" self.logger.warning(f"{timeout_msg} in sandbox {sandbox_id}") + sandbox_state["command_execution_times"].append( + self.timeout_per_command_seconds + ) return f"Error: {timeout_msg}" except Exception as e: raise vf.SandboxError from e - e = time.time() + command_execution_time = time.time() - s + sandbox_state["command_execution_times"].append(command_execution_time) stdout = results.stdout.strip() stderr = (results.stderr or "").strip() combined = stdout @@ -213,7 +239,9 @@ async def bash( else: combined = f"stderr:\n{stderr}" output = combined or "(no output)" - self.logger.debug(f"Executed command in {e - s:.1f}s. Got output: {output}") + self.logger.debug( + f"Executed command in {command_execution_time:.1f}s. Got output: {output}" + ) return output async def post_rollout(self, state: vf.State): @@ -252,7 +280,11 @@ async def setup_state(self, state: vf.State, **kwargs) -> vf.State: self.active_sandboxes.add(sandbox.id) self.logger.debug(f"Created sandbox {sandbox.id}") state["sandbox_id"] = sandbox.id - state["sandbox_state"] = {"ready": False} + state["sandbox_state"] = { + "ready": False, + "ready_wait_time": -1.0, + "command_execution_times": [], + } return await super().setup_state(state, **kwargs) def update_tool_args( diff --git a/verifiers/envs/tool_env.py b/verifiers/envs/tool_env.py index 6acd4cca6..fa622c5a7 100644 --- a/verifiers/envs/tool_env.py +++ b/verifiers/envs/tool_env.py @@ -4,10 +4,57 @@ from openai.types.chat import ChatCompletionAssistantMessageParam import verifiers as vf +from verifiers.types import Messages from verifiers.utils.async_utils import maybe_await from verifiers.utils.tool_utils import convert_func_to_oai_tool +class ToolMonitorRubric(vf.Rubric): + def __init__(self, tools: list[Callable] | None = None, **kwargs): + super().__init__(**kwargs) + + self.tools = tools or [] + self.tool_names = [tool.__name__ for tool in self.tools] # type: ignore[union-attr] + + # add tool metrics + self.add_metric(self.total_tool_calls) + for tool_name in self.tool_names: + self.add_metric(self.get_tool_call_count_func(tool_name)) + + async def total_tool_calls(self, completion: Messages) -> float: + """Count the total number of tool calls.""" + total = 0 + assert isinstance(completion, list) + for msg in completion: + if msg["role"] == "assistant" and "tool_calls" in msg: + assistant_msg = cast(ChatCompletionAssistantMessageParam, msg) # type: ignore[redundant-cast] + tool_calls = assistant_msg.get("tool_calls", []) + if isinstance(tool_calls, list): + total += len(tool_calls) + return float(total) + + def get_tool_call_count_func(self, tool_name: str) -> Callable: + """Create a metric that counts calls to a specific tool.""" + + async def tool_call_count_func(completion: Messages) -> int: + """Count calls to {tool_name} tool.""" + count = 0 + # Find tool calls in assistant messages + assert isinstance(completion, list) + for msg in completion: + if msg["role"] == "assistant" and "tool_calls" in msg: + assistant_msg = cast(ChatCompletionAssistantMessageParam, msg) # type: ignore[redundant-cast] + tool_calls = assistant_msg.get("tool_calls", []) + for tool_call in tool_calls: + if tool_call.get("function", {}).get("name") == tool_name: + count += 1 + + return count + + tool_call_count_func.__name__ = f"{tool_name}_calls" + return tool_call_count_func + + class ToolEnv(vf.MultiTurnEnv): def __init__( self, @@ -28,6 +75,8 @@ def __init__( } super().__init__(oai_tools=self.oai_tools, max_turns=max_turns, **kwargs) + self.add_rubric(ToolMonitorRubric(tools=self.tools)) + def _should_stop_for_error(self, err: Exception) -> bool: """Check if error is in stop_errors.""" return any(isinstance(err, err_type) for err_type in self.stop_errors) diff --git a/verifiers/rubrics/tool_rubric.py b/verifiers/rubrics/tool_rubric.py deleted file mode 100644 index 210020926..000000000 --- a/verifiers/rubrics/tool_rubric.py +++ /dev/null @@ -1,61 +0,0 @@ -from typing import Callable, cast - -from openai.types.chat import ChatCompletionAssistantMessageParam - -from verifiers.rubrics.rubric import Rubric -from verifiers.types import Messages -from verifiers.utils.tool_utils import convert_func_to_oai_tool - - -class ToolRubric(Rubric): - """Simple rubric that counts tool calls in completion messages.""" - - def __init__(self, tools: list[Callable] | None = None): - self.tools = tools or [] - self.oai_tools = [convert_func_to_oai_tool(tool) for tool in self.tools] - self.tool_names = [tool.__name__ for tool in self.tools] # type: ignore[union-attr] - - # Build initial reward functions and weights - reward_funcs = [] - reward_funcs.append(self.total_tool_calls) - reward_weights = [0.0] - - for tool_name in self.tool_names: - reward_funcs.append(self.get_tool_call_count_func(tool_name)) - reward_weights.append(0.0) - - # Pass them to parent class - super().__init__(funcs=reward_funcs, weights=reward_weights) - - async def total_tool_calls(self, completion: Messages) -> float: - """Count the total number of tool calls across all assistant messages.""" - total = 0 - assert isinstance(completion, list) - for msg in completion: - if msg["role"] == "assistant" and "tool_calls" in msg: - assistant_msg = cast(ChatCompletionAssistantMessageParam, msg) # type: ignore[redundant-cast] - tool_calls = assistant_msg.get("tool_calls", []) - if isinstance(tool_calls, list): - total += len(tool_calls) - return float(total) - - def get_tool_call_count_func(self, tool_name: str) -> Callable: - """Create a reward function that counts calls to a specific tool.""" - - async def tool_call_count_func(completion: Messages) -> float: - """Count calls to {tool_name} tool.""" - count = 0 - # Find tool calls in assistant messages - assert isinstance(completion, list) - for msg in completion: - if msg["role"] == "assistant" and "tool_calls" in msg: - assistant_msg = cast(ChatCompletionAssistantMessageParam, msg) # type: ignore[redundant-cast] - tool_calls = assistant_msg.get("tool_calls", []) - for tool_call in tool_calls: - if tool_call.get("function", {}).get("name") == tool_name: - count += 1 - - return float(count) - - tool_call_count_func.__name__ = f"{tool_name}_calls" - return tool_call_count_func