-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathgemini_2_flash.py
More file actions
executable file
·111 lines (78 loc) · 3.17 KB
/
gemini_2_flash.py
File metadata and controls
executable file
·111 lines (78 loc) · 3.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/env python
# -*- coding:utf-8 -*-
###
# Created Date: Friday, April 19th 2024, 11:17:41 am
# Author: Bin Wang
# -----
# Copyright (c) Bin Wang @ bwang28c@gmail.com
#
# -----
# HISTORY:
# Date&Time By Comments
# ---------- --- ----------------------------------------------------------
###
import os
import re
# add parent directory to sys.path
import sys
sys.path.append('.')
sys.path.append('../')
import logging
import numpy as np
import torch
from tqdm import tqdm
import pathlib
import soundfile as sf
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
import google.generativeai as genai
import tempfile
# = = = = = = = = = = = Logging Setup = = = = = = = = = = = = =
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
def gemini_2_flash_model_loader(self):
# Initialize a Gemini model appropriate for your use case.
self.model = genai.GenerativeModel('models/gemini-2.0-flash-exp')
logger.info("Model loaded")
def do_sample_inference(self, audio_array, instruction, sampling_rate=16000):
audio_path = tempfile.NamedTemporaryFile(suffix=".wav", prefix="audio_", delete=False)
sf.write(audio_path.name, audio_array, sampling_rate)
response = self.model.generate_content([
instruction,
{
"mime_type": "audio/wav",
"data": pathlib.Path(audio_path.name).read_bytes()
}
])
response = response.text
return response
def gemini_2_flash_model_generation(self, input):
audio_array = input["audio"]["array"]
sampling_rate = input["audio"]["sampling_rate"]
audio_duration = len(audio_array) / sampling_rate
instruction = input["instruction"]
os.makedirs('tmp', exist_ok=True)
# For ASR task, if audio duration is more than 30 seconds, we will chunk and infer separately
if audio_duration > 30 and input['task_type'] == 'ASR':
logger.info('Audio duration is more than 30 seconds. Chunking and inferring separately.')
audio_chunks = []
for i in range(0, len(audio_array), 30 * sampling_rate):
audio_chunks.append(audio_array[i:i + 30 * sampling_rate])
model_predictions = [do_sample_inference(self, chunk_array, instruction) for chunk_array in tqdm(audio_chunks)]
output = ' '.join(model_predictions)
elif audio_duration > 30:
logger.info('Audio duration is more than 30 seconds. Taking first 30 seconds.')
audio_array = audio_array[:30 * sampling_rate]
output = do_sample_inference(self, audio_array, instruction)
else:
if audio_duration < 1:
logger.info('Audio duration is less than 1 second. Padding the audio to 1 second.')
audio_array = np.pad(audio_array, (0, sampling_rate), 'constant')
output = do_sample_inference(self, audio_array, instruction)
return output