-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
280 lines (237 loc) · 8.95 KB
/
main.py
File metadata and controls
280 lines (237 loc) · 8.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
#!/usr/bin/env python3
"""
视频转文字应用
基于FunASR实现本地视频文件的语音识别转录
"""
import os
import sys
import argparse
import tempfile
from pathlib import Path
from video_processor import (
extract_audio_from_video,
validate_video_file,
validate_audio_file,
get_supported_video_formats
)
from asr_transcriber import ASRTranscriber
def print_banner():
"""打印应用横幅"""
banner = """
╔══════════════════════════════════════════════════════════════╗
║ 视频转文字工具 ║
║ 基于FunASR语音识别 ║
║ ║
║ 支持格式: MP4, AVI, MKV, MOV, WMV, FLV, WEBM, M4V ║
║ 输出格式: TXT, SRT, VTT, JSON ║
╚══════════════════════════════════════════════════════════════╝
"""
print(banner)
def setup_argparse():
"""设置命令行参数解析"""
parser = argparse.ArgumentParser(
description="视频转文字工具 - 基于FunASR语音识别",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
使用示例:
python main.py input.mp4 # 基础转录,输出文本
python main.py input.mp4 -o output.txt # 指定输出文件
python main.py input.mp4 -f srt -l zh # 输出SRT字幕,指定中文
python main.py input.mp4 -f json --timestamps # 输出JSON格式,包含时间戳
python main.py input.mp4 --model paraformer-zh # 使用指定模型
"""
)
parser.add_argument(
"input",
help="输入视频文件路径"
)
parser.add_argument(
"-o", "--output",
help="输出文件路径(可选,默认根据输入文件名生成)"
)
parser.add_argument(
"-f", "--format",
choices=["text", "srt", "vtt", "json"],
default="text",
help="输出格式 (默认: text)"
)
parser.add_argument(
"-l", "--language",
default="auto",
help="语言代码 (默认: auto,支持: zh, en, ja, ko, es, fr, de, it, pt, ru)"
)
parser.add_argument(
"--model",
default="iic/SenseVoiceSmall",
help="ASR模型名称 (默认: iic/SenseVoiceSmall)"
)
parser.add_argument(
"--vad-model",
default="fsmn-vad",
help="VAD模型名称 (默认: fsmn-vad)"
)
parser.add_argument(
"--device",
choices=["cpu", "cuda", "mps", "auto"],
default="auto",
help="计算设备 (默认: auto,自动检测最佳设备)"
)
parser.add_argument(
"--timestamps",
action="store_true",
help="输出包含时间戳信息"
)
parser.add_argument(
"--keep-audio",
action="store_true",
help="保留提取的音频文件"
)
parser.add_argument(
"--verbose",
action="store_true",
help="显示详细信息"
)
parser.add_argument(
"--max-length",
type=int,
default=5,
help="VAD片段合并最大长度(秒),用于控制字幕片段长度 (默认: 5秒)"
)
parser.add_argument(
"--batch-size",
type=int,
default=600,
help="批处理大小,GPU加速时可适当增大 (默认: 600)"
)
parser.add_argument(
"--vad-off",
action="store_true",
help="禁用VAD语音活动检测 (默认: 启用VAD)"
)
return parser
def generate_output_path(input_path, output_format):
"""生成输出文件路径"""
input_path = Path(input_path)
base_name = input_path.stem
format_extensions = {
"text": ".txt",
"srt": ".srt",
"vtt": ".vtt",
"json": ".json"
}
extension = format_extensions.get(output_format, ".txt")
return str(input_path.parent / f"{base_name}_transcription{extension}")
def main():
"""主函数"""
parser = setup_argparse()
args = parser.parse_args()
print_banner()
# 验证输入文件
if args.verbose:
print(f"验证输入文件: {args.input}")
video_validation = validate_video_file(args.input)
if not video_validation["valid"]:
print(f"❌ 输入文件验证失败: {video_validation['error']}")
return 1
if args.verbose:
print("✅ 视频文件验证通过")
print(f" 时长: {video_validation['duration']:.2f}秒")
print(f" 分辨率: {video_validation['size']}")
print(f" 帧率: {video_validation['fps']:.2f}fps")
# 提取音频
temp_audio_path = None
try:
print("🎵 正在从视频中提取音频...")
temp_audio_path = extract_audio_from_video(args.input)
# 验证音频文件
audio_validation = validate_audio_file(temp_audio_path)
if not audio_validation["valid"]:
print(f"❌ 音频提取失败: {audio_validation['error']}")
return 1
if args.verbose:
print("✅ 音频提取完成")
print(f" 时长: {audio_validation['duration']:.2f}秒")
print(f" 采样率: {audio_validation['sample_rate']}Hz")
print(f" 声道数: {audio_validation['channels']}")
# 初始化ASR转录器
print("🤖 正在初始化语音识别模型...")
transcriber = ASRTranscriber(
model_name=args.model,
vad_model=args.vad_model,
device=args.device,
enable_vad=not args.vad_off
)
# 执行转录
print("📝 正在进行语音识别转录...")
if args.timestamps:
result = transcriber.transcribe_with_timestamps(
temp_audio_path,
args.language,
max_length=args.max_length,
batch_size=args.batch_size
)
else:
result = transcriber.transcribe_audio(
temp_audio_path,
args.language,
max_length=args.max_length,
batch_size=args.batch_size
)
if not result["success"]:
print(f"❌ 转录失败: {result['error']}")
return 1
# 格式化输出
formatted_output = transcriber.format_transcription_output(result, args.format)
# 确定输出路径
if args.output:
output_path = args.output
else:
output_path = generate_output_path(args.input, args.format)
# 保存结果
try:
output_dir = os.path.dirname(output_path)
if output_dir: # 只有当输出路径包含目录时才创建
os.makedirs(output_dir, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
f.write(formatted_output)
print(f"✅ 转录完成!")
print(f"📄 输出文件: {output_path}")
if args.verbose:
print(f" 语言: {result.get('language', '未知')}")
print(f" 文本长度: {len(result['text'])}字符")
if 'segments' in result:
print(f" 分段数: {len(result['segments'])}")
# 显示部分转录内容
preview_text = result['text'][:200]
if len(result['text']) > 200:
preview_text += "..."
print(f"\n📝 转录预览:\n{preview_text}")
except Exception as e:
print(f"❌ 保存输出文件失败: {str(e)}")
return 1
except Exception as e:
print(f"❌ 处理过程中发生错误: {str(e)}")
return 1
finally:
# 清理临时文件
if temp_audio_path and os.path.exists(temp_audio_path):
if args.keep_audio:
audio_output_path = generate_output_path(args.input, "wav").replace(".txt", ".wav")
os.rename(temp_audio_path, audio_output_path)
if args.verbose:
print(f"🎵 音频文件已保存: {audio_output_path}")
else:
os.unlink(temp_audio_path)
if args.verbose:
print("🗑️ 临时音频文件已清理")
return 0
if __name__ == "__main__":
try:
exit_code = main()
sys.exit(exit_code)
except KeyboardInterrupt:
print("\n⚠️ 用户中断操作")
sys.exit(1)
except Exception as e:
print(f"❌ 程序异常退出: {str(e)}")
sys.exit(1)