diff --git a/ajet/tuner_lib/experimental/swarm_overwatch_utils.py b/ajet/tuner_lib/experimental/swarm_overwatch_utils.py index 1064289..4439013 100644 --- a/ajet/tuner_lib/experimental/swarm_overwatch_utils.py +++ b/ajet/tuner_lib/experimental/swarm_overwatch_utils.py @@ -2,6 +2,19 @@ from pydantic import BaseModel +class RewardHistoryEntry(BaseModel): + """A single entry in the reward history.""" + global_step: int + mean_reward: float + std_reward: float + timestamp: float # Unix timestamp when this entry was recorded + + +class RewardHistoryResponse(BaseModel): + """Response containing the reward history for visualization.""" + history: List[RewardHistoryEntry] = [] + + class CurrentBatchRolloutPoolInformation(BaseModel): sample_collection_method: str = "" completed_episodes: int = 0 diff --git a/ajet/tuner_lib/experimental/swarm_server.py b/ajet/tuner_lib/experimental/swarm_server.py index aa82984..4ba500d 100644 --- a/ajet/tuner_lib/experimental/swarm_server.py +++ b/ajet/tuner_lib/experimental/swarm_server.py @@ -11,7 +11,11 @@ from multiprocessing.managers import DictProxy from typing import Coroutine, Optional, Tuple, List from ajet.utils.process_killer import kill_process_tree -from ajet.tuner_lib.experimental.swarm_overwatch_utils import CurrentBatchRolloutPoolInformation +from ajet.tuner_lib.experimental.swarm_overwatch_utils import ( + CurrentBatchRolloutPoolInformation, + RewardHistoryEntry, + RewardHistoryResponse, +) from ajet.tuner_lib.experimental.interchange_utils import DEBUG, VERBOSE from ajet.tuner_lib.experimental.interchange_utils import ( SyncTrainConfigRequest, @@ -63,6 +67,14 @@ def register_enable_swarm_mode_routes( if "current_batch_rollout_pool_information" not in shared_mem_dict: shared_mem_dict["current_batch_rollout_pool_information"] = CurrentBatchRolloutPoolInformation() + # Initialize reward history storage for visualization + if "reward_history" not in shared_mem_dict: + shared_mem_dict["reward_history"] = [] # List of RewardHistoryEntry dicts + + # Initialize reward accumulator for collecting rewards of current global step + if "current_rewards" not in shared_mem_dict: + shared_mem_dict["current_rewards"] = [] # [rewards...] + # ------------------------------------------------------------------------------------------------ # ------ Recycle claimed episodes that client failed to complete in (promised) time -------------- # --------------------------------- claimed -> unclaimed ---------------------------------------- @@ -166,6 +178,35 @@ def _delete_episode_record(episode_uuid: str, shared_mem_dict, shared_mem_dict_l if episode_uuid in shared_mem_dict["unclaimed_episodes"]: shared_mem_dict["unclaimed_episodes"].remove(episode_uuid) + # -------------------------------------------------------------------------------------- + # -------------------------- reward history management --------------------------------- + # -------------------------------------------------------------------------------------- + + def _finalize_reward_history_for_step(global_step, shared_mem_dict, shared_mem_dict_lock): + """Finalize reward statistics for a given global step and add to reward_history.""" + import numpy as np + + rewards = shared_mem_dict.get("current_rewards", []) + if rewards: + rewards = list(rewards) # Convert proxy to list if needed + mean_reward = float(np.mean(rewards)) + std_reward = float(np.std(rewards)) + + history = shared_mem_dict.get("reward_history", []) + history = list(history) # Convert proxy to list if needed + + entry = RewardHistoryEntry( + global_step=global_step, + mean_reward=mean_reward, + std_reward=std_reward, + timestamp=time.time(), + ) + history.append(entry.model_dump()) + shared_mem_dict["reward_history"] = history + + # Clear current rewards for next step + shared_mem_dict["current_rewards"] = [] + # -------------------------------------------------------------------------------------- # -------------------------- return workflow output ------------------------------------ # -------------------------------------------------------------------------------------- @@ -272,6 +313,10 @@ def _clean_up_engine_status(shared_mem_dict_lock, shared_mem_dict): shared_mem_dict["unclaimed_episodes"] = [] logger.info(f"[_clean_up_engine_status] Cleared {num_unclaimed} unclaimed episodes") + # clear reward tracking + shared_mem_dict["current_rewards"] = [] + shared_mem_dict["reward_history"] = [] + # -------------------------------------------------------------------------------------- # -------------------------- fastapi routes -------------------------------------------- # -------------------------------------------------------------------------------------- @@ -446,7 +491,12 @@ async def update_engine_status(req: UpdateEngineStatusRequest): engine_status_detail = req.engine_status_detail global_step = req.global_step if global_step is not None: + previous_global_step = shared_mem_dict.get("global_step", None) shared_mem_dict["global_step"] = global_step + # When global_step changes, finalize reward statistics for the previous step + if previous_global_step is not None and previous_global_step != global_step: + _finalize_reward_history_for_step(previous_global_step, shared_mem_dict, shared_mem_dict_lock) + if engine_status_detail is not None: shared_mem_dict["engine_status_detail"] = engine_status_detail logger.info(f"[update_engine_status] Engine status set to {req.engine_status}") @@ -636,6 +686,21 @@ async def end_episode(req: EndEpisodeRequest): shared_mem_dict_lock, ) + # Record reward to current_rewards + if workflow_output.reward is not None: + reward_value = workflow_output.reward + # Handle both single reward and list of rewards + if isinstance(reward_value, list): + rewards_to_record = reward_value + else: + rewards_to_record = [reward_value] + + with shared_mem_dict_lock: + current_rewards = shared_mem_dict.get("current_rewards", []) + current_rewards = list(current_rewards) # Convert proxy to list if needed + current_rewards.extend(rewards_to_record) + shared_mem_dict["current_rewards"] = current_rewards + elif episode_type == "eval": if engine_status in ["ENGINE.ROLLING"]: await _revert_episode_to_unclaimed(episode_uuid, shared_mem_dict, shared_mem_dict_lock) @@ -779,6 +844,20 @@ async def get_current_batch_rollout_pool_information(): logger.error(f"Error getting current batch rollout pool information: {e}") return CurrentBatchRolloutPoolInformation() + # -------------------------------------------------------------------- + # ------------ get reward history for visualization ------------------ + # -------------------------------------------------------------------- + @app.get("/get_reward_history", response_model=RewardHistoryResponse) + async def get_reward_history(): + """Get the reward history for visualization (reward curves).""" + try: + history = shared_mem_dict.get("reward_history", []) + entries = [RewardHistoryEntry(**entry) for entry in history] + return RewardHistoryResponse(history=entries) + except Exception as e: + logger.error(f"Error getting reward history: {e}") + return RewardHistoryResponse(history=[]) + # -------------------------------------------------------------------- # ------------ bring engine back to ENGINE.OFFLINE ------------------- # -------------------------------------------------------------------- diff --git a/ajet/utils/config_utils.py b/ajet/utils/config_utils.py index 9e6e284..c13b6aa 100644 --- a/ajet/utils/config_utils.py +++ b/ajet/utils/config_utils.py @@ -98,7 +98,7 @@ def _dive_to_set_value(config, dotted_key, value): sub_config[keys[-1]] = value -def align_parameters(from_config_fp, to_config_fp, convertion_json_fg, backbone): +def align_parameters(from_config_fp, to_config_fp, convertion_json_fp, backbone): """Align configuration values based on a conversion map. Parameters @@ -107,7 +107,7 @@ def align_parameters(from_config_fp, to_config_fp, convertion_json_fg, backbone) Source YAML path to read values from. to_config_fp : str Destination YAML path that is updated in place. - convertion_json_fg : str + convertion_json_fp : str JSON path mapping dotted keys between configs. backbone : str Backbone identifier used for framework-specific alignment. @@ -121,7 +121,7 @@ def align_parameters(from_config_fp, to_config_fp, convertion_json_fg, backbone) # read convertion json import json - with open(convertion_json_fg, "r", encoding="utf-8") as file: + with open(convertion_json_fp, "r", encoding="utf-8") as file: convertion_json = json.load(file) logger.success("----------------------------------------------------") diff --git a/ajet/utils/swarm_overwatch.py b/ajet/utils/swarm_overwatch.py index f8003b1..ffbb7d5 100644 --- a/ajet/utils/swarm_overwatch.py +++ b/ajet/utils/swarm_overwatch.py @@ -17,7 +17,10 @@ from rich.text import Text from loguru import logger -from ajet.tuner_lib.experimental.swarm_overwatch_utils import CurrentBatchRolloutPoolInformation +from ajet.tuner_lib.experimental.swarm_overwatch_utils import ( + CurrentBatchRolloutPoolInformation, + RewardHistoryResponse, +) class SwarmOverwatch: @@ -56,6 +59,20 @@ def fetch_pool_info(self) -> Optional[CurrentBatchRolloutPoolInformation]: # logger.error(f"Failed to fetch pool info: {e}") return None + def fetch_reward_history(self) -> Optional[RewardHistoryResponse]: + """Fetch reward history from server for visualization""" + try: + response = self._httpx_client.get( + f"{self.server_url}/get_reward_history", + timeout=5.0, + ) + response.raise_for_status() + data = RewardHistoryResponse.model_validate(response.json()) + return data + except Exception as e: + logger.error(f"Failed to fetch reward history: {e}") + return None + def create_header( self, info: Optional[CurrentBatchRolloutPoolInformation] = None ) -> Panel: @@ -450,6 +467,141 @@ def create_dashboard( return layout + def display_reward_curve(self): + """Display ASCII reward curve in terminal""" + self.console.clear() + + # Fetch reward history + history = self.fetch_reward_history() + if history is None or not history.history: + self.console.print("[bold yellow]No reward history available yet.[/bold yellow]") + self.console.print("[dim]Reward history is recorded when training completes batches with rewards.[/dim]") + self.console.print("\n[dim]Press Enter to return to menu...[/dim]") + input() + return + + # Get terminal size + terminal_width = self.console.width or 80 + terminal_height = self.console.height or 24 + + # Reserve space for header, labels, and footer + chart_width = min(terminal_width - 15, 120) # Reserve space for y-axis labels + chart_height = min(terminal_height - 10, 30) # Reserve space for header and x-axis + + # Extract data + global_steps = [entry.global_step for entry in history.history] + mean_rewards = [entry.mean_reward for entry in history.history] + + # Calculate y-axis range with padding + y_min = min(mean_rewards) + y_max = max(mean_rewards) + y_range = y_max - y_min + if y_range == 0: + y_range = 1.0 # Avoid division by zero + y_min -= 0.5 + y_max += 0.5 + else: + # Add 10% padding + y_min -= y_range * 0.1 + y_max += y_range * 0.1 + y_range = y_max - y_min + + # Calculate x-axis range + x_min = min(global_steps) + x_max = max(global_steps) + x_range = x_max - x_min + if x_range == 0: + x_range = 1 + + # Create the chart grid + chart = [[' ' for _ in range(chart_width)] for _ in range(chart_height)] + + # Plot the data points + for i, (step, reward) in enumerate(zip(global_steps, mean_rewards)): + # Map to chart coordinates + x = int((step - x_min) / x_range * (chart_width - 1)) if x_range > 0 else 0 + y = int((reward - y_min) / y_range * (chart_height - 1)) if y_range > 0 else 0 + + # Invert y because terminal coordinates go top-down + y = chart_height - 1 - y + + # Clamp to valid range + x = max(0, min(chart_width - 1, x)) + y = max(0, min(chart_height - 1, y)) + + # Draw point + chart[y][x] = '*' + + # Connect points with lines if there are multiple points + if len(global_steps) > 1: + for i in range(len(global_steps) - 1): + step1, reward1 = global_steps[i], mean_rewards[i] + step2, reward2 = global_steps[i + 1], mean_rewards[i + 1] + + x1 = int((step1 - x_min) / x_range * (chart_width - 1)) if x_range > 0 else 0 + y1 = int((reward1 - y_min) / y_range * (chart_height - 1)) if y_range > 0 else 0 + x2 = int((step2 - x_min) / x_range * (chart_width - 1)) if x_range > 0 else 0 + y2 = int((reward2 - y_min) / y_range * (chart_height - 1)) if y_range > 0 else 0 + + y1 = chart_height - 1 - y1 + y2 = chart_height - 1 - y2 + + # Simple line drawing between points + steps_between = max(abs(x2 - x1), abs(y2 - y1)) + if steps_between > 0: + for s in range(1, steps_between): + t = s / steps_between + x = int(x1 + t * (x2 - x1)) + y = int(y1 + t * (y2 - y1)) + x = max(0, min(chart_width - 1, x)) + y = max(0, min(chart_height - 1, y)) + if chart[y][x] == ' ': + chart[y][x] = '.' + + # Build the output + output = Text() + output.append("\n Reward Curve (Mean Reward vs Global Step)\n", style="bold cyan") + output.append(f" Server: {self.server_url}\n", style="dim") + output.append(f" Data points: {len(global_steps)}\n\n", style="dim") + + # Draw y-axis labels and chart + y_labels = [] + for i in range(chart_height): + y_val = y_max - (i / (chart_height - 1)) * y_range if chart_height > 1 else y_max + y_labels.append(y_val) + + for i, row in enumerate(chart): + # Y-axis label (only show a few) + if i == 0 or i == chart_height - 1 or i == chart_height // 2: + label = f"{y_labels[i]:8.3f} |" + else: + label = " |" + output.append(label, style="dim") + output.append(''.join(row), style="green") + output.append("\n") + + # X-axis + output.append(" +" + "-" * chart_width + "\n", style="dim") + + # X-axis labels + x_label_line = " " + x_label_line += f"{x_min:<{chart_width // 3}}" + mid_step = x_min + x_range // 2 + x_label_line += f"{mid_step:^{chart_width // 3}}" + x_label_line += f"{x_max:>{chart_width // 3}}" + output.append(x_label_line[:chart_width + 10] + "\n", style="dim") + output.append(" " + " " * (chart_width // 2 - 5) + "Global Step\n", style="dim cyan") + + # Statistics + output.append("\n Statistics:\n", style="bold yellow") + output.append(f" Latest Global Step: {global_steps[-1]}\n", style="green") + output.append(f" Latest Mean Reward: {mean_rewards[-1]:.4f}\n", style="green") + output.append(f" Min Mean Reward: {min(mean_rewards):.4f} (step {global_steps[mean_rewards.index(min(mean_rewards))]})\n", style="cyan") + output.append(f" Max Mean Reward: {max(mean_rewards):.4f} (step {global_steps[mean_rewards.index(max(mean_rewards))]})\n", style="cyan") + + self.console.print(output) + self.console.print("\n[dim]Press Enter to return to menu...[/dim]") + input() def display_latest_llm_call(self): while True: @@ -515,6 +667,7 @@ def choose_run(self) -> str: self.console.print("\n[bold]Choose action:[/bold]") self.console.print(" [bold cyan]o[/bold cyan] - Return to overwatch") self.console.print(" [bold cyan]t[/bold cyan] - Show replay_latest_llm_call") + self.console.print(" [bold cyan]c[/bold cyan] - Show reward curve") self.console.print(" [bold cyan]ctrl+c[/bold cyan] - Exit") choice = input("\n> ").strip().lower() @@ -526,8 +679,12 @@ def choose_run(self) -> str: mode = "replay_latest_llm_call" self.console.clear() continue + elif choice == "c": + self.display_reward_curve() + self.console.clear() + continue else: - self.console.print("[yellow]Invalid choice. Please enter 'o' or 't'.[/yellow]") + self.console.print("[yellow]Invalid choice. Please enter 'o', 't', or 'c'.[/yellow]") def run(self): """Start the monitoring interface"""