-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathretriever.py
More file actions
142 lines (127 loc) · 7.01 KB
/
retriever.py
File metadata and controls
142 lines (127 loc) · 7.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
import logging
from tqdm import tqdm
import torch.nn as nn
from datasets import CodeBlock
from multiprocessing import Pool
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel
from concurrent.futures import ThreadPoolExecutor
def batch_tokenize(tokenizer, texts, max_workers=64):
with Pool(processes=max_workers) as pool:
return list(pool.map(tokenizer.tokenize, texts))
def tokenize(text, tokenizer, max_length, is_query, extracted_import='', tokenized_token=None):
"""
Converts text to a list of token ids.
:param text: The text to be converted
:param tokenizer: The tokenizer to use
:param max_length: The maximum input length
:param is_query: A flag indicating whether the text is a query
:return: A list of token ids
"""
if extracted_import:
# import_tokens = tokenizer.tokenize(extracted_import)[-127:] + [tokenizer.sep_token]
import_tokens = []
else:
import_tokens = []
if tokenized_token:
tokens = tokenized_token
else:
tokens = tokenizer.tokenize(text)
tokens = tokenizer.tokenize(text)
if is_query:
tokens = tokens[-(max_length - len(import_tokens)) + 4:]
else:
tokens = tokens[:(max_length - len(import_tokens)) - 4]
tokens = [tokenizer.cls_token, "<encoder-only>", tokenizer.sep_token] + import_tokens + tokens + [tokenizer.sep_token]
tokens_id = tokenizer.convert_tokens_to_ids(tokens)
padding_length = max_length - len(tokens_id)
tokens_id += [tokenizer.pad_token_id] * padding_length
return tokens_id
class CustomDataset(Dataset):
"""
Custom dataset class for handling code blocks and queries.
:param max_length: The maximum input length
:param tokenizer: The tokenizer used
:param examples: The samples in the dataset
:param is_query: A flag indicating whether it is a query
"""
def __init__(self, max_length, tokenizer, examples, query=False, extracted_imports=None):
self.max_length = max_length
self.tokenizer = tokenizer
self.examples = examples
self.query = query
self.extracted_imports = extracted_imports
self.tokenized_tokens = batch_tokenize(self.tokenizer, [str(example) for example in self.examples])
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
text = str(self.examples[idx])
extracted_import = str(self.extracted_imports[idx]) if self.extracted_imports else ''
tokenized_token = self.tokenized_tokens[idx]
tokens_id = tokenize(text, self.tokenizer, self.max_length, self.query, extracted_import, tokenized_token)
# tokens_id = tokenize(text, self.tokenizer, self.max_length, self.query, extracted_import)
return torch.tensor(tokens_id, dtype=torch.long)
class Retriever(nn.Module):
"""
Retriever model, used to compute sentence embeddings and retrieve similar code blocks.
:param args: A namespace containing configuration parameters
"""
def __init__(self, args):
super(Retriever, self).__init__()
self.tokenizer = AutoTokenizer.from_pretrained(args.retriever_model_path)
self.args = args
if self.args.disable_retriever is False:
self.model = AutoModel.from_pretrained(args.retriever_model_path)
self.model = torch.nn.DataParallel(self.model).cuda()
self.model.eval()
def forward(self, source_ids):
"""
Forward propagation function, used to generate the embedding representation of the input.
:param input_ids: The sequence of input IDs
:return: The embedding representation
"""
mask = source_ids.ne(self.tokenizer.pad_token_id)
token_embeddings = self.model(source_ids, attention_mask=mask)[0]
sentence_embeddings = (token_embeddings * mask.unsqueeze(-1)).sum(1) / mask.sum(-1).unsqueeze(-1)
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings
def retrieve(self, queries, candidate_codeblocks, topk, extracted_imports=None):
"""
Retrieval function, used to retrieve the most relevant code blocks from a list of candidate code blocks for each query.
:param queries: A list of queries
:param candidate_codeblocks: A list of candidate code blocks
:param topk: The number of top-k code blocks to return for each query
:return: A list of top-k code blocks for each query
"""
query_dataset = CustomDataset(self.args.retriever_query_context_length, self.tokenizer, queries, query=True, extracted_imports=extracted_imports)
query_dataloader = DataLoader(query_dataset, batch_size=self.args.retriever_batch_size, shuffle=False, num_workers=self.args.num_workers)
query_dataloader = tqdm(query_dataloader, desc="Encoding Query Blocks") if self.args.enable_tqdm else query_dataloader
candidate_numbers, candidate_codeblocks = [len(x) for x in candidate_codeblocks], [x for y in candidate_codeblocks for x in y]
code_dataset = CustomDataset(self.args.retriever_candidate_context_length, self.tokenizer, candidate_codeblocks, query=False)
code_dataloader = DataLoader(code_dataset, batch_size=self.args.retriever_batch_size, shuffle=False, num_workers=self.args.num_workers)
code_dataloader = tqdm(code_dataloader, desc="Encoding Code Blocks") if self.args.enable_tqdm else code_dataloader
query_embeddings, code_embeddings = [], []
with torch.no_grad():
for batch in query_dataloader:
query_embeddings.append(self.forward(batch.cuda()))
for batch in code_dataloader:
code_embeddings.append(self.forward(batch.cuda()))
query_embeddings = torch.cat(query_embeddings, dim=0)
code_embeddings = torch.cat(code_embeddings, dim=0)
scores = torch.mm(query_embeddings, code_embeddings.t())
scores = scores.cpu().numpy()
topk_codeblocks = [] # Stores top-k codeblocks for each query
start_idx = 0
for i, num_candidates in enumerate(candidate_numbers):
if num_candidates == 0:
topk_codeblocks.append([]) # If there are no candidates for this query, add an empty list
continue
query_scores = scores[i][start_idx:start_idx + num_candidates]
topk_indices_query = query_scores.argsort()[-topk:][::-1]
topk_codeblocks_query = [candidate_codeblocks[start_idx + idx] for idx in topk_indices_query]
if len(topk_codeblocks_query) < topk:
topk_codeblocks_query += [CodeBlock("","Don't need cross file context to completion", "", topk_codeblocks_query[0].language, '')] * (topk - len(topk_codeblocks_query))
topk_codeblocks.append(topk_codeblocks_query)
start_idx += num_candidates
return topk_codeblocks