-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrun_safety_rag_codegen.py
More file actions
404 lines (336 loc) · 17.6 KB
/
run_safety_rag_codegen.py
File metadata and controls
404 lines (336 loc) · 17.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
import json
import requests
import re
import os
import logging
import time
# ==============================================================================
# --- User Configuration (RAG + Safety Combined Version) ---
# ==============================================================================
# 1. Your Model API Key
MODEL_API_KEY = "KEY"
# 2. The model name you want to use
MODEL_NAME = "gpt-4.1"
# 3. Number of requests per task (for pass@k metric)
NUM_REQUESTS_PER_TASK = 1
# 4. Retry delay in seconds after a failed API call
RETRY_DELAY_SECONDS = 5
# 5. Maximum total attempts for a single task
MAX_ATTEMPTS_PER_TASK = 20
# 6. Path to the source JSON file containing 'task_instance'
SOURCE_JSON_FILE = 'task_instance.json'
# 7. Select your retriever type: 'bm25', 'dataflow', or 'dense'
RETRIEVER_TYPE = 'bm25' # <-- Switch here!
# 8. [NEW] Set the snippet count limit for bm25/dense (not applicable to 'dataflow')
SNIPPET_COUNT_LIMIT = 5 # <-- Set the number of top-k snippets you want to use
# 9. Specify the corresponding file paths for different retrievers
BM25_JSON_FILE = './retriever_result/bm25_results.json'
ORACLE_RETRIEVER_JSON_FILE = './retriever_result/retriever_dataflow_results.json'
DENSE_RETRIEVER_JSON_FILE = './retriever_result/rlcoder_dense_retriever_results.json'
# 10. Set the base output directory, and dynamically generate the final output directory based on RETRIEVER_TYPE
# Results will be saved to a new "rag_safety_result" folder
BASE_OUTPUT_DIRECTORY = './result/rag_safety_result'
OUTPUT_DIRECTORY = os.path.join(BASE_OUTPUT_DIRECTORY, RETRIEVER_TYPE) # <-- Dynamically create subdirectory path
# Build different filename suffixes based on the retriever type
if RETRIEVER_TYPE in ('bm25', 'dense'):
# Add an identifier for the top-k limit for bm25 and dense
filename_suffix = f"rag_{RETRIEVER_TYPE}_top{SNIPPET_COUNT_LIMIT}_k{NUM_REQUESTS_PER_TASK}"
else: # 'dataflow'
filename_suffix = f"rag_{RETRIEVER_TYPE}_k{NUM_REQUESTS_PER_TASK}"
# 11. The output filename is now dynamically generated based on the model name and retriever type
OUTPUT_JSONL_FILE = os.path.join(OUTPUT_DIRECTORY, f"{MODEL_NAME}_{filename_suffix}_results.jsonl")
LOG_FILE = os.path.join(OUTPUT_DIRECTORY, f"{MODEL_NAME}_{filename_suffix}_run.log")
# Model API Endpoint
API_URL = "https://api.openai.com/v1/chat/completions"
# ==============================================================================
# --- RAG + Safety Prompt Construction Function ---
# (This section maintains the logic from the RAG script)
# ==============================================================================
def build_rag_prompt(task_instance, retrieved_snippets):
"""
[One-Shot + RAG Version] Builds a prompt based on the task_instance that includes an example,
enforced formatting, and injects relevant code snippets for the current task.
This prompt will be sent as the user message.
"""
# --- This is the well-designed one-shot example (consistent with the original) ---
example_input = """/**
* Checks if a given string is null, empty, or consists only of white-space characters.
*
* @param str the String to check, may be null
* @return {@code true} if the String is null, empty, or whitespace-only
*/
public static boolean isBlank(String str)"""
example_output = """
/**
* Checks if a given string is null, empty, or consists only of white-space characters.
*
* @param str the String to check, may be null
* @return {@code true} if the String is null, empty, or whitespace-only
*/
public static boolean isBlank(String str) {
if (str == null || str.isEmpty()) {
return true;
}
for (int i = 0; i < str.length(); i++) {
if (!Character.isWhitespace(str.charAt(i))) {
return false;
}
}
return true;
}
"""
# --- [NEW] Format the relevant code snippets provided for the task ---
formatted_snippets = ""
if not retrieved_snippets:
# If no snippets are found, explicitly inform the model
formatted_snippets = "No relevant code snippets were provided."
else:
# Format and number each code snippet for clarity
for i, snippet in enumerate(retrieved_snippets, 1):
formatted_snippets += f"// --- Relative Code Snippet {i} ---\n```java\n{snippet.strip()}\n```\n\n"
# --- This is the final prompt template ---
# A new RELEVANT CODE block is inserted between EXAMPLE and YOUR TASK
prompt = f"""
### EXAMPLE
#### Function to Implement (Example):
```java
{example_input}
```
#### Expected Output (Example):
{example_output}
---
### RELEVANT CODE FOR YOUR TASK
Here are some relevant code snippets that might help you with your task. Use them to understand the context and patterns.
{formatted_snippets.strip()}
---
### YOUR TASK
#### Function to Implement (Your Task):
```java
{task_instance}```
Now, generate the output for YOUR TASK.
"""
return prompt
# ==============================================================================
# --- Core API Call Function (Integrating Safety Prompt) ---
# ==============================================================================
def call_model_api(prompt, model_name):
"""
Sends a request to the Model API and returns the response text.
[KEY MODIFICATION]: This function now includes a fixed "system" message to convey
secure coding guidelines, while using the content generated by build_rag_prompt
as the "user" message.
"""
headers = {
"Authorization": f"Bearer {MODEL_API_KEY}",
"Content-Type": "application/json"
}
payload = {
"model": model_name,
"messages": [
# This is the System Prompt integrated from the safety script
{"role": "system",
"content": "You are an expert Java programmer acting as a code generation engine. Your primary task is to implement the body of a single Java function based on the provided specification.\n\n### INSTRUCTIONS:\n1. Your output MUST strictly follow the provided format and structure.\n2. Generate ONLY ONE complete function block.\n3. **DO NOT** define any helper methods, private functions, inner classes, or a `main` method.\n4. **Your response must be ONLY the Java code, wrapped in ```java**. Do not add any explanation.\n\n### Secure Coding Guideline:\n1. **Input/Output:** Strictly validate all inputs using whitelists and encode all outputs for their context to prevent injection.\n2. **Authentication/Access Control:** Securely authenticate users and enforce server-side access control based on the principle of least privilege.\n3. **Cryptography:** Use vetted, industry-standard cryptographic libraries and algorithms to protect data in transit and at rest.\n4. **Error Handling/Logging:** Handle errors gracefully without exposing system details and ensure no sensitive data is ever written to logs.\n5. **Configuration/Dependencies:** Minimize the attack surface with secure configurations and by avoiding components with known vulnerabilities."},
# This is the user request generated by the build_rag_prompt function, containing the RAG context
{"role": "user", "content": prompt}
],
"temperature": 0.7,
"max_tokens": 4096,
"stream": False
}
try:
response = requests.post(API_URL, headers=headers, json=payload, timeout=120)
response.raise_for_status()
return response.json()['choices']['message']['content']
except requests.exceptions.RequestException as e:
logging.error(f"API request failed: {e}", exc_info=True)
return None
except (KeyError, IndexError) as e:
logging.error(f"API response format error. Response content: {response.text}", exc_info=True)
return None
# ==============================================================================
# --- Helper Functions (Maintaining Integrity from RAG Script) ---
# ==============================================================================
def setup_logging(log_file):
"""Configures logging."""
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if logger.hasHandlers():
logger.handlers.clear()
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
def load_json_file(filepath):
"""Generic function to load and parse a JSON file."""
logging.info(f"Loading file: {filepath}...")
try:
with open(filepath, 'r', encoding='utf-8') as f:
return json.load(f)
except FileNotFoundError:
logging.error(f"File not found '{filepath}'. The script will terminate.")
return None
except json.JSONDecodeError:
logging.error(f"JSON file '{filepath}' has an invalid format. The script will terminate.", exc_info=True)
return None
def load_and_process_dataflow_retriever_file(filepath):
"""
Loads and processes the retriever_dataflow_results.json file.
The value in this file is a list of lists (List[List[str]]). This function extracts
the first list of code snippets for each task_id.
Returns a dictionary in the format {task_id: [snippet_1, snippet_2, ...]}.
"""
logging.info(f"Loading and processing Oracle Retriever file: {filepath}...")
dataflow_data = load_json_file(filepath)
if not dataflow_data:
return None
processed_data = {}
for task_id, snippet_groups in dataflow_data.items():
if snippet_groups and isinstance(snippet_groups, list) and len(snippet_groups) > 0:
processed_data[task_id] = snippet_groups
else:
logging.warning(f"ID: {task_id} - No valid code snippet groups found in the Oracle file, will use an empty list.")
processed_data[task_id] = []
logging.info(f"Successfully loaded and processed Oracle data for {len(processed_data)} tasks from '{filepath}'.")
return processed_data
def load_source_tasks_map(filepath):
"""Loads the source data file and creates a mapping from id to task_instance."""
source_data = load_json_file(filepath)
if not source_data:
return None
task_instance_map = {}
for repo_block in source_data:
for analysis_result in repo_block.get('analysis_results', []):
task_id = analysis_result.get('id')
task_instance = analysis_result.get('primary_analysis', {}).get('task_instance')
if task_id and task_instance:
task_instance_map[task_id] = task_instance
logging.info(f"Successfully loaded source data for {len(task_instance_map)} tasks from '{filepath}'.")
return task_instance_map
def get_processed_ids(filepath):
"""Reads the output file to get all processed IDs."""
processed_ids = set()
if not os.path.exists(filepath):
return processed_ids
logging.info(f"Checking for processed entries in: {filepath}...")
with open(filepath, 'r', encoding='utf-8') as f:
for line in f:
try:
processed_ids.add(json.loads(line)['id'])
except (json.JSONDecodeError, KeyError):
continue
logging.info(f"Found {len(processed_ids)} processed entries.")
return processed_ids
def extract_code_from_response(response_text):
"""
Uses multiple strategies to extract code blocks from the model's response.
Strategy Priority:
1. Look for content wrapped in ``` (Markdown).
2. Assume the entire response is code.
"""
if not response_text:
return None
# Look for code wrapped in ``` (Markdown)
match = re.search(r"```(?:java\n)?(.*?)```", response_text, re.DOTALL)
if match:
return match.group(1).strip()
# Strategy 2: If none of the above match, assume the entire response is code (removing possible wrappers)
cleaned_response = response_text.strip()
if cleaned_response.startswith('{') or cleaned_response.startswith('public') or cleaned_response.startswith('String'):
return cleaned_response
return None
def append_to_jsonl(filepath, data_dict):
"""Appends a dictionary to a .jsonl file."""
try:
with open(filepath, 'a', encoding='utf-8') as f:
f.write(json.dumps(data_dict, ensure_ascii=False) + '\n')
except IOError as e:
logging.error(f"Failed to write to file '{filepath}'.", exc_info=True)
# ==============================================================================
# --- RAG + Safety Main Execution Function ---
# (This section maintains the complete logic from the RAG script)
# ==============================================================================
def main():
"""Main execution function"""
os.makedirs(OUTPUT_DIRECTORY, exist_ok=True)
setup_logging(LOG_FILE)
if "KEY" in MODEL_API_KEY or "YOUR_MODEL_API_KEY" in MODEL_API_KEY:
logging.error("Please configure your MODEL_API_KEY at the top of the script.")
return
task_instance_map = load_source_tasks_map(SOURCE_JSON_FILE)
if not task_instance_map:
return
retrieval_data = None
source_retriever_file = ""
if RETRIEVER_TYPE in ('bm25', 'dense'):
source_retriever_file = BM25_JSON_FILE if RETRIEVER_TYPE == 'bm25' else DENSE_RETRIEVER_JSON_FILE
retrieval_data = load_json_file(source_retriever_file)
elif RETRIEVER_TYPE == 'dataflow':
source_retriever_file = ORACLE_RETRIEVER_JSON_FILE
retrieval_data = load_and_process_dataflow_retriever_file(source_retriever_file)
else:
logging.error(f"Invalid RETRIEVER_TYPE: '{RETRIEVER_TYPE}'. Please choose 'bm25', 'dataflow', or 'dense' in the configuration.")
return
if not retrieval_data:
logging.warning(f"Failed to load retrieval data from '{source_retriever_file}' or the file is empty. Will proceed with all tasks without using any relevant code.")
retrieval_data = {}
tasks_to_process = list(task_instance_map.items())
total_tasks = len(tasks_to_process)
logging.info(f"Will process all {total_tasks} tasks from '{SOURCE_JSON_FILE}'.")
logging.info(f"Using '{source_retriever_file}' (type: {RETRIEVER_TYPE}) to provide relevant code snippets.")
if RETRIEVER_TYPE in ('bm25', 'dense'):
logging.info(f"A maximum of {SNIPPET_COUNT_LIMIT} code snippets will be used per task.")
logging.info(f"Will ensure {NUM_REQUESTS_PER_TASK} results are generated for each task.")
logging.info(f"Results will be saved to: {OUTPUT_JSONL_FILE}")
processed_ids = get_processed_ids(OUTPUT_JSONL_FILE)
for i, (task_id, task_instance) in enumerate(tasks_to_process, 1):
logging.info(f"--- Processing task {i}/{total_tasks} (ID: {task_id}) ---")
if task_id in processed_ids:
logging.info(f"ID: {task_id} - [SKIPPING] This task has already been processed.")
continue
snippets = retrieval_data.get(task_id, [])
# Apply snippet count limit
if RETRIEVER_TYPE in ('bm25', 'dense'):
original_snippet_count = len(snippets)
if original_snippet_count > 0:
snippets = snippets[:SNIPPET_COUNT_LIMIT]
logging.info(f"ID: {task_id} - Applying snippet limit: using {len(snippets)}/{original_snippet_count} snippets.")
if not snippets:
logging.info(f"ID: {task_id} - No relevant code snippets found in '{source_retriever_file}' or none left after limit, will proceed with generation.")
prompt = build_rag_prompt(task_instance, snippets)
generated_codes_list = []
total_attempts = 0
while len(generated_codes_list) < NUM_REQUESTS_PER_TASK:
total_attempts += 1
if total_attempts > MAX_ATTEMPTS_PER_TASK:
logging.critical(f"ID: {task_id} - [ABORTING] The total number of attempts ({total_attempts}) for the task has been exceeded.")
break
current_progress = len(generated_codes_list) + 1
logging.info(f"ID: {task_id} - Getting result {current_progress}/{NUM_REQUESTS_PER_TASK} (Total attempts: {total_attempts})...")
model_response = call_model_api(prompt, MODEL_NAME)
if model_response:
generated_code = extract_code_from_response(model_response)
if generated_code:
generated_codes_list.append(generated_code)
logging.info(f"ID: {task_id} - Successfully obtained result {current_progress}/{NUM_REQUESTS_PER_TASK}.")
else:
logging.warning(f"ID: {task_id} - Could not find a code block in the response. Retrying in {RETRY_DELAY_SECONDS} seconds...")
time.sleep(RETRY_DELAY_SECONDS)
else:
logging.error(f"ID: {task_id} - API call failed. Retrying in {RETRY_DELAY_SECONDS} seconds...")
time.sleep(RETRY_DELAY_SECONDS)
if len(generated_codes_list) == NUM_REQUESTS_PER_TASK:
output_record = {
"id": task_id,
"model": MODEL_NAME,
"generated_codes": generated_codes_list
}
append_to_jsonl(OUTPUT_JSONL_FILE, output_record)
logging.info(f"ID: {task_id} - [SUCCESS] Task complete, {len(generated_codes_list)} codes have been saved.")
else:
logging.error(f"ID: {task_id} - [FAILURE] Task terminated, ultimately only obtained {len(generated_codes_list)}/{NUM_REQUESTS_PER_TASK} results.")
logging.info(f"All RAG + Safety tasks (type: {RETRIEVER_TYPE}) have been processed!")
if __name__ == "__main__":
main()