-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
119 lines (86 loc) · 5.44 KB
/
main.py
File metadata and controls
119 lines (86 loc) · 5.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import json
import argparse
from tqdm import tqdm
from retrieve import load_data, retrieve_BM25, retrieve_BM25_segment, retrieve_BM25_Embedding, retrieve_BM25_Reranker, retrieve_BM25_Embedding_Reranker
from retrieve import retrieve_onlyEmbedding, retrieve_onlyReranker, retrieve_Embedding_Reranker
question_path = "./dataset/preliminary/questions_example.json"
source_path = "./reference/"
truth_path = "./dataset/preliminary/ground_truths_example.json"
output_path = "./dataset/preliminary/pred_retrieve_embedding_reranker.json"
retrieve_algorithm = retrieve_Embedding_Reranker
def original_main() -> None:
# 使用argparse解析命令列參數
parser = argparse.ArgumentParser(description='Process some paths and files.')
parser.add_argument('--question_path', type=str, required=True, help='讀取發布題目路徑') # 問題文件的路徑
parser.add_argument('--source_path', type=str, required=True, help='讀取參考資料路徑') # 參考資料的路徑
parser.add_argument('--output_path', type=str, required=True, help='輸出符合參賽格式的答案路徑') # 答案輸出的路徑
args = parser.parse_args() # 解析參數
answer_dict = {"answers": []}
with open(args.question_path, 'rb') as f:
qs_ref = json.load(f) # 讀取問題檔案
source_path_insurance = os.path.join(args.source_path, 'insurance') # 設定參考資料路徑
corpus_dict_insurance = load_data(source_path_insurance)
source_path_finance = os.path.join(args.source_path, 'finance') # 設定參考資料路徑
corpus_dict_finance = load_data(source_path_finance)
with open(os.path.join(args.source_path, 'faq/pid_map_content.json'), 'rb') as f_s:
key_to_source_dict = json.load(f_s) # 讀取參考資料文件
key_to_source_dict = {int(key): value for key, value in key_to_source_dict.items()}
for q_dict in qs_ref['questions']:
if q_dict['category'] == 'finance':
retrieved = retrieve_BM25(q_dict['query'], q_dict['source'], corpus_dict_finance) # 進行檢索
answer_dict['answers'].append({"qid": q_dict['qid'], "retrieve": retrieved}) # 將結果加入字典
elif q_dict['category'] == 'insurance':
retrieved = retrieve_BM25(q_dict['query'], q_dict['source'], corpus_dict_insurance)
answer_dict['answers'].append({"qid": q_dict['qid'], "retrieve": retrieved})
elif q_dict['category'] == 'faq':
corpus_dict_faq = {key: str(value) for key, value in key_to_source_dict.items() if key in q_dict['source']}
retrieved = retrieve_BM25(q_dict['query'], q_dict['source'], corpus_dict_faq)
answer_dict['answers'].append({"qid": q_dict['qid'], "retrieve": retrieved})
else:
raise ValueError("Something went wrong")
# 將答案字典保存為json文件
with open(args.output_path, 'w', encoding='utf8') as f:
json.dump(answer_dict, f, ensure_ascii=False, indent=4) # 儲存檔案,確保格式和非ASCII字符
def new_main() -> None:
with open(question_path, 'rb') as f:
qs_ref = json.load(f) # 讀取問題檔案
with open(os.path.join(source_path, "corpus_dict_insurance_fitz_ocr.json"), 'rb') as f:
corpus_dict_insurance = json.load(f) # 讀取保險相關的參考資料文件
corpus_dict_insurance = {int(key): value for key, value in corpus_dict_insurance.items()}
with open(os.path.join(source_path, "corpus_dict_finance_fitz_ocr.json"), 'rb') as f:
corpus_dict_finance = json.load(f) # 讀取金融相關的參考資料文件
corpus_dict_finance = {int(key): value for key, value in corpus_dict_finance.items()}
with open(os.path.join(source_path, "corpus_dict_faq.json"), 'rb') as f:
corpus_dict_faq_all = json.load(f) # 讀取常見問題相關的參考資料文件
corpus_dict_faq_all = {int(key): value for key, value in corpus_dict_faq_all.items()}
def handle_category(category: int, query: str, source: list[int]) -> int:
if category == "insurance":
corpus_dict = corpus_dict_insurance
elif category == "finance":
corpus_dict = corpus_dict_finance
elif category == "faq":
corpus_dict = {key: str(value) for key, value in corpus_dict_faq_all.items() if key in source}
else:
raise ValueError("Invalid category")
return retrieve_algorithm(query, source, corpus_dict)
answer_dict = {"answers": []}
for q_dict in tqdm(qs_ref["questions"]):
retrieved = handle_category(q_dict["category"], q_dict["query"], q_dict["source"])
answer_dict["answers"].append({"qid": q_dict["qid"], "retrieve": retrieved})
with open(output_path, 'w', encoding="utf8") as f:
json.dump(answer_dict, f, ensure_ascii=False, indent=4) # 儲存檔案,確保格式和非ASCII字符
# Evaluation
with open(output_path, 'rb') as f: pred = json.load(f)
with open(truth_path, 'rb') as f: truth = json.load(f)
total_count, correct_count = 0, 0
for q_dict_pred, q_dict_truth in zip(pred["answers"], truth["ground_truths"]):
assert q_dict_pred["qid"] == q_dict_truth["qid"]
total_count += 1
if q_dict_pred["retrieve"] == q_dict_truth["retrieve"]:
correct_count += 1
precision = correct_count / total_count
print(f"Correct count / Total count: {correct_count} / {total_count}")
print(f"Precision: {precision:.7f}")
if __name__ == "__main__":
new_main()