diff --git a/docs/en/example_train_multi_model.md b/docs/en/example_train_multi_model.md index a062ac0..e55d49d 100644 --- a/docs/en/example_train_multi_model.md +++ b/docs/en/example_train_multi_model.md @@ -90,6 +90,9 @@ graph TB C -->|end_episode + reward_14b| S2 ``` +![alt text](https://img.alicdn.com/imgextra/i3/O1CN01vHfNt41LRcQeDMjE4_!!6000000001296-2-tps-1408-768.png) + + **Architecture Explanation**: - **Swarm Server 1 (Port 10086)**: Hosts the 7B model, responsible for Agent 1 and Agent 3's inference and training diff --git a/docs/en/example_train_multi_model.zh.md b/docs/en/example_train_multi_model.zh.md index 772a84f..8e74c6b 100644 --- a/docs/en/example_train_multi_model.zh.md +++ b/docs/en/example_train_multi_model.zh.md @@ -88,6 +88,9 @@ graph TB C -->|end_episode + reward_14b| S2 ``` +![alt text](https://img.alicdn.com/imgextra/i3/O1CN01vHfNt41LRcQeDMjE4_!!6000000001296-2-tps-1408-768.png) + + **架构说明**: - **Swarm Server 1 (端口 10086)**:承载 7B 模型,负责 Agent 1 和 Agent 3 的推理与训练 @@ -176,6 +179,8 @@ sequenceDiagram 4. 将各自的奖励汇报给对应的 Swarm Server 5. 两个 Server 独立执行策略梯度更新 + + ## 训练曲线 ![alt text](https://img.alicdn.com/imgextra/i2/O1CN0161wtDk1zZwFmIX15x_!!6000000006729-2-tps-2978-1413.png) diff --git a/tutorial/opencode_build_openclaw_agent/cheatsheet.md b/tutorial/opencode_build_openclaw_agent/cheatsheet.md new file mode 100644 index 0000000..0d79b05 --- /dev/null +++ b/tutorial/opencode_build_openclaw_agent/cheatsheet.md @@ -0,0 +1,47 @@ +# OpenClaw Reward Cheatsheet + +## Run the test + +```bash +cd agentjet/tutorial/opencode_build_openclaw_agent + +# pointwise (default) +DASHSCOPE_API_KEY=your_key python test_reward.py + +# listwise +REWARD_MODE=listwise DASHSCOPE_API_KEY=your_key python test_reward.py +``` + +## Run the training endpoint + +```bash +# pointwise (default) +AJET_SWARM_URL=http://localhost:10086 \ +DASHSCOPE_API_KEY=your_key \ +REWARD_MODE=pointwise \ +python fake_vllm_endpoint.py + +# listwise +AJET_SWARM_URL=http://localhost:10086 \ +DASHSCOPE_API_KEY=your_key \ +REWARD_MODE=listwise \ +python fake_vllm_endpoint.py +``` + +## Reward modes + +| Mode | Description | +|------|-------------| +| `pointwise` | Each response scored independently (0.0–1.0) | +| `listwise` | All responses ranked together (best=1.0, worst=0.0) | + +## Environment variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `REWARD_MODE` | `pointwise` | `pointwise` or `listwise` | +| `DASHSCOPE_API_KEY` | — | DashScope API key (required) | +| `JUDGE_MODEL` | `qwen-plus` | Judge model name | +| `JUDGE_BASE_URL` | DashScope endpoint | Judge model base URL | +| `AJET_SWARM_URL` | `http://localhost:10086` | Swarm server URL | +| `NUM_REPEAT` | `4` | GRPO N (responses per query) | diff --git a/tutorial/opencode_build_openclaw_agent/fake_vllm_endpoint.py b/tutorial/opencode_build_openclaw_agent/fake_vllm_endpoint.py index 0831cd2..e73cc80 100644 --- a/tutorial/opencode_build_openclaw_agent/fake_vllm_endpoint.py +++ b/tutorial/opencode_build_openclaw_agent/fake_vllm_endpoint.py @@ -25,7 +25,7 @@ import sys sys.path.insert(0, os.path.dirname(__file__)) -from on_user_submit_new_requests import on_user_submit_new_requests +from on_user_submit_new_requests import on_user_submit_new_requests, get_query_history from on_compute_relative_reward import on_compute_relative_reward # Configuration @@ -91,6 +91,14 @@ async def proxy_chat_completion(base_url: str, api_key: str, request: Request, i json_data = await request.json() json_data["stream"] = is_stream + # Remove fields not supported by vLLM to avoid warnings + UNSUPPORTED_FIELDS = {"strict", "store"} + for field in UNSUPPORTED_FIELDS: + json_data.pop(field, None) + # Also remove 'strict' from response_format if present + if "response_format" in json_data and isinstance(json_data["response_format"], dict): + json_data["response_format"].pop("strict", None) + async with httpx.AsyncClient(timeout=300.0) as client: resp = await client.post(f"{base_url}/chat/completions", json=json_data, headers=headers) resp.raise_for_status() @@ -200,7 +208,7 @@ async def handle_one2many_request(request: Request, request_id: str) -> Dict | L valid_results = await run_all_episodes(request, is_stream) all_answers = [extract_assistant_message(r.response) for r in valid_results] - rewards = await on_compute_relative_reward(valid_results, all_answers) + rewards = await on_compute_relative_reward(valid_results, all_answers, question=user_query) await finalize_episodes(task, valid_results, rewards) @@ -259,7 +267,7 @@ async def health_check(): @app.get("/requests") async def get_requests(): """Get all recorded user requests.""" - return {"requests": USER_REQUEST_RECORD} + return {"requests": get_query_history()} if __name__ == "__main__": diff --git a/tutorial/opencode_build_openclaw_agent/on_compute_relative_reward.py b/tutorial/opencode_build_openclaw_agent/on_compute_relative_reward.py index 5bafd2f..53894a9 100644 --- a/tutorial/opencode_build_openclaw_agent/on_compute_relative_reward.py +++ b/tutorial/opencode_build_openclaw_agent/on_compute_relative_reward.py @@ -1,26 +1,55 @@ # -*- coding: utf-8 -*- -"""Compute relative rewards based on extraversion personality alignment using OpenJudge.""" +"""Compute relative rewards based on extraversion, relevance, diversity, and repetition quality.""" import os +import collections from typing import List, Dict + +from loguru import logger from beast_logger import print_listofdict from openjudge.graders.base_grader import GraderMode, GraderScore, GraderRank from openjudge.graders.llm_grader import LLMGrader +from openjudge.graders.common.relevance import RelevanceGrader +from openjudge.graders.format.ngram_repetition_penalty import NgramRepetitionPenaltyGrader from openjudge.models import OpenAIChatModel +try: + from ajet.utils.compute_madness import has_repeat +except ImportError: + # Fallback: when running outside the full ajet package (e.g. tests), + # resolve relative to the repo root. + import sys as _sys + from pathlib import Path as _Path + _repo_root = str(_Path(__file__).resolve().parents[2]) + if _repo_root not in _sys.path: + _sys.path.insert(0, _repo_root) + from ajet.utils.compute_madness import has_repeat +# --------------------------------------------------------------------------- # Configuration -REWARD_MODE = os.getenv("REWARD_MODE", "pointwise") # Options: pointwise, listwise +# --------------------------------------------------------------------------- +REWARD_MODE = os.getenv("REWARD_MODE", "pointwise") # pointwise | listwise API_KEY = os.getenv("DASHSCOPE_API_KEY", "sk-xxx") BASE_URL = os.getenv("JUDGE_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") JUDGE_MODEL = os.getenv("JUDGE_MODEL", "qwen-plus") -# OpenJudge grader setup +# Reward weights (must sum to 1.0) +W_EXTRAVERSION = float(os.getenv("W_EXTRAVERSION", "0.5")) +W_RELEVANCE = float(os.getenv("W_RELEVANCE", "0.3")) +W_DIVERSITY = float(os.getenv("W_DIVERSITY", "0.2")) + +# Cross-request history buffer size +HISTORY_MAX_SIZE = int(os.getenv("DIVERSITY_HISTORY_SIZE", "25")) + +# --------------------------------------------------------------------------- +# Shared model & graders +# --------------------------------------------------------------------------- judge_model = OpenAIChatModel( model=JUDGE_MODEL, api_key=API_KEY, base_url=BASE_URL, ) +# --- Extraversion grader (custom LLM prompt) --- EXTRAVERSION_PROMPT = """You are evaluating responses for extraversion personality traits. Extraversion characteristics include: @@ -41,6 +70,153 @@ - "score": float between 0.0 and 1.0 - "reason": brief explanation""" +pointwise_grader = LLMGrader( + name="extraversion_pointwise", + mode=GraderMode.POINTWISE, + description="Evaluate extraversion traits", + model=judge_model, + template=EXTRAVERSION_PROMPT, +) + +# --- Relevance grader (built-in OpenJudge) --- +relevance_grader = RelevanceGrader(model=judge_model) + +# --- Repetition penalty grader (deterministic, no LLM) --- +# Detects n-gram repetition within a single response. +# Returns score in [0, 1] where 1 = no repetition, 0 = heavily repetitive. +repetition_grader = NgramRepetitionPenaltyGrader( + n=4, # 4-gram detection + penalty_threshold=0.15, # trigger penalty when >15% of n-grams are repeated + use_soft_penalty=True, # gradual penalty rather than cliff + max_penalty=-1.0, # worst case: score becomes 0 + min_scaling=0.0, # at max penalty, multiplier goes to 0 +) + +# --------------------------------------------------------------------------- +# In-process history of recent responses (for cross-request diversity) +# --------------------------------------------------------------------------- +_response_history: List[str] = [] + + +def record_responses_to_history(contents: List[str]) -> None: + """Append new responses to the rolling history buffer.""" + _response_history.extend(contents) + # Trim to keep only the most recent entries + while len(_response_history) > HISTORY_MAX_SIZE: + _response_history.pop(0) + + +# --------------------------------------------------------------------------- +# Diversity: n-gram overlap (fast, deterministic, no LLM needed) +# --------------------------------------------------------------------------- +def _get_ngrams(text: str, n: int = 3) -> collections.Counter: + """Extract character-level n-grams from text.""" + tokens = text.lower().split() + if len(tokens) < n: + return collections.Counter(tokens) + return collections.Counter( + tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1) + ) + + +def _ngram_overlap(text_a: str, text_b: str, n: int = 3) -> float: + """Compute Jaccard overlap of n-grams between two texts. Returns 0-1.""" + ngrams_a = _get_ngrams(text_a, n) + ngrams_b = _get_ngrams(text_b, n) + if not ngrams_a or not ngrams_b: + return 0.0 + intersection = sum((ngrams_a & ngrams_b).values()) + union = sum((ngrams_a | ngrams_b).values()) + return intersection / union if union > 0 else 0.0 + + +def compute_diversity_scores(contents: List[str], history: List[str]) -> List[float]: + """ + Compute a diversity score for each response (0 = duplicate, 1 = fully unique). + + Two components: + 1. Within-batch: average pairwise n-gram overlap with other responses in the batch + 2. Cross-request: max n-gram overlap with any response in the history buffer + + Final diversity_score = 1 - max(within_batch_overlap, cross_request_overlap) + """ + n = len(contents) + scores = [] + for i, content_i in enumerate(contents): + # Within-batch overlap: average overlap with other responses in this batch + if n > 1: + batch_overlaps = [ + _ngram_overlap(content_i, contents[j]) + for j in range(n) + if j != i + ] + within_batch = max(batch_overlaps) # worst-case overlap within batch + else: + within_batch = 0.0 + + # Cross-request overlap: max overlap with any historical response + if history: + cross_request = max(_ngram_overlap(content_i, h) for h in history) + else: + cross_request = 0.0 + + overlap = max(within_batch, cross_request) + scores.append(1.0 - overlap) + + return scores + + +# --------------------------------------------------------------------------- +# Quality gate: repetition & degeneration detection (deterministic) +# --------------------------------------------------------------------------- +async def compute_quality_scores(contents: List[str]) -> List[float]: + """ + Compute a quality multiplier for each response (0 = degenerate, 1 = clean). + + Combines two signals: + 1. NgramRepetitionPenaltyGrader — detects looping/repeated n-gram blocks + 2. compute_string_madness — catches nonsense chars, special token leaks, + character-level repetition + + Returns a score in [0, 1] that will be used as a *multiplier* on the + composite reward, so degenerate outputs get crushed to near-zero. + """ + scores = [] + for content in contents: + # --- Signal 1: n-gram repetition (OpenJudge) --- + try: + rep_result = await repetition_grader.aevaluate(response=content) + # NgramRepetitionPenaltyGrader returns penalty in [-1, 0]: + # 0 = no repetition, -1 = max repetition + # Convert to quality: add 1 → [0, 1] + ngram_penalty = rep_result.score if isinstance(rep_result, GraderScore) else 0.0 + ngram_score = 1.0 + ngram_penalty + except Exception as e: + logger.warning(f"NgramRepetitionPenaltyGrader failed: {e}") + ngram_score = 1.0 + + # --- Signal 2: string madness (char-level degeneration) --- + # Only check for word/char repetition and special token leaks. + # We pass checklist=[] to skip the non-ASCII check (accented + # characters like é are legitimate), and check repetition manually. + madness_score = 1.0 # assume clean + if "<|im_start|>" in content: + madness_score = 0.0 + elif has_repeat(content.split(), remember_n_words=5, patience_max=10): + madness_score = 0.0 + elif has_repeat(content, remember_n_words=4, patience_max=200): + madness_score = 0.0 + + # Combined quality: take the minimum (strictest gate wins) + quality = max(0.0, min(1.0, min(ngram_score, madness_score))) + scores.append(quality) + + return scores + + +# --------------------------------------------------------------------------- +# Extraversion scoring (pointwise / listwise) +# --------------------------------------------------------------------------- def build_listwise_template(n: int) -> str: """Build a listwise prompt template for n responses.""" answers_block = "\n".join([f"{i+1}. {{answer_{i+1}}}" for i in range(n)]) @@ -62,33 +238,20 @@ def build_listwise_template(n: int) -> str: - "rank": list of integers (1-indexed) ordered from most to least extraverted, e.g. [2, 1, 3] - "reason": brief explanation of the ranking""" -pointwise_grader = LLMGrader( - name="extraversion_pointwise", - mode=GraderMode.POINTWISE, - description="Evaluate extraversion traits", - model=judge_model, - template=EXTRAVERSION_PROMPT, -) - -async def compute_pointwise_rewards(question: str, all_answers: List[Dict]) -> List[float]: - """Compute rewards using OpenJudge pointwise grading.""" +async def compute_pointwise_extraversion(question: str, all_answers: List[Dict]) -> List[float]: + """Compute extraversion scores using pointwise grading.""" scores = [] for answer in all_answers: content = answer.get("content", "") result = await pointwise_grader.aevaluate(question=question, response=content) - if isinstance(result, GraderScore): - # score is already normalized 0-1 by OpenJudge - score = result.score - else: - score = 0.0 + score = result.score if isinstance(result, GraderScore) else 0.0 scores.append(score) - answer["reward"] = score return scores -async def compute_listwise_rewards(question: str, all_answers: List[Dict]) -> List[float]: - """Compute rewards using OpenJudge listwise ranking.""" +async def compute_listwise_extraversion(question: str, all_answers: List[Dict]) -> List[float]: + """Compute extraversion scores using listwise ranking.""" n = len(all_answers) template = build_listwise_template(n) grader = LLMGrader( @@ -106,24 +269,93 @@ async def compute_listwise_rewards(question: str, all_answers: List[Dict]) -> Li scores = [0.0] * n if isinstance(result, GraderRank): - # rank is a list of 1-indexed positions ordered best to worst - # convert to reward: rank 1 (best) -> 1.0, rank n (worst) -> 0.0 for position, idx in enumerate(result.rank): scores[idx - 1] = 1.0 - (position / (n - 1)) if n > 1 else 0.5 + return scores + - for answer, score in zip(all_answers, scores): - answer["reward"] = score +# --------------------------------------------------------------------------- +# Relevance scoring (built-in OpenJudge RelevanceGrader, score 1-5 → 0-1) +# --------------------------------------------------------------------------- +async def compute_relevance_scores(question: str, all_answers: List[Dict]) -> List[float]: + """Score how relevant each response is to the question. Returns 0-1.""" + scores = [] + for answer in all_answers: + content = answer.get("content", "") + result = await relevance_grader.aevaluate(query=question, response=content) + if isinstance(result, GraderScore): + # RelevanceGrader returns 1-5; normalise to 0-1 + score = (result.score - 1.0) / 4.0 + else: + score = 0.0 + scores.append(max(0.0, min(1.0, score))) return scores -async def on_compute_relative_reward(valid_results: List, all_answers: List[Dict]) -> List[float]: - """Compute relative rewards for extraversion alignment.""" - question = valid_results[0].get("question", "") if valid_results else "" +# --------------------------------------------------------------------------- +# Main entry point +# --------------------------------------------------------------------------- +async def on_compute_relative_reward( + valid_results: List, + all_answers: List[Dict], + question: str = "", +) -> List[float]: + """ + Compute composite rewards combining extraversion, relevance, diversity, + and a quality gate for repetition/degeneration. + + Final reward = quality * (W_EXTRAVERSION * extraversion + + W_RELEVANCE * relevance + + W_DIVERSITY * diversity) + The quality multiplier (0-1) acts as a hard gate: degenerate responses + (looping, repeated paragraphs, nonsense characters) get their reward + crushed toward zero regardless of other signal scores. + """ + contents = [a.get("content", "") for a in all_answers] + + # 0. Quality gate (deterministic — fast, runs first) + quality_scores = await compute_quality_scores(contents) + + # 1. Extraversion score (LLM-based) if REWARD_MODE == "listwise": - scores = await compute_listwise_rewards(question, all_answers) - else: # pointwise (default) - scores = await compute_pointwise_rewards(question, all_answers) + extraversion_scores = await compute_listwise_extraversion(question, all_answers) + else: + extraversion_scores = await compute_pointwise_extraversion(question, all_answers) - print_listofdict(all_answers, header=f"on_compute_relative_reward (mode={REWARD_MODE})") - return scores + # 2. Relevance score (LLM-based) + relevance_scores = await compute_relevance_scores(question, all_answers) + + # 3. Diversity score (deterministic, n-gram overlap) + diversity_scores = compute_diversity_scores(contents, _response_history) + + # Composite reward = quality * weighted_sum + final_scores = [] + for i in range(len(all_answers)): + weighted_sum = ( + W_EXTRAVERSION * extraversion_scores[i] + + W_RELEVANCE * relevance_scores[i] + + W_DIVERSITY * diversity_scores[i] + ) + composite = quality_scores[i] * weighted_sum + final_scores.append(round(composite, 4)) + + # Annotate the answer dict for logging + all_answers[i]["reward"] = final_scores[i] + all_answers[i]["quality"] = round(quality_scores[i], 4) + all_answers[i]["extraversion"] = round(extraversion_scores[i], 4) + all_answers[i]["relevance"] = round(relevance_scores[i], 4) + all_answers[i]["diversity"] = round(diversity_scores[i], 4) + + # Update history buffer with this batch's responses + record_responses_to_history(contents) + + print_listofdict( + all_answers, + header=( + f"on_compute_relative_reward (mode={REWARD_MODE}, " + f"w_ext={W_EXTRAVERSION}, w_rel={W_RELEVANCE}, w_div={W_DIVERSITY}, " + f"quality_gate=multiplicative)" + ), + ) + return final_scores diff --git a/tutorial/opencode_build_openclaw_agent/on_user_submit_new_requests.py b/tutorial/opencode_build_openclaw_agent/on_user_submit_new_requests.py index 07f32a5..11b7932 100644 --- a/tutorial/opencode_build_openclaw_agent/on_user_submit_new_requests.py +++ b/tutorial/opencode_build_openclaw_agent/on_user_submit_new_requests.py @@ -1,8 +1,44 @@ # -*- coding: utf-8 -*- -"""Handle new user requests.""" +"""Handle new user requests and track query history for diversity awareness.""" +from typing import List, Dict +from loguru import logger from ajet.schema.task import Task +# Rolling buffer of recent queries — used to detect repeated / near-duplicate +# questions so the system can log warnings. The response-level diversity +# signal lives in on_compute_relative_reward._response_history. +_query_history: List[Dict] = [] +QUERY_HISTORY_MAX = 100 + + +def get_query_history() -> List[Dict]: + """Return the current query history (read-only copy).""" + return list(_query_history) + + async def on_user_submit_new_requests(request_id: str, task: Task) -> None: - """Store user request when submitted.""" - pass # No special processing needed for this use case + """ + Store user request metadata when submitted. + + This populates a lightweight in-process history so that: + 1. The /requests endpoint can expose recent queries for debugging. + 2. We can detect if the same question keeps appearing, which signals + a data distribution issue upstream rather than a model problem. + """ + entry = { + "request_id": request_id, + "task_id": task.task_id, + "query": task.main_query, + } + _query_history.append(entry) + + # Trim oldest entries + while len(_query_history) > QUERY_HISTORY_MAX: + _query_history.pop(0) + + logger.info( + f"[on_user_submit] request_id={request_id} " + f"query_len={len(task.main_query)} " + f"history_size={len(_query_history)}" + ) diff --git a/tutorial/opencode_build_openclaw_agent/test_reward.py b/tutorial/opencode_build_openclaw_agent/test_reward.py index a731b25..8b65922 100644 --- a/tutorial/opencode_build_openclaw_agent/test_reward.py +++ b/tutorial/opencode_build_openclaw_agent/test_reward.py @@ -1,90 +1,301 @@ #!/usr/bin/env python3 -"""Test script for on_compute_relative_reward.py using real OpenJudge API.""" +"""Test script for on_compute_relative_reward.py using real OpenJudge API. + +Tests four reward dimensions: + 1. Extraversion — enthusiastic responses score higher + 2. Relevance — on-topic responses score higher than off-topic + 3. Diversity — unique responses score higher than near-duplicates + 4. Quality gate — repetitive/degenerate responses get crushed +""" import asyncio import sys import os sys.path.insert(0, os.path.dirname(__file__)) -os.environ["DASHSCOPE_API_KEY"] = os.getenv("DASHSCOPE_API_KEY", "sk-xxx") +os.environ["DASHSCOPE_API_KEY"] = os.getenv("DASHSCOPE_API_KEY", "sk-311cfac3a0f94ff4b5ddf401f70fa338") -async def test_pointwise(): - """Test pointwise reward mode with real API.""" - print("\n=== Testing Pointwise Mode (real API) ===") +async def test_pointwise_composite(): + """Test pointwise composite reward (extraversion + relevance + diversity).""" + print("\n=== Testing Pointwise Composite Reward ===") os.environ["REWARD_MODE"] = "pointwise" import importlib import on_compute_relative_reward as mod importlib.reload(mod) + mod._response_history.clear() # fresh history for test isolation - valid_results = [{"question": "What are your thoughts on Paris?"}] + question = "What are your thoughts on Paris?" all_answers = [ - {"content": "I'm so excited about Paris! It's amazing and wonderful!"}, + {"content": "I'm so excited about Paris! The Eiffel Tower at night is breathtaking and the cafes are amazing!"}, {"content": "Paris is a city in France."}, - {"content": "I absolutely love Paris! The energy is fantastic and vibrant!"}, + {"content": "I absolutely love Paris! The energy on the Champs-Élysées is fantastic and so vibrant!"}, ] try: - scores = await mod.on_compute_relative_reward(valid_results, all_answers) - print(f"Scores: {scores}") + scores = await mod.on_compute_relative_reward([], all_answers, question=question) + print(f"Composite scores: {scores}") + for a in all_answers: + print(f" ext={a.get('extraversion')}, rel={a.get('relevance')}, " + f"div={a.get('diversity')}, reward={a.get('reward')} " + f"content={a['content'][:50]}...") + assert len(scores) == 3, f"Expected 3 scores, got {len(scores)}" assert all(isinstance(s, float) for s in scores), "All scores should be floats" - # extraverted responses should score higher than neutral - assert scores[0] > scores[1], f"Extraverted response should score higher than neutral: {scores}" - assert scores[2] > scores[1], f"Extraverted response should score higher than neutral: {scores}" - print("✓ Pointwise mode test passed") + # Extraverted + relevant responses should beat the flat neutral one + assert scores[0] > scores[1], f"Enthusiastic on-topic should beat neutral: {scores}" + assert scores[2] > scores[1], f"Enthusiastic on-topic should beat neutral: {scores}" + print("PASSED") + return True + except Exception as e: + print(f"FAILED: {e}") + import traceback; traceback.print_exc() + return False + + +async def test_relevance_penalty(): + """Off-topic answers should get lower composite scores than on-topic ones.""" + print("\n=== Testing Relevance Penalty ===") + os.environ["REWARD_MODE"] = "pointwise" + + import importlib + import on_compute_relative_reward as mod + importlib.reload(mod) + mod._response_history.clear() + + question = "What is your favorite food?" + all_answers = [ + # On-topic, extraverted + {"content": "Oh my gosh, I absolutely LOVE sushi! The flavors are incredible and I get so excited every time!"}, + # Off-topic, extraverted (talks about space, not food) + {"content": "WOW space exploration is SO exciting! Rockets launching into the sky fills me with energy!!!"}, + ] + + try: + scores = await mod.on_compute_relative_reward([], all_answers, question=question) + print(f"Scores: {scores}") + for a in all_answers: + print(f" ext={a.get('extraversion')}, rel={a.get('relevance')}, " + f"div={a.get('diversity')}, reward={a.get('reward')} " + f"content={a['content'][:50]}...") + + # Both are extraverted, but on-topic should win because of relevance + assert scores[0] > scores[1], \ + f"On-topic extraverted should beat off-topic extraverted: {scores}" + print("PASSED") return True except Exception as e: - print(f"✗ Pointwise mode test failed: {e}") - import traceback - traceback.print_exc() + print(f"FAILED: {e}") + import traceback; traceback.print_exc() return False -async def test_listwise(): - """Test listwise reward mode with real API.""" - print("\n=== Testing Listwise Mode (real API) ===") +async def test_diversity_penalty(): + """Near-duplicate answers should get lower diversity scores.""" + print("\n=== Testing Diversity Penalty ===") + os.environ["REWARD_MODE"] = "pointwise" + + import importlib + import on_compute_relative_reward as mod + importlib.reload(mod) + mod._response_history.clear() + + question = "Tell me about your hobbies." + all_answers = [ + {"content": "I love hiking in the mountains! The fresh air and stunning views make me feel so alive and energized!"}, + # Near-duplicate of answer 0 + {"content": "I love hiking in the mountains! The fresh air and stunning views make me feel so alive and energized!"}, + # Unique answer + {"content": "Dancing is my absolute passion! Nothing beats the energy of moving to great music with friends!"}, + ] + + try: + scores = await mod.on_compute_relative_reward([], all_answers, question=question) + print(f"Scores: {scores}") + for a in all_answers: + print(f" ext={a.get('extraversion')}, rel={a.get('relevance')}, " + f"div={a.get('diversity')}, reward={a.get('reward')} " + f"content={a['content'][:50]}...") + + # The duplicate pair should have lower diversity than the unique one + div_duplicate = all_answers[0].get("diversity", 1.0) + div_unique = all_answers[2].get("diversity", 0.0) + assert div_unique > div_duplicate, \ + f"Unique response should have higher diversity ({div_unique}) than duplicate ({div_duplicate})" + print("PASSED") + return True + except Exception as e: + print(f"FAILED: {e}") + import traceback; traceback.print_exc() + return False + + +async def test_cross_request_diversity(): + """Answers that repeat historical responses should be penalized.""" + print("\n=== Testing Cross-Request Diversity ===") + os.environ["REWARD_MODE"] = "pointwise" + + import importlib + import on_compute_relative_reward as mod + importlib.reload(mod) + mod._response_history.clear() + + # Simulate a prior request that produced a response + mod.record_responses_to_history([ + "I love hiking in the mountains! The fresh air and stunning views make me feel so alive!" + ]) + + question = "What do you enjoy doing on weekends?" + all_answers = [ + # Repeats the historical response almost verbatim + {"content": "I love hiking in the mountains! The fresh air and stunning views make me feel so alive!"}, + # Fresh, unique response + {"content": "Weekends are for exploring new restaurants and trying exotic cuisines! I get so thrilled by new flavors!"}, + ] + + try: + scores = await mod.on_compute_relative_reward([], all_answers, question=question) + print(f"Scores: {scores}") + for a in all_answers: + print(f" ext={a.get('extraversion')}, rel={a.get('relevance')}, " + f"div={a.get('diversity')}, reward={a.get('reward')} " + f"content={a['content'][:50]}...") + + div_stale = all_answers[0].get("diversity", 1.0) + div_fresh = all_answers[1].get("diversity", 0.0) + assert div_fresh > div_stale, \ + f"Fresh response should have higher diversity ({div_fresh}) than stale ({div_stale})" + print("PASSED") + return True + except Exception as e: + print(f"FAILED: {e}") + import traceback; traceback.print_exc() + return False + + +async def test_repetition_penalty(): + """Degenerate looping responses should get near-zero reward.""" + print("\n=== Testing Repetition / Degeneration Penalty ===") + os.environ["REWARD_MODE"] = "pointwise" + + import importlib + import on_compute_relative_reward as mod + importlib.reload(mod) + mod._response_history.clear() + + question = "Tell me about Dunfermline." + + # Build a degenerate looping response (similar to the real failure case) + good_intro = "Hello! Dunfermline is a charming town in Fife, Scotland, with a rich history." + loop_block = ( + "\n\n---\n\n" + "If you have any specific questions or need more information, just " + "let me know! I'm here to assist you in making your visit to " + "Dunfermline a delightful experience.\n\n---\n\n" + "Looking forward to your wonderful Dunfermline adventures!\n\n---\n\n" + "Thank you for the opportunity to share my thoughts on Dunfermline. " + "If you have any more questions or need assistance, feel free to " + "reach out!" + ) + degenerate_response = good_intro + (loop_block * 15) # repeat the block many times + + all_answers = [ + # Degenerate looping response + {"content": degenerate_response}, + # Clean, concise, extraverted response + {"content": "Dunfermline is absolutely wonderful! The abbey ruins are breathtaking and the town has such vibrant energy. I love the mix of history and modern community spirit there!"}, + ] + + try: + scores = await mod.on_compute_relative_reward([], all_answers, question=question) + print(f"Scores: {scores}") + for a in all_answers: + print(f" quality={a.get('quality')}, ext={a.get('extraversion')}, " + f"rel={a.get('relevance')}, div={a.get('diversity')}, " + f"reward={a.get('reward')} " + f"content={a['content'][:60]}...") + + quality_degenerate = all_answers[0].get("quality", 1.0) + quality_clean = all_answers[1].get("quality", 0.0) + print(f" Quality scores: degenerate={quality_degenerate}, clean={quality_clean}") + + # The degenerate response should have much lower quality + assert quality_clean > quality_degenerate, \ + f"Clean response quality ({quality_clean}) should exceed degenerate ({quality_degenerate})" + # The clean response should win overall + assert scores[1] > scores[0], \ + f"Clean response ({scores[1]}) should beat degenerate ({scores[0]})" + print("PASSED") + return True + except Exception as e: + print(f"FAILED: {e}") + import traceback; traceback.print_exc() + return False + + +async def test_listwise_composite(): + """Listwise mode should also produce composite rewards.""" + print("\n=== Testing Listwise Composite Reward ===") os.environ["REWARD_MODE"] = "listwise" import importlib import on_compute_relative_reward as mod importlib.reload(mod) + mod._response_history.clear() - valid_results = [{"question": "What are your thoughts on Paris?"}] + question = "What are your thoughts on Paris?" all_answers = [ - {"content": "I'm so excited about Paris! It's amazing and wonderful!"}, + {"content": "I'm so excited about Paris! The Eiffel Tower at night is breathtaking!"}, {"content": "Paris is a city in France."}, - {"content": "I absolutely love Paris! The energy is fantastic and vibrant!"}, + {"content": "I absolutely love Paris! The Champs-Élysées energy is fantastic!"}, ] try: - scores = await mod.on_compute_relative_reward(valid_results, all_answers) + scores = await mod.on_compute_relative_reward([], all_answers, question=question) print(f"Scores: {scores}") + for a in all_answers: + print(f" ext={a.get('extraversion')}, rel={a.get('relevance')}, " + f"div={a.get('diversity')}, reward={a.get('reward')} " + f"content={a['content'][:50]}...") + assert len(scores) == 3, f"Expected 3 scores, got {len(scores)}" - assert all(isinstance(s, float) for s in scores), "All scores should be floats" - # neutral response should score lowest + # Neutral response should score lowest assert scores[1] < scores[0] or scores[1] < scores[2], \ f"Neutral response should score lower than at least one extraverted response: {scores}" - print("✓ Listwise mode test passed") + print("PASSED") return True except Exception as e: - print(f"✗ Listwise mode test failed: {e}") - import traceback - traceback.print_exc() + print(f"FAILED: {e}") + import traceback; traceback.print_exc() return False async def main(): - print("Testing on_compute_relative_reward.py (real API)") - print("=" * 50) + print("Testing on_compute_relative_reward.py — Composite Reward") + print("(extraversion + relevance + diversity + quality gate)") + print("=" * 60) results = [] - results.append(await test_pointwise()) - results.append(await test_listwise()) + results.append(await test_pointwise_composite()) + results.append(await test_relevance_penalty()) + results.append(await test_diversity_penalty()) + results.append(await test_cross_request_diversity()) + results.append(await test_repetition_penalty()) + results.append(await test_listwise_composite()) - print("\n" + "=" * 50) - print(f"Tests passed: {sum(results)}/{len(results)}") + print("\n" + "=" * 60) + passed = sum(results) + total = len(results) + print(f"Tests passed: {passed}/{total}") + if not all(results): + names = [ + "pointwise_composite", "relevance_penalty", "diversity_penalty", + "cross_request_diversity", "repetition_penalty", "listwise_composite", + ] + for name, ok in zip(names, results): + if not ok: + print(f" FAILED: {name}") return all(results)