-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathzero_shot.py
More file actions
33 lines (27 loc) · 1.12 KB
/
zero_shot.py
File metadata and controls
33 lines (27 loc) · 1.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import os
import sys
sys.path.append(os.path.abspath('GHS-Net_scanslice_pos_v5'))
sys.path.append(os.path.abspath(''))
import logging
from eval.zeroshot_metadata_ct_rate import PROMPTS
from eval.zeroshot_ct_rate import zero_shot as run_ct_rate
def zero_shot_eval(model, data, epoch, args, tokenizer):
if args.zeroshot_frequency == 0:
return {}
if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
return {}
if args.distributed and not args.horovod:
model = model.module
if args.zeroshot_template != 'organ':
PROMPTS["Lung nodule"] = ("Not lung nodule", "Lung nodule")
PROMPTS["Lung opacity"] = ("Not lung opacity", "Lung opacity")
if 'zeroshot-ct-rate' in data:
logging.info('Starting Zero-Shot CT-RATE.')
result = run_ct_rate(model, tokenizer, data['zeroshot-ct-rate'].dataloader, args)
logging.info('Finished Zero-Shot CT-RATE.')
# 返回18类和16类的结果
return {
'18_classes': result['results_18']['* mean'],
'16_classes': result['results_16']['* mean']
}
return {}