-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathmodel.py
More file actions
executable file
·153 lines (112 loc) · 7.24 KB
/
model.py
File metadata and controls
executable file
·153 lines (112 loc) · 7.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# add parent directory to sys.path
import sys
sys.path.append('.')
import logging
import torch
# = = = = = = = = = = = Logging Setup = = = = = = = = = = = = =
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
class Model(object):
def __init__(self, model_name_or_path):
self.dataset_name = None
self.model_name = model_name_or_path
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.load_model()
logger.info("Loaded model: {}".format(self.model_name))
logger.info("= = "*20)
def load_model(self):
if self.model_name == "cascade_whisper_large_v3_llama_3_8b_instruct":
from model_src.whisper_large_v3_with_llama_3_8b_instruct import whisper_large_v3_with_llama_3_8b_instruct_model_loader
whisper_large_v3_with_llama_3_8b_instruct_model_loader(self)
elif self.model_name == "cascade_whisper_large_v2_gemma2_9b_cpt_sea_lionv3_instruct":
from model_src.whisper_large_v2_gemma2_9b_cpt_sea_lionv3_instruct import whisper_large_v2_gemma2_9b_cpt_sea_lionv3_instruct_model_loader
whisper_large_v2_gemma2_9b_cpt_sea_lionv3_instruct_model_loader(self)
elif self.model_name == "Qwen2-Audio-7B-Instruct":
from model_src.qwen2_audio_7b_instruct import qwen2_audio_7b_instruct_model_loader
qwen2_audio_7b_instruct_model_loader(self)
elif self.model_name == "SALMONN_7B":
from model_src.salmonn_7b import salmonn_7b_model_loader
salmonn_7b_model_loader(self)
elif self.model_name == 'WavLLM_fairseq':
from model_src.wavllm_fairseq import wavllm_fairseq_model_loader
wavllm_fairseq_model_loader(self)
elif self.model_name == 'Qwen-Audio-Chat':
from model_src.qwen_audio_chat import qwen_audio_chat_model_loader
qwen_audio_chat_model_loader(self)
elif self.model_name == 'MERaLiON-AudioLLM-Whisper-SEA-LION':
from model_src.meralion_audiollm_whisper_sea_lion import meralion_audiollm_whisper_sea_lion_model_loader
meralion_audiollm_whisper_sea_lion_model_loader(self)
elif self.model_name == 'gemini-1.5-flash':
from model_src.gemini_1_5_flash import gemini_1_5_flash_model_loader
gemini_1_5_flash_model_loader(self)
elif self.model_name == 'gemini-2-flash':
from model_src.gemini_2_flash import gemini_2_flash_model_loader
gemini_2_flash_model_loader(self)
elif self.model_name == 'whisper_large_v3':
from model_src.whisper_large_v3 import whisper_large_v3_model_loader
whisper_large_v3_model_loader(self)
elif self.model_name == 'whisper_large_v2':
from model_src.whisper_large_v2 import whisper_large_v2_model_loader
whisper_large_v2_model_loader(self)
elif self.model_name == 'gpt-4o-audio':
from model_src.gpt_4o_audio import gpt_4o_audio_model_loader
gpt_4o_audio_model_loader(self)
elif self.model_name == 'phi_4_multimodal_instruct':
from model_src.phi_4_multimodal_instruct import phi_4_multimodal_instruct_model_loader
phi_4_multimodal_instruct_model_loader(self)
elif self.model_name == 'seallms_audio_7b':
from model_src.seallms_audio_7b import seallms_audio_7b_model_loader
seallms_audio_7b_model_loader(self)
else:
raise NotImplementedError("Model {} not implemented yet".format(self.model_name))
def generate(self, input):
with torch.no_grad():
if self.model_name == "cascade_whisper_large_v3_llama_3_8b_instruct":
from model_src.whisper_large_v3_with_llama_3_8b_instruct import whisper_large_v3_with_llama_3_8b_instruct_model_generation
return whisper_large_v3_with_llama_3_8b_instruct_model_generation(self, input)
elif self.model_name == "cascade_whisper_large_v2_gemma2_9b_cpt_sea_lionv3_instruct":
from model_src.whisper_large_v2_gemma2_9b_cpt_sea_lionv3_instruct import whisper_large_v2_gemma2_9b_cpt_sea_lionv3_instruct_model_generation
return whisper_large_v2_gemma2_9b_cpt_sea_lionv3_instruct_model_generation(self, input)
elif self.model_name == "Qwen2-Audio-7B-Instruct":
from model_src.qwen2_audio_7b_instruct import qwen2_audio_7b_instruct_model_generation
return qwen2_audio_7b_instruct_model_generation(self, input)
elif self.model_name == "SALMONN_7B":
from model_src.salmonn_7b import salmonn_7b_model_generation
return salmonn_7b_model_generation(self, input)
elif self.model_name == "WavLLM_fairseq":
from model_src.wavllm_fairseq import wavllm_fairseq_model_generation
return wavllm_fairseq_model_generation(self, input)
elif self.model_name == "Qwen-Audio-Chat":
from model_src.qwen_audio_chat import qwen_audio_chat_model_generation
return qwen_audio_chat_model_generation(self, input)
elif self.model_name == "MERaLiON-AudioLLM-Whisper-SEA-LION":
from model_src.meralion_audiollm_whisper_sea_lion import meralion_audiollm_whisper_sea_lion_model_generation
return meralion_audiollm_whisper_sea_lion_model_generation(self, input)
elif self.model_name == "gemini-1.5-flash":
from model_src.gemini_1_5_flash import gemini_1_5_flash_model_generation
return gemini_1_5_flash_model_generation(self, input)
elif self.model_name == "gemini-2-flash":
from model_src.gemini_2_flash import gemini_2_flash_model_generation
return gemini_2_flash_model_generation(self, input)
elif self.model_name == "whisper_large_v3":
from model_src.whisper_large_v3 import whisper_large_v3_model_generation
return whisper_large_v3_model_generation(self, input)
elif self.model_name == "whisper_large_v2":
from model_src.whisper_large_v2 import whisper_large_v2_model_generation
return whisper_large_v2_model_generation(self, input)
elif self.model_name == "gpt-4o-audio":
from model_src.gpt_4o_audio import gpt_4o_audio_model_generation
return gpt_4o_audio_model_generation(self, input)
elif self.model_name == 'phi_4_multimodal_instruct':
from model_src.phi_4_multimodal_instruct import phi_4_multimodal_instruct_model_generation
return phi_4_multimodal_instruct_model_generation(self, input)
elif self.model_name == 'seallms_audio_7b':
from model_src.seallms_audio_7b import seallms_audio_7b_model_generation
return seallms_audio_7b_model_generation(self, input)
else:
raise NotImplementedError("Model {} not implemented yet".format(self.model_name))