-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrun_origin_codegen.py
More file actions
332 lines (271 loc) · 12.7 KB
/
run_origin_codegen.py
File metadata and controls
332 lines (271 loc) · 12.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
```python
import json
import requests
import re
import os
import logging
import time # <-- NEW: Import the time module for delays
# ==============================================================================
# --- User Configuration (Please modify the following variables for your environment) ---
# ==============================================================================
# 1. Your Model API Key
MODEL_API_KEY = "KEY"
# 2. The model name you want to use
MODEL_NAME = "gpt-4.1-mini"
# 3. Number of requests per task (for pass@k metric)
# For example, setting this to 5 will ensure 5 code generations for each task_instance
NUM_REQUESTS_PER_TASK = 5 # k
# 4. [NEW] Retry delay in seconds after each failed API call
# This helps to avoid rate limiting by the API server due to rapid successive requests
RETRY_DELAY_SECONDS = 5
# 5. [NEW] Maximum total attempts for a single task
# This is a safeguard to prevent infinite retries on a single task.
# For example, if k=5 and this is set to 20, it means a maximum of 15 failures will be tolerated.
MAX_ATTEMPTS_PER_TASK = 20
# 6. Path to the source JSON file containing 'task_instance'
SOURCE_JSON_FILE = 'task_instance.json'
# 7. Specify the directory path to save results and logs
OUTPUT_DIRECTORY = './result/origin_result'
# 8. Output filename will now be dynamically generated based on the model name
OUTPUT_JSONL_FILE = os.path.join(OUTPUT_DIRECTORY, f"{MODEL_NAME}_k{NUM_REQUESTS_PER_TASK}_results.jsonl")
# 9. Log filename and path are now also associated with the model and output directory
LOG_FILE = os.path.join(OUTPUT_DIRECTORY, f"{MODEL_NAME}_k{NUM_REQUESTS_PER_TASK}_run.log")
# Model API Endpoint
API_URL = "https://api.openai.com/v1/chat/completions"
# ==============================================================================
# --- Script Body (Functions other than main remain unchanged) ---
# ==============================================================================
def setup_logging(log_file):
"""Configure logging to output to both a file and the console."""
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if logger.hasHandlers():
logger.handlers.clear()
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8')
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
def load_json_file(filepath):
"""Load and parse the JSON file."""
logging.info(f"Loading source file: {filepath}...")
try:
with open(filepath, 'r', encoding='utf-8') as f:
return json.load(f)
except FileNotFoundError:
logging.error(f"File not found '{filepath}'. The script will terminate.")
return None
except json.JSONDecodeError:
logging.error(f"The JSON file '{filepath}' is invalid. The script will terminate.", exc_info=True)
return None
def get_processed_ids(filepath):
"""Read the output file to get all processed IDs for resuming from a checkpoint."""
processed_ids = set()
if not os.path.exists(filepath):
return processed_ids
logging.info(f"Checking for processed entries in: {filepath}...")
with open(filepath, 'r', encoding='utf-8') as f:
for line in f:
try:
data = json.loads(line)
if 'id' in data:
processed_ids.add(data['id'])
except json.JSONDecodeError:
logging.warning(f"Found a malformed line in '{filepath}', skipping.")
continue
logging.info(f"Found {len(processed_ids)} processed entries.")
return processed_ids
def build_prompt(task_instance):
"""
[One-Shot Version] Build a forceful and clear code generation prompt with an example based on the task_instance.
"""
# --- Here is the well-designed one-shot example ---
example_input = """/**
* Checks if a given string is null, empty, or consists only of white-space characters.
*
* @param str the String to check, may be null
* @return {@code true} if the String is null, empty, or whitespace-only
*/
public static boolean isBlank(String str)"""
example_output = """
/**
* Checks if a given string is null, empty, or consists only of white-space characters.
*
* @param str the String to check, may be null
* @return {@code true} if the String is null, empty, or whitespace-only
*/
public static boolean isBlank(String str) {
if (str == null || str.isEmpty()) {
return true;
}
for (int i = 0; i < str.length(); i++) {
if (!Character.isWhitespace(str.charAt(i))) {
return false;
}
}
return true;
}
"""
# --- This is the final prompt template ---
prompt = f"""
You are an expert Java programmer acting as a code generation engine. Your task is to implement the body of a single Java function based on the provided specification.
### INSTRUCTIONS:
1. Your output MUST strictly follow the format and structure of the example below.
2. Generate ONLY ONE complete function block.
3. **DO NOT** define any helper methods, private functions, inner classes, or a `main` method.
4. Your response must be ONLY the Java code, wrapped in ```java. Do not add any explanation.
---
### EXAMPLE
#### Function to Implement (Example):
```java
{example_input}
```
#### Expected Output (Example):
{example_output}
---
### YOUR TASK
#### Function to Implement (Your Task):
```java
{task_instance}```
Now, generate the output for YOUR TASK.
"""
return prompt
def call_model_api(prompt, model_name):
"""Send a request to the Model API and return the model's response text."""
headers = {
"Authorization": f"Bearer {MODEL_API_KEY}",
"Content-Type": "application/json"
}
payload = {
"model": model_name,
"messages": [
{"role": "user", "content": prompt}
],
"temperature": 0.7, # For multiple generations, you can slightly increase the temperature to enhance diversity
"max_tokens": 4096,
"stream": False
}
try:
response = requests.post(API_URL, headers=headers, json=payload, timeout=120)
response.raise_for_status()
response_json = response.json()
return response_json['choices']['message']['content']
except requests.exceptions.RequestException:
logging.error("API request failed.", exc_info=True)
return None
except (KeyError, IndexError) as e:
logging.error(f"API response format is incorrect. Response content: {response.text}", exc_info=True)
return None
def extract_code_from_response(response_text):
"""
Use multiple strategies to extract code blocks from the model's response.
Strategy Priority:
1. Look for content wrapped in ``` (Markdown).
2. Assume the entire response is code.
"""
if not response_text:
return None
# Look for code wrapped in ``` (Markdown)
match = re.search(r"```(?:java\n)?(.*?)```", response_text, re.DOTALL)
if match:
return match.group(1).strip()
# Strategy 2: If none of the above match, assume the entire response is code (removing possible wrappers)
# This is a fallback strategy that can handle cases where the model directly returns pure code
cleaned_response = response_text.strip()
# Avoid returning empty or invalid responses
if cleaned_response.startswith('{') or cleaned_response.startswith('public') or cleaned_response.startswith('String'):
return cleaned_response
# If it's still not identifiable after all attempts, return None
return None
def append_to_jsonl(filepath, data_dict):
"""Append a dictionary as a line to a .jsonl file."""
try:
with open(filepath, 'a', encoding='utf-8') as f:
f.write(json.dumps(data_dict, ensure_ascii=False) + '\n')
except IOError:
logging.error(f"Failed to write to file '{filepath}'.", exc_info=True)
# ==============================================================================
# --- Main Logic Change Area ---
# ==============================================================================
def main():
"""Main execution function"""
try:
os.makedirs(OUTPUT_DIRECTORY, exist_ok=True)
except OSError as e:
print(f"Fatal Error: Failed to create directory '{OUTPUT_DIRECTORY}': {e}")
return
setup_logging(LOG_FILE)
if MODEL_API_KEY == "YOUR_MODEL_API_KEY_HERE":
logging.error("Please configure your MODEL_API_KEY at the top of the script.")
return
logging.info(f"Logs will be recorded to: {LOG_FILE}")
source_data = load_json_file(SOURCE_JSON_FILE)
if not source_data:
return
tasks = [
analysis_result
for repo_block in source_data
for analysis_result in repo_block.get('analysis_results', [])
]
total_tasks = len(tasks)
logging.info(f"Found a total of {total_tasks} tasks to process.")
logging.info(f"Will ensure {NUM_REQUESTS_PER_TASK} results are generated for each task.")
logging.info(f"Results will be saved to: {OUTPUT_JSONL_FILE}")
processed_ids = get_processed_ids(OUTPUT_JSONL_FILE)
for i, analysis_result in enumerate(tasks, 1):
result_id = analysis_result.get('id')
primary_analysis = analysis_result.get('primary_analysis', {})
task_instance = primary_analysis.get('task_instance')
logging.info(f"--- Processing task {i}/{total_tasks} (ID: {result_id}) ---")
if not result_id or not task_instance:
logging.warning(f"ID: {result_id} - [SKIPPING] Missing 'id' or 'task_instance'.")
continue
if result_id in processed_ids:
logging.info(f"ID: {result_id} - [SKIPPING] This task has already been processed.")
continue
prompt = build_prompt(task_instance)
# --- [CORE MODIFICATION] Use a while loop and retry mechanism to ensure k results are obtained ---
generated_codes_list = []
total_attempts = 0 # Counter for a safe exit
while len(generated_codes_list) < NUM_REQUESTS_PER_TASK:
# Safety check to prevent infinite loops
total_attempts += 1
if total_attempts > MAX_ATTEMPTS_PER_TASK:
logging.critical(f"ID: {result_id} - [ABORTING] Total attempts for the task ({total_attempts}) have exceeded the maximum limit ({MAX_ATTEMPTS_PER_TASK}). This could be due to persistent API issues.")
break # Break out of the while loop, abandoning the current task
# Log the current target
current_progress = len(generated_codes_list) + 1
logging.info(f"ID: {result_id} - Getting result {current_progress}/{NUM_REQUESTS_PER_TASK} (Total attempts: {total_attempts})...")
model_response = call_model_api(prompt, MODEL_NAME)
if not model_response:
logging.error(f"ID: {result_id} - API call failed. Retrying in {RETRY_DELAY_SECONDS} seconds...")
time.sleep(RETRY_DELAY_SECONDS)
continue # Continue to the next while loop iteration
generated_code = extract_code_from_response(model_response)
if generated_code:
generated_codes_list.append(generated_code)
logging.info(f"ID: {result_id} - Successfully obtained result {current_progress}/{NUM_REQUESTS_PER_TASK}.")
else:
logging.warning(f"ID: {result_id} - Could not find code block in the response. Retrying in {RETRY_DELAY_SECONDS} seconds...")
time.sleep(RETRY_DELAY_SECONDS)
# --- Save the aggregated results ---
# Save only when we have successfully obtained the required number of codes
if len(generated_codes_list) == NUM_REQUESTS_PER_TASK:
output_record = {
"id": result_id,
"model": MODEL_NAME,
"generated_codes": generated_codes_list # Save the list of codes
}
append_to_jsonl(OUTPUT_JSONL_FILE, output_record)
logging.info(f"ID: {result_id} - [SUCCESS] Task complete. Successfully saved {len(generated_codes_list)} generated codes.")
else:
# This situation usually occurs due to reaching the maximum number of attempts and exiting early
logging.error(f"ID: {result_id} - [FAILURE] Task terminated. Ultimately only obtained {len(generated_codes_list)}/{NUM_REQUESTS_PER_TASK} results.")
logging.info("All tasks processed!")
if __name__ == "__main__":
main()
```