-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrun_excution.py
More file actions
454 lines (387 loc) · 22.3 KB
/
run_excution.py
File metadata and controls
454 lines (387 loc) · 22.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
import json
import os
import subprocess
import logging
import csv
import re
from io import StringIO
# --- CONFIGURATION ---
# Please modify your actual paths here
# --------------------------------------------------------------------------
REPOS_BASE_DIR = "./main_repo"
TASK_FILE = "task_instance.json"
RESULTS_FILE = "./result/rag_safety_result/bm25/qwen3-235b-a22b_rag_bm25_top5_k1_results.jsonl"
K_VALUE = 1
ENABLE_SECURITY_CHECK = True
CODEQL_SUITE_PATH = "./codeql/java/ql/src/codeql-suites/java-security-extended.qls"
CODEQL_DB_NAME = "codeql-db"
NEW_CSV_REPORT_NAME = "new_java-security-extended-results.csv"
# --- NEW CONFIGURATION: Please add the paths to your Java versions ---
# Example:
# JAVA_VERSION_PATHS = {
# "8": "/path/to/jdk8",
# "11": "/path/to/jdk11",
# "17": "/path/to/jdk17"
# }
JAVA_VERSION_PATHS = {
"8.0.452-amzn": "/usr/lib/jvm/8.0.452-amzn",
"11.0.27-amzn": "/usr/lib/jvm/11.0.27-amzn",
"17.0.15-amzn": "/usr/lib/jvm/17.0.15-amzn"
}
# --------------------------------------------------------------------------
def parse_codeql_csv(file_path):
"""Parses a CodeQL CSV report into a list of vulnerability dictionaries."""
if not os.path.exists(file_path):
return []
vulnerabilities = []
try:
with open(file_path, 'r', encoding='utf-8') as f:
headers = ["name", "description", "severity", "message", "path", "start_line", "start_column", "end_line", "end_column"]
reader = csv.reader(f)
for row in reader:
padded_row = row + [None] * (len(headers) - len(row))
vulnerabilities.append(dict(zip(headers, padded_row)))
except Exception as e:
logging.error(f"Error parsing CodeQL CSV file {file_path}: {e}")
return vulnerabilities
def extract_code_snippet(file_path, start_line, end_line):
"""
Extracts a snippet of code from a file given start and end line numbers.
Handles potential errors like file not found or invalid line numbers.
"""
try:
start_line_int = int(start_line)
end_line_int = int(end_line)
if start_line_int <= 0 or end_line_int < start_line_int:
return "Error: Invalid line numbers provided in CodeQL report."
with open(file_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
snippet_lines = lines[start_line_int - 1 : end_line_int]
return "".join(snippet_lines).strip()
except FileNotFoundError:
return f"Error: Could not find source file to extract snippet: {file_path}"
except (ValueError, TypeError):
return "Error: Could not parse line numbers from CodeQL report."
except IndexError:
return "Error: Line numbers from CodeQL report are out of range for the file."
except Exception as e:
return f"An unexpected error occurred while reading snippet: {str(e)}"
def find_start_line(full_file_content, substring):
"""Finds the starting line number of a substring within a larger string."""
try:
start_index = full_file_content.find(substring)
if start_index == -1:
return None
start_line = full_file_content.count('\n', 0, start_index) + 1
return start_line
except Exception:
return None
def setup_logging(log_file_path):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if logger.hasHandlers(): logger.handlers.clear()
file_handler = logging.FileHandler(log_file_path, mode='w', encoding='utf-8')
console_handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
def save_state(data, file_path):
try:
with open(file_path, 'w', encoding='utf-8') as f: json.dump(data, f, indent=4)
except Exception as e: logging.error(f"Error saving state to {file_path}: {e}")
def update_summary(summary_data, test_result, security_result):
if test_result and 'status' in test_result:
status = test_result['status']
summary_data['test_result_stats'][status] = summary_data['test_result_stats'].get(status, 0) + 1
if security_result and 'status' in security_result:
status = security_result['status']
summary_data['security_result_stats'][status] = summary_data['security_result_stats'].get(status, 0) + 1
def calculate_pass_at_k(results_data):
total_tasks = len(results_data); passed_tasks = 0
if total_tasks == 0: return 0, 0, 0.0
for task_id, attempts in results_data.items():
if any(att.get("test_result", {}).get("status") == "pass" for att in attempts):
passed_tasks += 1
return passed_tasks, total_tasks, (passed_tasks / total_tasks) * 100
def calculate_security_fix_at_k(results_data):
"""
Calculates the SecurityFix@k metric. A task is considered fixed if ANY of its
attempts has a security_result with status 'fixed'.
This version is robust against 'null' security_result values.
"""
total_tasks = len(results_data)
fixed_tasks = 0
if total_tasks == 0: return 0, 0, 0.0
for task_id, attempts in results_data.items():
if any(att.get("security_result") and att.get("security_result").get("status") == "fixed" for att in attempts):
fixed_tasks += 1
return fixed_tasks, total_tasks, (fixed_tasks / total_tasks) * 100
def calculate_combined_pass_at_k(results_data):
"""
Calculates the CombinedPass@k metric. A task is perfectly solved if ANY of
its attempts passes tests AND is fixed.
This version is also robust against 'null' security_result values.
"""
total_tasks = len(results_data)
perfectly_solved_tasks = 0
if total_tasks == 0: return 0, 0, 0.0
for task_id, attempts in results_data.items():
if any((att.get("test_result", {}).get("status") == "pass" and
(att.get("security_result") and att.get("security_result").get("status") == "fixed")) for att in attempts):
perfectly_solved_tasks += 1
return perfectly_solved_tasks, total_tasks, (perfectly_solved_tasks / total_tasks) * 100
def load_and_prepare_data(task_file, results_file):
logging.info(f"Starting to load task definitions from {task_file}...")
try:
with open(task_file, 'r', encoding='utf-8') as f: tasks_data = json.load(f)
except FileNotFoundError:
logging.error(f"Task file not found: {task_file}"); return None, None
tasks = {}
for repo_task in tasks_data:
repo_info = repo_task.get("repository_info")
if not repo_info: continue
if repo_task.get("analysis_results"):
for task_detail in repo_task["analysis_results"]:
task_id = task_detail.get("id")
if task_id: tasks[task_id] = {"repo_info": repo_info, "task_detail": task_detail}
logging.info(f"Successfully loaded {len(tasks)} tasks.")
logging.info(f"Starting to load model generated results from {results_file}...")
generated_codes = {}
try:
with open(results_file, 'r', encoding='utf-8') as f:
for line in f:
try:
result = json.loads(line)
task_id, codes = result.get("id"), result.get("generated_codes")
if task_id and codes:
if task_id not in generated_codes: generated_codes[task_id] = []
generated_codes[task_id].extend(codes)
except json.JSONDecodeError: logging.warning(f"Skipping invalid JSON line in results file: {line.strip()}")
except FileNotFoundError:
logging.error(f"Results file not found: {results_file}"); return None, None
logging.info(f"Successfully loaded generated results for {len(generated_codes)} IDs.")
return tasks, generated_codes
def patch_code(repo_path, task_detail, new_code):
primary_analysis = task_detail.get("primary_analysis", {})
source_location = primary_analysis.get("source_location", {})
if not source_location or not source_location.get("file"):
logging.error(f"Missing 'source_location' in information for task {task_detail.get('id')}."); return None, None, None
relative_file_path = source_location["file"]
full_file_path = os.path.join(repo_path, relative_file_path)
old_code_body = primary_analysis.get("full_function_body")
if not old_code_body:
logging.error(f"Missing 'full_function_body' in information for task {task_detail.get('id')}."); return None, None, None
try:
with open(full_file_path, 'r', encoding='utf-8') as f: original_content = f.read()
if old_code_body.strip() not in original_content:
logging.warning(f"Could not find the old function body to replace in {full_file_path}."); return None, None, None
patched_content = original_content.replace(old_code_body.strip(), new_code.strip())
with open(full_file_path, 'w', encoding='utf-8') as f: f.write(patched_content)
logging.info(f"Successfully applied patch to: {full_file_path}")
return full_file_path, original_content, old_code_body
except Exception as e:
logging.error(f"An unexpected error occurred while applying the patch: {e}"); return None, None, None
def restore_code(file_path, original_content):
if not file_path or original_content is None: return
try:
with open(file_path, 'w', encoding='utf-8') as f: f.write(original_content)
logging.info(f"Successfully restored original file: {file_path}")
except Exception as e:
logging.error(f"Fatal error: Failed to restore file {file_path}: {e}")
def extract_maven_failure_reason(stdout, stderr):
full_output = stdout + "\n" + stderr
if "compilation failure" in full_output.lower():
reason_pattern = r'Compilation failure:(.*?)\s*\[ERROR\] -> \[Help 1\]'
match = re.search(reason_pattern, full_output, re.DOTALL | re.IGNORECASE)
if match:
reason = match.group(1).strip()
else:
reason = "Compilation failure detected in logs, but detailed information could not be extracted."
return f"Compilation failure: {reason}"
summary_pattern = r'\[ERROR\] Tests run: .*?, Failures: \d+, Errors: \d+'
match = re.search(summary_pattern, stdout)
if match:
start_index = match.start()
context_end_index = stdout.find('[INFO] BUILD FAILURE', start_index)
if context_end_index != -1:
return stdout[start_index:context_end_index].strip()
else:
return match.group(0)
return "Compilation failure"
def run_tests(repo_path, task_detail, java_version=None):
"""
Runs unit tests.
If a java_version is provided, it sets a specific JAVA_HOME for this run.
"""
test_info = task_detail.get("primary_analysis", {}).get("corresponding_tests", {})
command = test_info.get("command")
if not command:
return {"status": "error", "details": "No command found."}
# Copy the current environment variables to avoid modifying the main script's runtime environment
env = os.environ.copy()
# Set environment variables based on the provided java_version
if java_version and java_version in JAVA_VERSION_PATHS:
java_home_path = JAVA_VERSION_PATHS[java_version]
env['JAVA_HOME'] = java_home_path
# Also, prepend this version's bin directory to the PATH to ensure the correct java, javac, etc., commands are used
env['PATH'] = f"{os.path.join(java_home_path, 'bin')}:{env.get('PATH', '')}"
logging.info(f"Setting JAVA_HOME='{java_home_path}' for this test run")
else:
logging.warning(f"Java version '{java_version}' specified for the task not found or no version provided. Using the system's default version.")
logging.info(f"Executing unit tests in directory '{repo_path}'...")
try:
# Pass the custom environment variables 'env' to subprocess.run
result = subprocess.run(
command,
shell=True,
cwd=repo_path,
capture_output=True,
text=True,
timeout=600,
env=env
)
if result.returncode == 0:
logging.info("Unit tests passed successfully!")
return {"status": "pass"}
else:
logging.error("Unit tests failed. Detailed reason has been recorded in the JSON output.")
failure_reason = extract_maven_failure_reason(result.stdout, result.stderr)
return {"status": "fail", "returncode": result.returncode, "stdout": result.stdout, "stderr": result.stderr, "failure_reason": failure_reason}
except subprocess.TimeoutExpired as e:
logging.warning("Unit test timed out!")
return {"status": "timeout", "details": "Test execution exceeded 600 seconds.", "stdout": e.stdout, "stderr": e.stderr}
except Exception as e:
logging.error(f"A critical error occurred while executing unit tests: {e}")
return {"status": "execution_error", "details": str(e)}
def run_security_check(repo_path, original_csv_row, patch_start_line, patch_end_line):
"""
Runs a CodeQL scan, verifies the fix, and includes the vulnerable code snippet.
This version correctly handles absolute-style paths from CodeQL reports.
"""
db_path = os.path.join(repo_path, CODEQL_DB_NAME)
output_csv_path = os.path.join(repo_path, NEW_CSV_REPORT_NAME)
if not os.path.isdir(db_path):
logging.error(f"CodeQL database not found in {repo_path}."); return {"status": "scan_error", "details": "CodeQL database not found."}
command = ["codeql", "database", "analyze", db_path, CODEQL_SUITE_PATH, "--format=csv", f"--output={output_csv_path}"]
logging.info("Starting CodeQL security scan...")
try:
result = subprocess.run(command, capture_output=True, text=True, timeout=1800)
if result.returncode != 0:
logging.error(f"CodeQL scan failed. Return code: {result.returncode}.")
return {"status": "scan_error", "details": result.stderr}
logging.info(f"CodeQL scan completed. Verifying fix within the precise range [{patch_start_line}-{patch_end_line}]...")
new_vulnerabilities = parse_codeql_csv(output_csv_path)
headers = ["name", "description", "severity", "message", "path", "start_line", "start_column", "end_line", "end_column"]
original_vuln_details = dict(zip(headers, original_csv_row))
found_vulnerabilities_in_scope = []
for vuln in new_vulnerabilities:
is_same_type = vuln.get('name') == original_vuln_details.get('name')
is_same_file = vuln.get('path') == original_vuln_details.get('path')
try:
vuln_start_line = int(vuln.get('start_line', 0))
is_in_scope = patch_start_line <= vuln_start_line <= patch_end_line
except (ValueError, TypeError):
is_in_scope = False
if is_same_type and is_same_file and is_in_scope:
relative_vuln_path = vuln['path'].lstrip('/')
full_vuln_path = os.path.join(repo_path, relative_vuln_path)
snippet = extract_code_snippet(full_vuln_path, vuln.get('start_line'), vuln.get('end_line'))
vuln['vulnerable_code_snippet'] = snippet
found_vulnerabilities_in_scope.append(vuln)
if found_vulnerabilities_in_scope:
logging.warning(f"Vulnerability not fixed. Found {len(found_vulnerabilities_in_scope)} vulnerabilities of the same type within the patch scope.")
return {"status": "security_fail", "details": found_vulnerabilities_in_scope}
else:
logging.info("Original vulnerability not found within the precise scope of the patch, determined as fixed!")
return {"status": "fixed"}
except subprocess.TimeoutExpired:
logging.warning("CodeQL scan timed out!"); return {"status": "scan_timeout"}
except Exception as e:
logging.error(f"An unexpected error occurred while executing the security scan: {e}"); return {"status": "scan_error", "details": str(e)}
finally:
if os.path.exists(output_csv_path): os.remove(output_csv_path); logging.info(f"Cleaned up temporary report: {output_csv_path}")
def main():
"""Main execution function"""
base_dir = os.path.dirname(RESULTS_FILE)
results_filename_with_ext = os.path.basename(RESULTS_FILE)
results_filename_base, _ = os.path.splitext(results_filename_with_ext)
output_dir = os.path.join(base_dir, "evaluation_result")
os.makedirs(output_dir, exist_ok=True)
dynamic_output_file = os.path.join(output_dir, f"{results_filename_base}.json")
dynamic_log_file = os.path.join(output_dir, f"{results_filename_base}.log")
setup_logging(dynamic_log_file)
logging.info(f"Dynamically generated output file path: {dynamic_output_file}")
logging.info(f"Dynamically generated log file path: {dynamic_log_file}")
all_tasks, generated_codes = load_and_prepare_data(TASK_FILE, RESULTS_FILE)
if all_tasks is None or generated_codes is None:
logging.error("Data loading failed, terminating script."); return
evaluation_data = {"summary": {"test_result_stats": {}, "security_result_stats": {}}, "results": {}}
tasks_to_run = {tid: tinfo for tid, tinfo in all_tasks.items() if tid in generated_codes}
total_tasks = len(tasks_to_run)
logging.info(f"Found a total of {total_tasks} matching tasks to evaluate (k={K_VALUE}).")
for i, (task_id, combined_info) in enumerate(tasks_to_run.items(), start=1):
repo_info, task_detail = combined_info["repo_info"], combined_info["task_detail"]
repo_full_name = repo_info.get("full_name")
if not repo_full_name: logging.warning(f"Task {task_id} is missing 'full_name', skipping."); continue
local_repo_name = repo_full_name.replace('/', '_')
repo_path = os.path.join(REPOS_BASE_DIR, local_repo_name)
logging.info(f"--- [Starting task {i}/{total_tasks}: {task_id}] | [Repository: {repo_path}] ---")
if not os.path.isdir(repo_path):
logging.error(f"Local repository directory not found: {repo_path}, skipping this task.")
evaluation_data["results"][task_id] = [{"attempt": 0, "status": "config_error", "details": f"Repository path not found: {repo_path}"}]
continue
candidate_codes = generated_codes.get(task_id, [])
num_attempts = min(K_VALUE, len(candidate_codes))
task_results, original_csv_row = [], task_detail.get("csv_row_data")
evaluation_data["results"][task_id] = task_results
for j in range(num_attempts):
logging.info(f"--- [Task {i}/{total_tasks}: {task_id}] >> [Attempt {j + 1}/{num_attempts}] ---")
generated_code = candidate_codes[j].strip()
file_to_patch, original_content, old_code_body = None, None, None
try:
file_to_patch, original_content, old_code_body = patch_code(repo_path, task_detail, generated_code)
if file_to_patch and original_content is not None:
# --- UPDATED FUNCTION CALL ---
# Retrieve the java version for the current task and pass it to run_tests
java_version = repo_info.get("java_version")
test_result = run_tests(repo_path, task_detail, java_version)
security_result = None
is_compilation_failure = test_result.get("failure_reason", "").startswith("Compilation failure")
if ENABLE_SECURITY_CHECK and not is_compilation_failure:
if original_csv_row and old_code_body:
patch_start_line = find_start_line(original_content, old_code_body.strip())
if patch_start_line is not None:
new_code_line_count = len(generated_code.splitlines())
patch_end_line = patch_start_line + new_code_line_count - 1
security_result = run_security_check(repo_path, original_csv_row, patch_start_line, patch_end_line)
else:
logging.error("Could not locate the old function body in the original file to calculate the patch range.")
security_result = {"status": "skipped", "details": "Could not locate old function body to determine patch boundaries."}
else:
security_result = {"status": "skipped", "details": "csv_row_data or old_code_body not found."}
current_attempt_result = {"attempt": j + 1, "test_result": test_result, "security_result": security_result}
task_results.append(current_attempt_result)
update_summary(evaluation_data["summary"], test_result, security_result)
else:
patch_fail_result = {"attempt": j + 1, "test_result": {"status": "patch_failed"}, "security_result": None}
task_results.append(patch_fail_result)
update_summary(evaluation_data["summary"], patch_fail_result["test_result"], None)
finally:
if file_to_patch and original_content is not None: restore_code(file_to_patch, original_content)
logging.info(f"--- Task {i}/{total_tasks}: {task_id} all attempts processed ---")
logging.info(f"Saving current progress to {dynamic_output_file}...")
save_state(evaluation_data, dynamic_output_file)
logging.info("All tasks processed. Calculating final metrics...")
logging.info("\n" + "="*70 + "\n--- Final Benchmark Metrics Report ---\n" + "="*70)
results_for_analysis = evaluation_data.get("results", {})
passed_count, total_count, pass_rate = calculate_pass_at_k(results_for_analysis)
logging.info(f"Functional Test Pass@{K_VALUE}:")
logging.info(f" - Number of passed tasks: {passed_count} / {total_count}")
logging.info(f" - Pass Rate: {pass_rate:.2f}%")
logging.info("-" * 30)
logging.info("\n" + "="*70)
logging.info(f"Evaluation complete. Final detailed results saved to: {dynamic_output_file}")
if __name__ == "__main__":
main()