diff --git a/pyproject.toml b/pyproject.toml index ed9ab048e..5e3d4ddce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -135,6 +135,7 @@ speech = [ # all includes all functional dependencies excluding the ones from the "dev" extra all = [ "accelerate>=1.7.0", + "av>=14.0.0", "azure-ai-ml>=1.27.1", "azure-cognitiveservices-speech>=1.44.0", "azureml-mlflow>=1.60.0", diff --git a/pyrit/score/audio_transcript_scorer.py b/pyrit/score/audio_transcript_scorer.py index b0d0ad2a9..5d28cfb91 100644 --- a/pyrit/score/audio_transcript_scorer.py +++ b/pyrit/score/audio_transcript_scorer.py @@ -8,6 +8,8 @@ from abc import ABC from typing import Optional +import av + from pyrit.memory import CentralMemory from pyrit.models import MessagePiece, Score from pyrit.prompt_converter import AzureSpeechAudioToTextConverter @@ -16,6 +18,76 @@ logger = logging.getLogger(__name__) +def _is_compliant_wav(input_path: str, *, sample_rate: int, channels: int) -> bool: + """ + Check if the audio file is already a compliant WAV with the target format. + + Args: + input_path (str): Path to the audio file. + sample_rate (int): Expected sample rate in Hz. + channels (int): Expected number of channels. + + Returns: + bool: True if the file is already compliant, False otherwise. + """ + try: + with av.open(input_path) as container: + if not container.streams.audio: + return False + stream = container.streams.audio[0] + codec_name = stream.codec_context.name + is_pcm_s16 = codec_name == "pcm_s16le" + is_correct_rate = stream.rate == sample_rate + is_correct_channels = stream.channels == channels + return is_pcm_s16 and is_correct_rate and is_correct_channels + except Exception: + return False + + +def _audio_to_wav(input_path: str, *, sample_rate: int, channels: int) -> str: + """ + Convert any audio or video file to a normalised PCM WAV using PyAV. + + If the input is already a compliant WAV (correct sample rate, channels, and codec), + returns the original path without re-encoding. + + Args: + input_path (str): Source audio or video file. + sample_rate (int): Target sample rate in Hz. + channels (int): Target number of channels (1 = mono). + + Returns: + str: Path to the WAV file (original if compliant, otherwise a temporary file). + """ + # Skip conversion if already compliant + if _is_compliant_wav(input_path, sample_rate=sample_rate, channels=channels): + logger.debug(f"Audio file already compliant, skipping conversion: {input_path}") + return input_path + + layout = "mono" if channels == 1 else "stereo" + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: + output_path = tmp.name + + with av.open(input_path) as in_container: + with av.open(output_path, "w", format="wav") as out_container: + out_stream = out_container.add_stream("pcm_s16le", rate=sample_rate, layout=layout) + resampler = av.AudioResampler(format="s16", layout=layout, rate=sample_rate) + + for frame in in_container.decode(audio=0): + for out_frame in resampler.resample(frame): + for packet in out_stream.encode(out_frame): + out_container.mux(packet) + + for out_frame in resampler.resample(None): + for packet in out_stream.encode(out_frame): + out_container.mux(packet) + + for packet in out_stream.encode(None): + out_container.mux(packet) + + return output_path + + class AudioTranscriptHelper(ABC): # noqa: B024 """ Abstract base class for audio scorers that process audio by transcribing and scoring the text. @@ -29,7 +101,6 @@ class AudioTranscriptHelper(ABC): # noqa: B024 _DEFAULT_SAMPLE_RATE = 16000 # 16kHz - Azure Speech optimal rate _DEFAULT_CHANNELS = 1 # Mono - Azure Speech prefers mono _DEFAULT_SAMPLE_WIDTH = 2 # 16-bit audio (2 bytes per sample) - _DEFAULT_EXPORT_PARAMS = ["-acodec", "pcm_s16le"] # 16-bit PCM for best compatibility def __init__( self, @@ -149,7 +220,7 @@ async def _transcribe_audio_async(self, audio_path: str) -> str: logger.info(f"Audio transcription: WAV file size = {file_size} bytes") try: - converter = AzureSpeechAudioToTextConverter() + converter = AzureSpeechAudioToTextConverter(use_entra_auth=True) logger.info("Audio transcription: Starting Azure Speech transcription...") result = await converter.convert_async(prompt=wav_path, input_type="audio_path") logger.info(f"Audio transcription: Result = '{result.output_text}'") @@ -171,25 +242,12 @@ def _ensure_wav_format(self, audio_path: str) -> str: Returns: str: Path to WAV file (original if already WAV, or converted temporary file). - - Raises: - ModuleNotFoundError: If pydub is not installed. """ - try: - from pydub import AudioSegment - except ModuleNotFoundError as e: - logger.error("Could not import pydub. Install it via 'pip install pydub'") - raise e - - audio = AudioSegment.from_file(audio_path) - audio = ( - audio.set_frame_rate(self._DEFAULT_SAMPLE_RATE) - .set_channels(self._DEFAULT_CHANNELS) - .set_sample_width(self._DEFAULT_SAMPLE_WIDTH) + return _audio_to_wav( + audio_path, + sample_rate=self._DEFAULT_SAMPLE_RATE, + channels=self._DEFAULT_CHANNELS, ) - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav: - audio.export(temp_wav.name, format="wav") - return temp_wav.name def _extract_audio_from_video(self, video_path: str) -> Optional[str]: """ @@ -201,9 +259,6 @@ def _extract_audio_from_video(self, video_path: str) -> Optional[str]: Returns: str: a path to the extracted audio file (WAV format) or returns None if extraction fails. - - Raises: - ModuleNotFoundError: If pydub/ffmpeg is not installed. """ return AudioTranscriptHelper.extract_audio_from_video(video_path) @@ -218,57 +273,16 @@ def extract_audio_from_video(video_path: str) -> Optional[str]: Returns: str: a path to the extracted audio file (WAV format) or returns None if extraction fails. - - Raises: - ModuleNotFoundError: If pydub/ffmpeg is not installed. """ try: - from pydub import AudioSegment - except ModuleNotFoundError as e: - logger.error("Could not import pydub. Install it via 'pip install pydub'") - raise e - - try: - # Extract audio from video using pydub (requires ffmpeg) logger.info(f"Extracting audio from video: {video_path}") - audio = AudioSegment.from_file(video_path) - logger.info( - f"Audio extracted: duration={len(audio)}ms, channels={audio.channels}, " - f"sample_width={audio.sample_width}, frame_rate={audio.frame_rate}" + output_path = _audio_to_wav( + video_path, + sample_rate=AudioTranscriptHelper._DEFAULT_SAMPLE_RATE, + channels=AudioTranscriptHelper._DEFAULT_CHANNELS, ) - - # Optimize for Azure Speech recognition: - # Azure Speech works best with 16kHz mono audio (same as Azure TTS output) - if audio.frame_rate != AudioTranscriptHelper._DEFAULT_SAMPLE_RATE: - logger.info( - f"Resampling audio from {audio.frame_rate}Hz to {AudioTranscriptHelper._DEFAULT_SAMPLE_RATE}Hz" - ) - audio = audio.set_frame_rate(AudioTranscriptHelper._DEFAULT_SAMPLE_RATE) - - # Ensure 16-bit audio - if audio.sample_width != AudioTranscriptHelper._DEFAULT_SAMPLE_WIDTH: - logger.info( - f"Converting sample width from {audio.sample_width * 8}-bit" - f" to {AudioTranscriptHelper._DEFAULT_SAMPLE_WIDTH * 8}-bit" - ) - audio = audio.set_sample_width(AudioTranscriptHelper._DEFAULT_SAMPLE_WIDTH) - - # Convert to mono (Azure Speech prefers mono) - if audio.channels > AudioTranscriptHelper._DEFAULT_CHANNELS: - logger.info(f"Converting from {audio.channels} channels to mono") - audio = audio.set_channels(AudioTranscriptHelper._DEFAULT_CHANNELS) - - # Create temporary WAV file with PCM encoding for best compatibility - with tempfile.NamedTemporaryFile(suffix="_video_audio.wav", delete=False) as temp_audio: - audio.export( - temp_audio.name, - format="wav", - parameters=AudioTranscriptHelper._DEFAULT_EXPORT_PARAMS, - ) - logger.info( - f"Audio exported to: {temp_audio.name} (duration={len(audio)}ms, rate={audio.frame_rate}Hz, mono)" - ) - return temp_audio.name + logger.info(f"Audio exported to: {output_path} (rate={AudioTranscriptHelper._DEFAULT_SAMPLE_RATE}Hz, mono)") + return output_path except Exception as e: logger.warning(f"Failed to extract audio from video {video_path}: {e}") return None