-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_ner.py
More file actions
executable file
·87 lines (68 loc) · 2.77 KB
/
run_ner.py
File metadata and controls
executable file
·87 lines (68 loc) · 2.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import json
import os
import pandas as pd
import spacy
from transformers import pipeline
from utils.utils import load_config
def generate_ner_filename(model_name, p_cfg, raw_data):
c_start, s_start = p_cfg['min_canto'], p_cfg['min_stanza']
c_end = p_cfg.get('max_canto') or raw_data[-1]['canto']
s_end = p_cfg.get('max_stanza')
if s_end is None:
for canto_data in raw_data:
if int(canto_data['canto']) == int(c_end):
s_end = canto_data['octaves'][-1]['octave_number']
break
clean_model_name = model_name.split('/')[-1]
return f"{clean_model_name}_C{c_start}-O{s_start}_to_C{c_end}-O{s_end}_NER.csv"
def main():
MODE = "NER"
cfg = load_config()
p_cfg = cfg['pipeline']
with open(p_cfg['data_path'], 'r', encoding='utf-8') as f:
raw_data = json.load(f)
# SpaCy
nlp_spacy = spacy.load("it_core_news_lg")
# BERT Transformer
nlp_trans = pipeline(
"ner",
model="osiria/bert-italian-cased-ner",
aggregation_strategy="simple",
device=0
)
spacy_rows = []
transformer_rows = []
for canto_data in raw_data:
canto_num = int(canto_data['canto'])
if canto_num < p_cfg['min_canto']: continue
if p_cfg.get('max_canto') and canto_num > p_cfg['max_canto']: break
octaves = canto_data['octaves']
texts = [o['text'] for o in octaves]
nums = [o['octave_number'] for o in octaves]
# SpaCy
docs = list(nlp_spacy.pipe(texts))
for i, doc in enumerate(docs):
names = {ent.text for ent in doc.ents if ent.label_ == "PER"}
for name in names:
spacy_rows.append({"canto": canto_num, "stanza": nums[i], "character_name": name})
# BERT Transformer
trans_results = nlp_trans(texts)
for i, entities in enumerate(trans_results):
names = {ent['word'] for ent in entities if ent['entity_group'] == 'PER'}
for name in names:
transformer_rows.append({"canto": canto_num, "stanza": nums[i], "character_name": name})
out_dir = os.path.join("output", MODE)
os.makedirs(out_dir, exist_ok=True)
# SpaCy
df_spacy = pd.DataFrame(spacy_rows)
spacy_filename = generate_ner_filename("it_core_news_lg", p_cfg, raw_data)
df_spacy.to_csv(os.path.join(out_dir, spacy_filename), index=False)
# BERT Transformer
df_trans = pd.DataFrame(transformer_rows)
trans_filename = generate_ner_filename("bert-italian-cased-ner", p_cfg, raw_data)
df_trans.to_csv(os.path.join(out_dir, trans_filename), index=False)
print(f"Files saved in {out_dir}:")
print(f"- {spacy_filename}")
print(f"- {trans_filename}")
if __name__ == "__main__":
main()