-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreward_function.py
More file actions
217 lines (186 loc) · 8.1 KB
/
reward_function.py
File metadata and controls
217 lines (186 loc) · 8.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import re
from typing import List
from collections import Counter
import os
import re
import time
import json # Import json library
from pathlib import Path
from typing import Optional, List, Dict, Any
import requests # Import requests library
# Removed openai and asyncio libraries
from prompts import judge_system_prompt, judge_prompt_template
from verl.utils.reward_score import default_compute_score
# Server configuration extracted from get_judge_client function and curl example
JUDGE_SERVER_URL = "http://localhost:8124"
# Model name extracted from curl example
MODEL_NAME = "/ceph/home/muhan01/huggingfacemodels/Qwen2.5-32B-Instruct"
# Parameters extracted from the original JudgeRewardModel class
TEMPERATURE = 0.1
MAX_TOKENS = 10
TIMEOUT = 60
# retries
MAX_RETRIES = 3 # 最大重试次数
INITIAL_RETRY_DELAY = 2 # 初始等待时间(秒)
BACKOFF_FACTOR = 2 # 每次重试等待时间翻倍
def extract_evidence_tag(text: str) -> str:
match = re.search(r'<evidence>(.*?)</evidence>', text, re.DOTALL | re.IGNORECASE)
return match.group(1).strip() if match else None
def _parse_judge_response(response: str) -> float:
"""
Parses the judge model's response into a numerical reward.
"""
response = response.strip().lower()
if 'true' in response[:10]:
return 1.0
elif 'false' in response[:10]:
return 0.0
else:
return 0.0
def _compute_evidence_format_reward(solution_str: str) -> float:
"""
Computes the reward for the evidence format.
"""
num_left_evidence = solution_str.count('<evidence>')
num_right_evidence = solution_str.count('</evidence>')
if num_left_evidence == num_right_evidence == 1:
return 0.0
else:
return -0.2
def _extract_answer(solution_str: str) -> str:
"""
Extracts the answer from the solution string.
"""
# Extract everything after </evidence>
m = re.search(r'</evidence>(.*)', solution_str, flags=re.DOTALL | re.IGNORECASE)
answer = m.group(1) if m else ''
return answer.strip()
def split_into_sentences(text: str) -> List[str]:
"""Split text by sentence-ending punctuations: . ? !"""
sentences = re.split(r'[.?!]+', text)
return [s.strip() for s in sentences if s.strip()]
def normalize_sentence(s: str) -> str:
"""Normalize sentence by removing extra whitespaces and invisible chars"""
return re.sub(r'\s+', '', s)
def calculate_token_f1_evidence(evidence_pred: str, evidence_gt: str) -> float:
"""
Calculate token-level F1 score between predicted and ground truth evidence.
"""
pred_tokens = evidence_pred.split()
gt_tokens = evidence_gt.split()
pred_counter = Counter(pred_tokens)
gt_counter = Counter(gt_tokens)
matched = 0
for token, cnt in pred_counter.items():
if token in gt_counter:
matched += min(cnt, gt_counter[token])
precision = matched / len(pred_tokens) if pred_tokens else 0.0
recall = matched / len(gt_tokens) if gt_tokens else 0.0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
return f1
def calculate_f1_score_evidence(evidence_pred: str, evidence_gt: str) -> float:
"""
Calculate F1 score between predicted and ground truth evidence based on sentence matching.
Matching ignores all whitespace/invisible characters.
"""
gt_sentences = split_into_sentences(evidence_gt)
pred_sentences = split_into_sentences(evidence_pred)
gt_normalized = [normalize_sentence(s) for s in gt_sentences]
pred_normalized = [normalize_sentence(s) for s in pred_sentences]
gt_counter = Counter(gt_normalized)
pred_counter = Counter(pred_normalized)
matched = 0
for sent, cnt in pred_counter.items():
if sent in gt_counter:
matched += min(cnt, gt_counter[sent])
precision = matched / len(pred_normalized) if pred_normalized else 0.0
recall = matched / len(gt_normalized) if gt_normalized else 0.0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
return f1
def compute_score(
data_source: str,
solution_str: str,
ground_truth: str,
extra_info: Optional[Dict[str, Any]] = None
) -> float:
"""
Compute F1 score for evidence extraction, with penalty for missing <evidence> tags.
Args:
pred: predicted string (包含 <evidence> 标签)
gt: ground truth string (包含 <evidence> 标签)
penalty: reward penalty for missing/incomplete evidence tags
Returns:
F1 score (float)
"""
# for sentence-level F1 score
if data_source == "train_longtext_qa":
pred_evidence = extract_evidence_tag(solution_str)
# gt_evidence = extract_evidence_tag(ground_truth)
gt_evidence = extra_info['evidence']
assert gt_evidence is not None, "Ground truth evidence is None, please check the training data"
if not pred_evidence:
print("Missing <evidence> tag, applying penalty.")
return -0.2
sentence_f1 = calculate_f1_score_evidence(pred_evidence, gt_evidence)
token_f1 = calculate_token_f1_evidence(pred_evidence, gt_evidence)
score = 0.5 * sentence_f1 + 0.5 * token_f1
print(f"Sentence-level F1: {sentence_f1}, Token-level F1: {token_f1}, Final score: {score}")
return score
# 1. Compute format reward
# This reward is computed regardless of whether the API call succeeds
# format_reward = _compute_evidence_format_reward(solution_str)
format_reward = 0.0
# 2. Extract necessary information
assert extra_info.get('question_raw', None) is not None, "extra_info is missing 'question_raw', please check the training data"
question = extra_info['question_raw']
# 3. Prepare for API call
judge_reward = 0.0 # Default judge reward
api_url = f"{JUDGE_SERVER_URL.rstrip('/')}/evaluate"
headers = {"Content-Type": "application/json"}
# Extract answer for the prompt
# answer = _extract_answer(solution_str)
answer = solution_str
# Build request payload matching the new endpoint format
payload = {
"question": question,
"reference": ground_truth,
"pred": answer
}
# 4. Execute API call
current_delay = INITIAL_RETRY_DELAY
for attempt in range(MAX_RETRIES + 1):
try:
response = requests.post(
api_url,
headers=headers,
json=payload, # Use json parameter instead of data=json.dumps()
timeout=TIMEOUT
)
response.raise_for_status() # Trigger exception for 4xx/5xx codes
# Parse JSON response - new format: {"result": "True" or "False"}
response_data = response.json()
judge_response_content = response_data.get('result', 'False')
judge_reward = _parse_judge_response(judge_response_content)
# Success: exit the retry loop
break
except (requests.exceptions.RequestException, ConnectionError, TimeoutError) as e:
# 网络相关错误,需要重试
if attempt < MAX_RETRIES:
print(f"[WARNING] Request failed (Attempt {attempt+1}/{MAX_RETRIES}). Retrying in {current_delay}s... Error: {e}")
time.sleep(current_delay)
current_delay *= BACKOFF_FACTOR # Exponential backoff
else:
print(f"[ERROR] All {MAX_RETRIES} retries failed for question '{question[:20]}...'. Defaulting reward to 0.0. Final Error: {e}")
# 如果彻底失败,judge_reward 保持为 0.0
pass
except (KeyError, IndexError, json.JSONDecodeError) as e:
# 数据解析错误(通常重试也没用,直接跳出)
print(f"[ERROR] Response parsing failed for question '{question[:20]}...': {e}")
break
except Exception as e:
# 其他未知错误
print(f"[ERROR] Unexpected error for question '{question[:20]}...': {e}")
break
# 5. Return total reward
total_reward = format_reward + judge_reward
return total_reward