-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate.py
More file actions
62 lines (55 loc) · 2.54 KB
/
generate.py
File metadata and controls
62 lines (55 loc) · 2.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import re
import time
import json
import argparse
if __name__ == "__main__":
args = argparse.ArgumentParser()
args.add_argument("--data_file", type=str, default="data/charthal.json")
args.add_argument("--model_id", type=str, required=True)
args.add_argument("--model_type", type=str, required=True)
args.add_argument("--mode", type=str, default="transformers", choices=["transformers", "vllm"])
args.add_argument("--batch_size", type=int, default=1)
args.add_argument("--save_dir", type=str, default="results")
args = args.parse_args()
if args.save_dir is None:
args.save_dir = "."
model_name = args.model_id.split("/")[-1]
resp_dir = os.path.join(args.save_dir, model_name)
os.makedirs(resp_dir, exist_ok=True)
save_file = os.path.join(resp_dir, "response.json")
# open dataset file
with open(args.data_file, "r", encoding='utf-8') as f:
eval_data = json.load(f)
# generate response
if args.model_type == "qwen25vl":
if args.mode == "vllm":
from models.qwen25vl import generate_resp_qwen25vl_vllm
eval_results = generate_resp_qwen25vl_vllm(args.model_id, eval_data)
else:
from models.qwen25vl import generate_resp_qwen25vl
eval_results = generate_resp_qwen25vl(args.model_id, eval_data)
elif args.model_type == "qwen3vl":
if args.mode == "vllm":
from models.qwen3vl import generate_resp_qwen3vl_vllm
eval_results = generate_resp_qwen3vl_vllm(args.model_id, eval_data)
else:
from models.qwen3vl import generate_resp_qwen3vl
eval_results = generate_resp_qwen3vl(args.model_id, eval_data)
elif args.model_type == "internvl25":
if args.mode == "vllm":
print("ERROR: InternVL2.5 VLLM generation not implemented yet. Use transformers mode.")
from models.internvl25 import generate_resp_internvl25
eval_results = generate_resp_internvl25(args.model_id, eval_data)
elif args.model_type == "llama32v":
if args.mode == "vllm":
print("ERROR: Llama32V VLLM generation not implemented yet. Use transformers mode.")
from models.llama32v import generate_resp_llama32v
eval_results = generate_resp_llama32v(args.model_id, eval_data)
else:
print(f"ERROR: Unknown model type: {args.model_type}")
exit(0)
# save result file
with open(save_file, "w", encoding='utf-8') as f:
json.dump(eval_results, f, ensure_ascii=False, indent=4)
print(f"Results saved to {save_file}")