Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ speech = [
# all includes all functional dependencies excluding the ones from the "dev" extra
all = [
"accelerate>=1.7.0",
"av>=14.0.0",
"azure-ai-ml>=1.27.1",
"azure-cognitiveservices-speech>=1.44.0",
"azureml-mlflow>=1.60.0",
Expand Down
152 changes: 83 additions & 69 deletions pyrit/score/audio_transcript_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from abc import ABC
from typing import Optional

import av

from pyrit.memory import CentralMemory
from pyrit.models import MessagePiece, Score
from pyrit.prompt_converter import AzureSpeechAudioToTextConverter
Expand All @@ -16,6 +18,76 @@
logger = logging.getLogger(__name__)


def _is_compliant_wav(input_path: str, *, sample_rate: int, channels: int) -> bool:
"""
Check if the audio file is already a compliant WAV with the target format.

Args:
input_path (str): Path to the audio file.
sample_rate (int): Expected sample rate in Hz.
channels (int): Expected number of channels.

Returns:
bool: True if the file is already compliant, False otherwise.
"""
try:
with av.open(input_path) as container:
if not container.streams.audio:
return False
stream = container.streams.audio[0]
codec_name = stream.codec_context.name
is_pcm_s16 = codec_name == "pcm_s16le"
is_correct_rate = stream.rate == sample_rate
is_correct_channels = stream.channels == channels
return is_pcm_s16 and is_correct_rate and is_correct_channels
except Exception:
return False


def _audio_to_wav(input_path: str, *, sample_rate: int, channels: int) -> str:
"""
Convert any audio or video file to a normalised PCM WAV using PyAV.

If the input is already a compliant WAV (correct sample rate, channels, and codec),
returns the original path without re-encoding.

Args:
input_path (str): Source audio or video file.
sample_rate (int): Target sample rate in Hz.
channels (int): Target number of channels (1 = mono).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's okay to check this first, but likely not installed. @romanlutz lmk what you think, but imo we should use something like PyAV as an "all" dependency. It's relatively big, but we'll likely need a lot of audio/video editing and that seems like a good approach.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I suspect consumers may be less excited but having a million classes with a check on whether or not av is installed is silly. I imagine we'll have A LOT of converters with this over time.

Returns:
str: Path to the WAV file (original if compliant, otherwise a temporary file).
"""
# Skip conversion if already compliant
if _is_compliant_wav(input_path, sample_rate=sample_rate, channels=channels):
logger.debug(f"Audio file already compliant, skipping conversion: {input_path}")
return input_path

layout = "mono" if channels == 1 else "stereo"
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
output_path = tmp.name

with av.open(input_path) as in_container:
with av.open(output_path, "w", format="wav") as out_container:
out_stream = out_container.add_stream("pcm_s16le", rate=sample_rate, layout=layout)
resampler = av.AudioResampler(format="s16", layout=layout, rate=sample_rate)

for frame in in_container.decode(audio=0):
for out_frame in resampler.resample(frame):
for packet in out_stream.encode(out_frame):
out_container.mux(packet)

for out_frame in resampler.resample(None):
for packet in out_stream.encode(out_frame):
out_container.mux(packet)

for packet in out_stream.encode(None):
out_container.mux(packet)

return output_path


class AudioTranscriptHelper(ABC): # noqa: B024
"""
Abstract base class for audio scorers that process audio by transcribing and scoring the text.
Expand All @@ -29,7 +101,6 @@ class AudioTranscriptHelper(ABC): # noqa: B024
_DEFAULT_SAMPLE_RATE = 16000 # 16kHz - Azure Speech optimal rate
_DEFAULT_CHANNELS = 1 # Mono - Azure Speech prefers mono
_DEFAULT_SAMPLE_WIDTH = 2 # 16-bit audio (2 bytes per sample)
_DEFAULT_EXPORT_PARAMS = ["-acodec", "pcm_s16le"] # 16-bit PCM for best compatibility

def __init__(
self,
Expand Down Expand Up @@ -149,7 +220,7 @@ async def _transcribe_audio_async(self, audio_path: str) -> str:
logger.info(f"Audio transcription: WAV file size = {file_size} bytes")

try:
converter = AzureSpeechAudioToTextConverter()
converter = AzureSpeechAudioToTextConverter(use_entra_auth=True)
logger.info("Audio transcription: Starting Azure Speech transcription...")
result = await converter.convert_async(prompt=wav_path, input_type="audio_path")
logger.info(f"Audio transcription: Result = '{result.output_text}'")
Expand All @@ -171,25 +242,12 @@ def _ensure_wav_format(self, audio_path: str) -> str:

Returns:
str: Path to WAV file (original if already WAV, or converted temporary file).

Raises:
ModuleNotFoundError: If pydub is not installed.
"""
try:
from pydub import AudioSegment
except ModuleNotFoundError as e:
logger.error("Could not import pydub. Install it via 'pip install pydub'")
raise e

audio = AudioSegment.from_file(audio_path)
audio = (
audio.set_frame_rate(self._DEFAULT_SAMPLE_RATE)
.set_channels(self._DEFAULT_CHANNELS)
.set_sample_width(self._DEFAULT_SAMPLE_WIDTH)
return _audio_to_wav(
audio_path,
sample_rate=self._DEFAULT_SAMPLE_RATE,
channels=self._DEFAULT_CHANNELS,
)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav:
audio.export(temp_wav.name, format="wav")
return temp_wav.name

def _extract_audio_from_video(self, video_path: str) -> Optional[str]:
"""
Expand All @@ -201,9 +259,6 @@ def _extract_audio_from_video(self, video_path: str) -> Optional[str]:
Returns:
str: a path to the extracted audio file (WAV format)
or returns None if extraction fails.

Raises:
ModuleNotFoundError: If pydub/ffmpeg is not installed.
"""
return AudioTranscriptHelper.extract_audio_from_video(video_path)

Expand All @@ -218,57 +273,16 @@ def extract_audio_from_video(video_path: str) -> Optional[str]:
Returns:
str: a path to the extracted audio file (WAV format)
or returns None if extraction fails.

Raises:
ModuleNotFoundError: If pydub/ffmpeg is not installed.
"""
try:
from pydub import AudioSegment
except ModuleNotFoundError as e:
logger.error("Could not import pydub. Install it via 'pip install pydub'")
raise e

try:
# Extract audio from video using pydub (requires ffmpeg)
logger.info(f"Extracting audio from video: {video_path}")
audio = AudioSegment.from_file(video_path)
logger.info(
f"Audio extracted: duration={len(audio)}ms, channels={audio.channels}, "
f"sample_width={audio.sample_width}, frame_rate={audio.frame_rate}"
output_path = _audio_to_wav(
video_path,
sample_rate=AudioTranscriptHelper._DEFAULT_SAMPLE_RATE,
channels=AudioTranscriptHelper._DEFAULT_CHANNELS,
)

# Optimize for Azure Speech recognition:
# Azure Speech works best with 16kHz mono audio (same as Azure TTS output)
if audio.frame_rate != AudioTranscriptHelper._DEFAULT_SAMPLE_RATE:
logger.info(
f"Resampling audio from {audio.frame_rate}Hz to {AudioTranscriptHelper._DEFAULT_SAMPLE_RATE}Hz"
)
audio = audio.set_frame_rate(AudioTranscriptHelper._DEFAULT_SAMPLE_RATE)

# Ensure 16-bit audio
if audio.sample_width != AudioTranscriptHelper._DEFAULT_SAMPLE_WIDTH:
logger.info(
f"Converting sample width from {audio.sample_width * 8}-bit"
f" to {AudioTranscriptHelper._DEFAULT_SAMPLE_WIDTH * 8}-bit"
)
audio = audio.set_sample_width(AudioTranscriptHelper._DEFAULT_SAMPLE_WIDTH)

# Convert to mono (Azure Speech prefers mono)
if audio.channels > AudioTranscriptHelper._DEFAULT_CHANNELS:
logger.info(f"Converting from {audio.channels} channels to mono")
audio = audio.set_channels(AudioTranscriptHelper._DEFAULT_CHANNELS)

# Create temporary WAV file with PCM encoding for best compatibility
with tempfile.NamedTemporaryFile(suffix="_video_audio.wav", delete=False) as temp_audio:
audio.export(
temp_audio.name,
format="wav",
parameters=AudioTranscriptHelper._DEFAULT_EXPORT_PARAMS,
)
logger.info(
f"Audio exported to: {temp_audio.name} (duration={len(audio)}ms, rate={audio.frame_rate}Hz, mono)"
)
return temp_audio.name
logger.info(f"Audio exported to: {output_path} (rate={AudioTranscriptHelper._DEFAULT_SAMPLE_RATE}Hz, mono)")
return output_path
except Exception as e:
logger.warning(f"Failed to extract audio from video {video_path}: {e}")
return None
Loading