-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsearch_code.py
More file actions
144 lines (130 loc) · 7 KB
/
search_code.py
File metadata and controls
144 lines (130 loc) · 7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from concurrent.futures import as_completed, ProcessPoolExecutor
import json
import re
import numpy as np
import scipy
import tqdm
import os
import copy
import functools
from utils import Tools, FilePathBuilder, CONSTANTS
class SimilarityScore:
@staticmethod
def cosine_similarity(embedding_vec1, embedding_vec2):
return 1 - scipy.spatial.distance.cosine(embedding_vec1, embedding_vec2)
@staticmethod
def jaccard_similarity(list1, list2):
set1 = set(list1)
set2 = set(list2)
intersection = len(set1.intersection(set2))
union = len(set1.union(set2))
return float(intersection) / union
class CodeSearchWorker:
def __init__(self, repo_embedding_lines, query_embedding_lines, output_path, sim_scorer, max_top_k, log_message):
self.repo_embedding_lines = repo_embedding_lines # list
self.query_embedding_lines = query_embedding_lines # list
self.max_top_k = max_top_k
self.sim_scorer = sim_scorer
self.output_path = output_path
self.log_message = log_message
def _is_context_after_hole(self, repo_embedding_line, query_line):
hole_fpath = "/".join(query_line['metadata']['fpath_tuple'])
context_is_not_after_hole = []
for metadata in repo_embedding_line['metadata']:
if not "/".join(metadata['fpath_tuple']).endswith(hole_fpath):
context_is_not_after_hole.append(True)
continue
# now we know that the repo line is in the same file as the hole
if metadata['end_line_no'] <= query_line['metadata']['context_start_lineno']:
context_is_not_after_hole.append(True)
continue
context_is_not_after_hole.append(False)
return not any(context_is_not_after_hole)
def _is_context_containing_query(self, repo_embedding_line, query_line):
print(query_line["context"])
parts = query_line["context"].split("'''")
signature, docstr = parts[0].strip(), parts[1].strip()
docstr = re.sub(r"\s+", " ", docstr).lower()
signature = re.sub(r"\s+", " ", signature).lower()
context = re.sub(r"\s+", " ", repo_embedding_line["context"]).lower()
# if signature in context or docstr in context:
# print(signature)
# print(docstr)
# print(context)
return signature in context or docstr in context
def _find_top_k_context(self, query_line):
top_k_context = []
query_embedding = np.array(query_line['data'][0]['embedding'])
for repo_embedding_line in self.repo_embedding_lines:
if self._is_context_after_hole(repo_embedding_line, query_line):
continue
# if self._is_context_containing_query(repo_embedding_line, query_line):
# continue
repo_line_embedding = np.array(repo_embedding_line['data'][0]['embedding'])
similarity_score = self.sim_scorer(query_embedding, repo_line_embedding)
top_k_context.append((repo_embedding_line, similarity_score))
top_k_context = sorted(top_k_context, key=lambda x: x[1], reverse=False)[-self.max_top_k:]
# print("#" * 50)
# print(query_line)
# print(json.dumps(top_k_context))
return top_k_context
def run(self):
query_lines_with_retrieved_results = []
for query_line in self.query_embedding_lines:
new_line = copy.deepcopy(query_line)
top_k_context = self._find_top_k_context(new_line)
new_line['top_k_context'] = top_k_context
query_lines_with_retrieved_results.append(new_line)
Tools.dump_pickle(query_lines_with_retrieved_results, self.output_path)
class CodeSearchWrapper:
def __init__(self, vectorizer, benchmark, repos, window_sizes, slice_sizes):
self.vectorizer = vectorizer
if vectorizer == 'one-gram':
self.sim_scorer = SimilarityScore.jaccard_similarity
self.vector_path_builder = FilePathBuilder.one_gram_vector_path
elif vectorizer == 'ada002':
self.sim_scorer = SimilarityScore.cosine_similarity
self.vector_path_builder = FilePathBuilder.ada002_vector_path
self.max_top_k = 20 # store 20 top k context for the prompt construction (top 10)
self.repos = repos
self.window_sizes = window_sizes
self.slice_sizes = slice_sizes
self.benchmark = benchmark
def _run_parallel(self, query_window_path_builder, prediction_path_template=None):
workers = []
for window_size in self.window_sizes:
for slice_size in self.slice_sizes:
for repo in self.repos:
if prediction_path_template:
query_window_path = query_window_path_builder(
prediction_path_template.format(window_size=window_size, slice_size=slice_size),
repo, window_size
)
else:
query_window_path = query_window_path_builder(repo, window_size)
query_line_path = self.vector_path_builder(query_window_path)
repo_window_path = FilePathBuilder.repo_windows_path(repo, window_size, slice_size)
repo_embedding_path = self.vector_path_builder(repo_window_path)
output_path = FilePathBuilder.retrieval_results_path(query_line_path, repo_embedding_path, self.max_top_k)
repo_embedding_lines = Tools.load_pickle(repo_embedding_path)
query_embedding_lines = Tools.load_pickle(query_line_path)
# print(repo_embedding_lines)
# print(query_embedding_lines)
log_message = f'repo: {repo}, window: {window_size}, slice: {slice_size} {self.vectorizer}, max_top_k: {self.max_top_k}'
worker = CodeSearchWorker(repo_embedding_lines, query_embedding_lines, output_path, self.sim_scorer, self.max_top_k, log_message)
workers.append(worker)
# process pool
with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
futures = {executor.submit(worker.run, ) for worker in workers}
for future in tqdm.tqdm(as_completed(futures), total=len(futures)):
future.result()
def search_baseline_and_ground(self):
query_line_path_temp = functools.partial(FilePathBuilder.search_first_window_path, self.benchmark, CONSTANTS.rg)
self._run_parallel(query_line_path_temp)
# query_line_path_temp = functools.partial(FilePathBuilder.search_first_window_path, self.benchmark, CONSTANTS.gt)
# self._run_parallel(query_line_path_temp)
def search_prediction(self, mode, prediction_path_template):
query_line_path_temp = functools.partial(FilePathBuilder.gen_first_window_path, self.benchmark, mode)
self._run_parallel(query_line_path_temp, prediction_path_template)