-
Notifications
You must be signed in to change notification settings - Fork 28
Hosted RL Entrypoint #256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Hosted RL Entrypoint #256
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
d4229a6
Implement commands for hosted RL
manveerxyz 65b8ad4
Hosted RL
manveerxyz 7b3b945
Allow for user to use just
manveerxyz 89079df
Support tomls on prime rl cmd
manveerxyz 7e0b4e1
Minor fix
manveerxyz deeb088
Cleanup references to RFT
manveerxyz 63b2182
Minor improvements
manveerxyz a3e1cd9
Fix ruff
manveerxyz 1dc8a75
Match post rft run schema to new backend
manveerxyz 2cdfe27
Refactor delete_run method to remove return value and simplify succes…
manveerxyz 5ab66bd
Fix/prime rl list (#267)
JohannesHa 084b563
Add support for run_config
manveerxyz f553271
feat: add eval_config support to RL API client (#271)
JannikSt 92a9956
prime registry support (#215)
kcoopermiller 27be637
Chore/bump version 0.5.8 (#270)
JannikSt 894d04b
Fix: Update eval sample field (#265)
d42me cdefef5
Fix: Remove trailing comma from API token URL (#273)
samsja 84abaf0
resolve conflicts
JannikSt File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,175 @@ | ||
| """Hosted RL (Reinforcement Learning) API client.""" | ||
|
|
||
| from datetime import datetime | ||
| from typing import Any, Dict, List, Optional | ||
|
|
||
| from pydantic import BaseModel, ConfigDict, Field | ||
|
|
||
| from prime_cli.core import APIClient, APIError | ||
|
|
||
|
|
||
| class RLModel(BaseModel): | ||
| """Model available for RL training.""" | ||
|
|
||
| name: str = Field(..., description="Model name") | ||
|
|
||
| model_config = ConfigDict(populate_by_name=True) | ||
|
|
||
|
|
||
| class RLRun(BaseModel): | ||
| """RL Training Run.""" | ||
|
|
||
| id: str = Field(..., description="Run ID") | ||
| name: Optional[str] = Field(None, description="Run name") | ||
| user_id: str = Field(..., alias="userId") | ||
| team_id: Optional[str] = Field(None, alias="teamId") | ||
| cluster_id: str = Field(..., alias="rftClusterId") | ||
| status: str = Field(..., description="Run status") | ||
|
|
||
| # Training configuration | ||
| rollouts_per_example: int = Field(..., alias="rolloutsPerExample") | ||
| seq_len: int = Field(..., alias="seqLen") | ||
| max_steps: int = Field(..., alias="maxSteps") | ||
| base_model: str = Field(..., alias="baseModel") | ||
| environments: List[Dict[str, Any]] = Field(default_factory=list) | ||
| run_config: Optional[Dict[str, Any]] = Field(None, alias="runConfig") | ||
| eval_config: Optional[Dict[str, Any]] = Field(None, alias="evalConfig") | ||
|
|
||
| # Monitoring | ||
| wandb_entity: Optional[str] = Field(None, alias="wandbEntity") | ||
| wandb_project: Optional[str] = Field(None, alias="wandbProject") | ||
| wandb_run_name: Optional[str] = Field(None, alias="wandbRunName") | ||
|
|
||
| # Timestamps | ||
| started_at: Optional[datetime] = Field(None, alias="startedAt") | ||
| completed_at: Optional[datetime] = Field(None, alias="completedAt") | ||
| error_message: Optional[str] = Field(None, alias="errorMessage") | ||
| created_at: datetime = Field(..., alias="createdAt") | ||
| updated_at: datetime = Field(..., alias="updatedAt") | ||
|
|
||
| model_config = ConfigDict(populate_by_name=True) | ||
|
|
||
|
|
||
| class RLClient: | ||
| """Client for hosted RL API.""" | ||
|
|
||
| def __init__(self, client: APIClient) -> None: | ||
| self.client = client | ||
|
|
||
| def list_models(self) -> List[RLModel]: | ||
| """List available models for RL training.""" | ||
| try: | ||
| response = self.client.get("/rft/models") | ||
| models_data = response.get("models", []) | ||
| return [RLModel.model_validate(model) for model in models_data] | ||
| except Exception as e: | ||
| if hasattr(e, "response") and hasattr(e.response, "text"): | ||
| raise APIError(f"Failed to list RL models: {e.response.text}") | ||
| raise APIError(f"Failed to list RL models: {str(e)}") | ||
|
|
||
| def list_runs(self, team_id: Optional[str] = None) -> List[RLRun]: | ||
| """List RL training runs for the authenticated user.""" | ||
| try: | ||
| params = {} | ||
| if team_id: | ||
| params["team_id"] = team_id | ||
| response = self.client.get("/rft/runs", params=params if params else None) | ||
| runs_data = response.get("runs", []) | ||
| return [RLRun.model_validate(run) for run in runs_data] | ||
| except Exception as e: | ||
| if hasattr(e, "response") and hasattr(e.response, "text"): | ||
| raise APIError(f"Failed to list RL runs: {e.response.text}") | ||
| raise APIError(f"Failed to list RL runs: {str(e)}") | ||
|
|
||
| def create_run( | ||
| self, | ||
| model_name: str, | ||
| environments: List[Dict[str, Any]], | ||
| rollouts_per_example: int = 8, | ||
| seq_len: int = 4096, | ||
| max_steps: int = 100, | ||
| name: Optional[str] = None, | ||
| wandb_entity: Optional[str] = None, | ||
| wandb_project: Optional[str] = None, | ||
| wandb_run_name: Optional[str] = None, | ||
| wandb_api_key: Optional[str] = None, | ||
| team_id: Optional[str] = None, | ||
| run_config: Optional[Dict[str, Any]] = None, | ||
| eval_config: Optional[Dict[str, Any]] = None, | ||
| ) -> RLRun: | ||
| """Create a new RL training run.""" | ||
| try: | ||
| secrets: List[Dict[str, str]] = [] | ||
|
|
||
| # Add W&B API key as a secret if provided | ||
| if wandb_api_key: | ||
| secrets.append({"key": "WANDB_API_KEY", "value": wandb_api_key}) | ||
|
|
||
| payload: Dict[str, Any] = { | ||
| "model": {"name": model_name}, | ||
| "environments": environments, | ||
| "rollouts_per_example": rollouts_per_example, | ||
| "seq_len": seq_len, | ||
| "max_steps": max_steps, | ||
| "secrets": secrets, | ||
| } | ||
|
|
||
| if name: | ||
| payload["name"] = name | ||
|
|
||
| # Add monitoring config if W&B is specified | ||
| if wandb_entity or wandb_project: | ||
| payload["monitoring"] = { | ||
| "wandb": { | ||
| "entity": wandb_entity, | ||
| "project": wandb_project, | ||
| "name": wandb_run_name, | ||
| } | ||
| } | ||
|
|
||
| if team_id: | ||
| payload["team_id"] = team_id | ||
|
|
||
| if run_config: | ||
| payload["run_config"] = run_config | ||
|
|
||
| if eval_config: | ||
| payload["eval"] = eval_config | ||
|
|
||
| response = self.client.post("/rft/runs", json=payload) | ||
| return RLRun.model_validate(response.get("run")) | ||
| except Exception as e: | ||
| if hasattr(e, "response") and hasattr(e.response, "text"): | ||
| raise APIError(f"Failed to create RL run: {e.response.text}") | ||
| raise APIError(f"Failed to create RL run: {str(e)}") | ||
|
|
||
| def stop_run(self, run_id: str) -> RLRun: | ||
| """Stop a running RL training run.""" | ||
| try: | ||
| response = self.client.request("PUT", f"/rft/runs/{run_id}/stop") | ||
| return RLRun.model_validate(response.get("run")) | ||
| except Exception as e: | ||
| if hasattr(e, "response") and hasattr(e.response, "text"): | ||
| raise APIError(f"Failed to stop RL run: {e.response.text}") | ||
| raise APIError(f"Failed to stop RL run: {str(e)}") | ||
|
|
||
| def delete_run(self, run_id: str) -> None: | ||
| """Delete an RL run.""" | ||
| try: | ||
| self.client.delete(f"/rft/runs/{run_id}") | ||
| except Exception as e: | ||
| if hasattr(e, "response") and hasattr(e.response, "text"): | ||
| raise APIError(f"Failed to delete RL run: {e.response.text}") | ||
| raise APIError(f"Failed to delete RL run: {str(e)}") | ||
|
|
||
| def get_logs(self, run_id: str, tail_lines: int = 1000) -> str: | ||
| """Get logs for an RL run.""" | ||
| try: | ||
| response = self.client.get( | ||
| f"/rft/runs/{run_id}/logs", params={"tail_lines": tail_lines} | ||
| ) | ||
| return response.get("logs", "") | ||
| except Exception as e: | ||
| if hasattr(e, "response") and hasattr(e.response, "text"): | ||
| raise APIError(f"Failed to get RL run logs: {e.response.text}") | ||
| raise APIError(f"Failed to get RL run logs: {str(e)}") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.