-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathaugment_api_data.py
More file actions
435 lines (360 loc) · 16.9 KB
/
augment_api_data.py
File metadata and controls
435 lines (360 loc) · 16.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
API 代码补全数据增强脚本
目标:扩大需要补全的代码范围,使平均补全长度达到 14 个 token
"""
import json
import os
import re
import random
import argparse
from pathlib import Path
from typing import Dict, List, Tuple
try:
import tiktoken
except ModuleNotFoundError:
tiktoken = None
class APIDataAugmentor:
def __init__(self, target_avg_tokens: int = 14, seed: int | None = None):
"""
初始化数据增强器
:param target_avg_tokens: 目标平均 token 数
:param seed: 随机种子(用于可复现的左右扩展策略)
"""
self.target_avg_tokens = target_avg_tokens
self.rng = random.Random(seed) if seed is not None else random
# 使用 cl100k_base 编码器(类似 GPT-3.5/4 的编码方式)
try:
self.encoding = tiktoken.get_encoding("cl100k_base") if tiktoken else None
except:
self.encoding = None
def count_tokens(self, text: str) -> int:
"""
计算文本的 token 数量
"""
if self.encoding:
return len(self.encoding.encode(text))
else:
# 简单估算:按空格和标点符号分割
return len(re.findall(r'\w+|[^\w\s]', text))
def find_balanced_end(self, text: str, start_pos: int = 0) -> int:
"""
从指定位置开始,找到括号、花括号平衡的位置
"""
stack = []
in_string = False
escape_next = False
for i in range(start_pos, len(text)):
char = text[i]
# 处理转义字符
if escape_next:
escape_next = False
continue
if char == '\\':
escape_next = True
continue
# 处理字符串
if char in ('"', "'", '`'):
if not in_string:
in_string = char
elif in_string == char:
in_string = False
continue
if in_string:
continue
# 处理括号
if char in '({[':
stack.append(char)
elif char in ')}]':
if stack:
opening = stack[-1]
if (char == ')' and opening == '(') or \
(char == '}' and opening == '{') or \
(char == ']' and opening == '['):
stack.pop()
# 如果栈空了,找到了平衡点
if not stack:
# 继续找到语句结束(;)或者下一行
for j in range(i + 1, min(i + 50, len(text))):
if text[j] in (';\n', '\n'):
return j + 1
return i + 1
return -1
def _safe_join(self, a: str, b: str) -> str:
if not a:
return b
if not b:
return a
if a[-1].isspace() or b[0].isspace():
return a + b
return a + " " + b
def _find_left_suffix_start_by_tokens(self, left_context: str, tokens_needed: int) -> int:
"""
返回一个字符下标 start,使得 left_context[start:] 至少包含 tokens_needed 个 token。
用于把 left_context 的一部分“挪入” target_code。
"""
if tokens_needed <= 0:
return len(left_context)
start = len(left_context)
while start > 0:
start -= 1
if self.count_tokens(left_context[start:]) >= tokens_needed:
return start
return 0
def _take_right_prefix_balanced(
self,
base_code: str,
right_stripped: str,
min_added_tokens: int,
max_extra_chars: int = 200,
) -> Tuple[str, int]:
"""
从 right_stripped 取一段前缀追加到 base_code 后:
- 至少增加 min_added_tokens 个 token(尽量满足)
- 尽量在语句/括号边界停止(利用 end_markers / 平衡点)
返回 (appended_text, consumed_chars)
"""
if not right_stripped or min_added_tokens <= 0:
return "", 0
base_tokens = self.count_tokens(base_code)
end_markers = [';', '}\n', ')\n', '},\n', ');\n', '});\n']
# 逐字符添加直到达到 token 目标或超限
limit = min(len(right_stripped), max_extra_chars)
for i in range(limit):
candidate = right_stripped[: i + 1]
test_code = self._safe_join(base_code, candidate)
if self.count_tokens(test_code) - base_tokens >= min_added_tokens:
# 找到最近结束标记,避免扩太远
remaining = right_stripped[i:]
min_end = None
for marker in end_markers:
pos = remaining.find(marker)
if pos != -1:
end_pos = i + pos + len(marker)
if min_end is None or end_pos < min_end:
min_end = end_pos
if min_end is not None and min_end - (i + 1) < 100:
return right_stripped[:min_end], min_end
return candidate, i + 1
# 达不到 token 目标:尽量找一个较早的结束标记
remaining = right_stripped[:limit]
for marker in end_markers:
pos = remaining.find(marker)
if pos != -1:
end_pos = pos + len(marker)
return remaining[:end_pos], end_pos
return right_stripped[:limit], limit
def extract_more_context(
self,
left_context: str,
right_context: str,
original_api: str,
current_tokens: int
) -> Tuple[str, str, str]:
"""
从 left_context 和 right_context 中提取更多上下文,扩展 target_code。
返回扩展后的 (new_left_context, new_target_code, new_right_context)。
"""
# 计算需要增加的 token 数
tokens_needed = self.target_avg_tokens - current_tokens
if tokens_needed <= 2: # 至少需要增加2个token才值得增强
return None, None, None
# 策略:
# 1) 先处理括号未闭合:必须向右扩到平衡点(否则可能破坏语法)
# 2) 剩余需要补的 token:随机选择向左/向右/左右同时扩展(参考 generate_api_dataset.py 的策略)
original_api = original_api or ""
left_context = left_context or ""
right_context = right_context or ""
extended_code = original_api.rstrip()
new_left_context = left_context
new_right_context = right_context
# right_context: 保留原始前导空白,但消费从第一个非空白开始的部分
right_start = len(right_context) - len(right_context.lstrip())
right_stripped = right_context[right_start:]
# 检查 original_api 是否有未闭合的括号
open_parens = original_api.count('(') - original_api.count(')')
open_braces = original_api.count('{') - original_api.count('}')
open_brackets = original_api.count('[') - original_api.count(']')
# 如果有未闭合的括号/花括号,一定要找到闭合点
if open_parens > 0 or open_braces > 0 or open_brackets > 0:
# 从 right_context 找到平衡点
balance_end = self.find_balanced_end(
original_api + right_stripped,
len(original_api)
)
if balance_end > len(original_api):
appended = (original_api + right_stripped)[len(original_api):balance_end]
extended_code = self._safe_join(original_api, appended)
consumed_len = max(0, balance_end - len(original_api))
new_right_context = (right_context[:right_start] + right_stripped[consumed_len:]).lstrip()
new_tokens = self.count_tokens(extended_code)
# 如果扩展后的 token 数合理,先继续尝试(后续可能还需要向左补)
if new_tokens > self.target_avg_tokens * 1.8:
return None, None, None
else:
# 需要向右闭合但没有足够上下文,无法安全增强
return None, None, None
# 剩余 token:随机选择扩展方向
tokens_now = self.count_tokens(extended_code)
remaining_needed = max(0, self.target_avg_tokens - tokens_now)
if remaining_needed <= 2:
final_code = extended_code.strip()
if final_code and final_code != original_api.strip():
return new_left_context, final_code, new_right_context
return None, None, None
strategy = self.rng.randint(0, 2) # 0=左 1=右 2=左右
if strategy == 0:
tokens_before, tokens_after = remaining_needed, 0
elif strategy == 1:
tokens_before, tokens_after = 0, remaining_needed
else:
tokens_before = self.rng.randint(0, remaining_needed)
tokens_after = remaining_needed - tokens_before
# 向右扩展:按 tokens_after 取一段,并更新 right_context
if tokens_after > 0 and right_stripped:
appended, consumed = self._take_right_prefix_balanced(extended_code, right_stripped, tokens_after)
if appended.strip():
extended_code = self._safe_join(extended_code, appended)
new_right_context = (right_context[:right_start] + right_stripped[consumed:]).lstrip()
# 向左扩展:按 tokens_before 从 left_context 的末尾挪一段到 target 前,并更新 left_context
if tokens_before > 0 and left_context:
start_idx = self._find_left_suffix_start_by_tokens(left_context, tokens_before)
left_added = left_context[start_idx:]
if left_added.strip():
new_left_context = left_context[:start_idx].rstrip()
extended_code = self._safe_join(left_added, extended_code)
# 最终检查:确实变长且不过分膨胀
final_code = extended_code.strip()
if not final_code or final_code == original_api.strip():
return None, None, None
final_tokens = self.count_tokens(final_code)
if final_tokens <= current_tokens or final_tokens > self.target_avg_tokens * 1.8:
return None, None, None
return new_left_context, final_code, new_right_context
def augment_sample(self, sample_data: Dict) -> Dict:
"""
增强单个样本
"""
api_info = sample_data.get('api_info', {})
original_api = api_info.get('original_api', '')
current_tokens = api_info.get('token_count', 0)
# 如果当前 token 数已经接近目标,不需要增强
if current_tokens >= self.target_avg_tokens * 0.85:
return sample_data
left_context = sample_data.get('left_context', '')
right_context = sample_data.get('right_context', '')
# 尝试扩展
new_left, new_target, new_right = self.extract_more_context(
left_context, right_context, original_api, current_tokens
)
if new_target and new_target != original_api:
# 更新样本数据
augmented_sample = sample_data.copy()
# 计算新的 token 数
new_token_count = self.count_tokens(new_target)
# 只有当新的 token 数确实增加了才更新
if new_token_count > current_tokens:
# 更新 target_code
augmented_sample['target_code'] = new_target
if new_left is not None:
augmented_sample['left_context'] = new_left
if new_right is not None:
augmented_sample['right_context'] = new_right
# 更新 api_info
new_api_info = api_info.copy()
new_api_info['original_api'] = new_target
new_api_info['token_count'] = new_token_count
new_api_info['augmented'] = True # 标记为增强过的样本
augmented_sample['api_info'] = new_api_info
return augmented_sample
return sample_data
def process_file(self, input_file: str, output_file: str) -> Dict:
"""
处理单个 JSON 文件
"""
print(f"处理文件: {input_file}")
with open(input_file, 'r', encoding='utf-8') as f:
data = json.load(f)
augmented_data = {}
total_samples = len(data)
augmented_count = 0
total_tokens_before = 0
total_tokens_after = 0
for key, sample in data.items():
original_tokens = sample.get('api_info', {}).get('token_count', 0)
total_tokens_before += original_tokens
augmented_sample = self.augment_sample(sample)
augmented_data[key] = augmented_sample
new_tokens = augmented_sample.get('api_info', {}).get('token_count', 0)
total_tokens_after += new_tokens
if new_tokens != original_tokens:
augmented_count += 1
# 保存增强后的数据
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(augmented_data, f, ensure_ascii=False, indent=2)
avg_before = total_tokens_before / total_samples if total_samples > 0 else 0
avg_after = total_tokens_after / total_samples if total_samples > 0 else 0
stats = {
'file': os.path.basename(input_file),
'total_samples': total_samples,
'augmented_samples': augmented_count,
'avg_tokens_before': round(avg_before, 2),
'avg_tokens_after': round(avg_after, 2)
}
print(f" - 总样本数: {total_samples}")
print(f" - 增强样本数: {augmented_count}")
print(f" - 平均 token 数 (增强前): {avg_before:.2f}")
print(f" - 平均 token 数 (增强后): {avg_after:.2f}")
return stats
def main():
"""
主函数
"""
parser = argparse.ArgumentParser(description="API 代码补全数据增强脚本(支持左右/右/左右随机扩展)")
parser.add_argument("--input-dir", type=str, default="api_code_completion", help="输入目录(默认: api_code_completion)")
parser.add_argument("--output-dir", type=str, default="api_code_completion_augmented", help="输出目录(默认: api_code_completion_augmented)")
parser.add_argument("--target-avg-tokens", type=int, default=20, help="目标平均 token 数(默认: 20)")
parser.add_argument("--seed", type=int, default=None, help="随机种子(默认: None)")
args = parser.parse_args()
input_dir = Path(args.input_dir)
output_dir = Path(args.output_dir)
# 创建输出目录
output_dir.mkdir(parents=True, exist_ok=True)
# 初始化增强器
augmentor = APIDataAugmentor(target_avg_tokens=args.target_avg_tokens, seed=args.seed)
# 获取所有 JSON 文件
json_files = list(input_dir.glob("*_api_code_data.json"))
print(f"找到 {len(json_files)} 个文件\n")
all_stats = []
for json_file in json_files:
output_file = output_dir / json_file.name
stats = augmentor.process_file(str(json_file), str(output_file))
all_stats.append(stats)
print()
# 汇总统计
print("=" * 80)
print("数据增强汇总统计:")
print("=" * 80)
print(f"{'文件名':<50} {'样本数':>8} {'增强数':>8} {'增强前':>10} {'增强后':>10}")
print("-" * 80)
total_samples = 0
total_augmented = 0
total_tokens_before = 0
total_tokens_after = 0
for stat in all_stats:
print(f"{stat['file']:<50} {stat['total_samples']:>8} {stat['augmented_samples']:>8} "
f"{stat['avg_tokens_before']:>10.2f} {stat['avg_tokens_after']:>10.2f}")
total_samples += stat['total_samples']
total_augmented += stat['augmented_samples']
total_tokens_before += stat['total_samples'] * stat['avg_tokens_before']
total_tokens_after += stat['total_samples'] * stat['avg_tokens_after']
print("-" * 80)
avg_before = total_tokens_before / total_samples if total_samples > 0 else 0
avg_after = total_tokens_after / total_samples if total_samples > 0 else 0
print(f"{'总计':<50} {total_samples:>8} {total_augmented:>8} "
f"{avg_before:>10.2f} {avg_after:>10.2f}")
print("=" * 80)
print(f"\n增强后的数据已保存到: {output_dir}")
if __name__ == "__main__":
main()