diff --git a/google/genai/errors.py b/google/genai/errors.py index 63d9334b9..650aff37a 100644 --- a/google/genai/errors.py +++ b/google/genai/errors.py @@ -271,4 +271,20 @@ class UnknownApiResponseError(ValueError): """Raised when the response from the API cannot be parsed as JSON.""" pass + +class FileProcessingError(Exception): + """Error related to file processing in the API. + + This exception is raised when a file fails to reach the ACTIVE state + required for using it in content generation requests. + """ + + def __init__( + self, message: str, response_json: Optional[dict[str, Any]] = None + ) -> None: + self.message = message + self.details = response_json or {} + super().__init__(message) + + ExperimentalWarning = _common.ExperimentalWarning diff --git a/google/genai/models.py b/google/genai/models.py index ddee88f2d..f9e51e787 100644 --- a/google/genai/models.py +++ b/google/genai/models.py @@ -17,6 +17,7 @@ import json import logging +import time from typing import Any, AsyncIterator, Awaitable, Iterator, Optional, Union from urllib.parse import urlencode @@ -4472,6 +4473,86 @@ def _Video_to_vertex( return to_object +def _ensure_file_active( + api_client: BaseApiClient, + file_obj: types.File, + max_retries: int = 3, + retry_delay_seconds: int = 5, +) -> types.File: + """Ensure a file object is in ACTIVE state before using it in content generation. + + Args: + api_client: The API client to use for requests. + file_obj: The file object to check. + max_retries: Maximum number of retries for checking file state. + retry_delay_seconds: Delay between retries in seconds. + + Returns: + The file object, refreshed if necessary. + + Raises: + errors.FileProcessingError: If the file fails to become ACTIVE within the retry limit. + """ + if hasattr(file_obj, 'name') and file_obj.name and hasattr(file_obj, 'state'): + if file_obj.state == types.FileState.PROCESSING: + logger.info( + f'File {file_obj.name} is in PROCESSING state. Waiting for it to become ACTIVE.' + ) + for attempt in range(max_retries): + time.sleep(retry_delay_seconds) + try: + file_id = file_obj.name.split('/')[-1] + response = api_client.request('GET', f'files/{file_id}', {}, None) + response_dict = {} if not response.body else json.loads(response.body) + refreshed_file = types.File._from_response( + response=response_dict, kwargs={} + ) + logger.info(f'File {file_obj.name} state: {refreshed_file.state}') + if refreshed_file.state == types.FileState.ACTIVE: + return refreshed_file + if refreshed_file.state == types.FileState.FAILED: + error_msg = 'File processing failed' + if hasattr(refreshed_file, 'error') and refreshed_file.error: + error_msg = f'{error_msg}: {refreshed_file.error.message}' + raise errors.FileProcessingError(error_msg) + except errors.FileProcessingError: + raise + except Exception as e: + logger.warning(f'Error refreshing file state: {e}') + logger.warning( + f'File {file_obj.name} did not become ACTIVE after {max_retries} attempts. ' + 'This may cause the content generation to fail.' + ) + return file_obj + + +def _process_contents_for_generation( + api_client: BaseApiClient, + contents: Union[types.ContentListUnion, types.ContentListUnionDict], +) -> list[types.Content]: + """Process the contents, ensuring all File objects are in the ACTIVE state. + + Args: + api_client: The API client to use for requests. + contents: The contents to process. + + Returns: + The processed contents. + """ + processed_contents = t.t_contents(contents) + + def process_file_in_item(item: types.Content) -> types.Content: + if isinstance(item, types.Content): + if hasattr(item, 'parts') and item.parts: + for part in item.parts: + if hasattr(part, 'file_data') and part.file_data: + if isinstance(part.file_data, types.File): + part.file_data = _ensure_file_active(api_client, part.file_data) + return item + + return [process_file_in_item(item) for item in processed_contents] + + class Models(_api_module.BaseModule): def _generate_content( @@ -4481,6 +4562,7 @@ def _generate_content( contents: Union[types.ContentListUnion, types.ContentListUnionDict], config: Optional[types.GenerateContentConfigOrDict] = None, ) -> types.GenerateContentResponse: + contents = _process_contents_for_generation(self._api_client, contents) parameter_model = types._GenerateContentParameters( model=model, contents=contents, @@ -4562,6 +4644,7 @@ def _generate_content_stream( contents: Union[types.ContentListUnion, types.ContentListUnionDict], config: Optional[types.GenerateContentConfigOrDict] = None, ) -> Iterator[types.GenerateContentResponse]: + contents = _process_contents_for_generation(self._api_client, contents) parameter_model = types._GenerateContentParameters( model=model, contents=contents, @@ -6440,6 +6523,7 @@ async def _generate_content( contents: Union[types.ContentListUnion, types.ContentListUnionDict], config: Optional[types.GenerateContentConfigOrDict] = None, ) -> types.GenerateContentResponse: + contents = _process_contents_for_generation(self._api_client, contents) parameter_model = types._GenerateContentParameters( model=model, contents=contents, @@ -6521,6 +6605,7 @@ async def _generate_content_stream( contents: Union[types.ContentListUnion, types.ContentListUnionDict], config: Optional[types.GenerateContentConfigOrDict] = None, ) -> Awaitable[AsyncIterator[types.GenerateContentResponse]]: + contents = _process_contents_for_generation(self._api_client, contents) parameter_model = types._GenerateContentParameters( model=model, contents=contents, diff --git a/google/genai/tests/models/test_file_state_handling.py b/google/genai/tests/models/test_file_state_handling.py new file mode 100644 index 000000000..9838b331c --- /dev/null +++ b/google/genai/tests/models/test_file_state_handling.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for file state handling in content generation.""" + +import json +import unittest +from unittest import mock + +import pytest + +from google.genai import errors +from google.genai import types +from google.genai.models import _ensure_file_active, _process_contents_for_generation +from google.genai.types import FileState + + +def _make_response_body(state: str, error_message: str = None) -> bytes: + """Create a mock API response body for a file.""" + data = { + 'name': 'files/test123', + 'displayName': 'Test File', + 'mimeType': 'video/mp4', + 'state': state, + } + if error_message: + data['error'] = {'message': error_message} + return json.dumps(data).encode() + + +class TestFileStateHandling(unittest.TestCase): + """Test file state handling functionality.""" + + def setUp(self): + """Set up test fixtures.""" + self.api_client = mock.MagicMock() + self.file_obj = types.File( + name='files/test123', + display_name='Test File', + mime_type='video/mp4', + uri='https://example.com/files/test123', + state=types.FileState.PROCESSING, + ) + + def test_ensure_file_active_with_processing_file(self): + """Test that _ensure_file_active waits for a PROCESSING file to become ACTIVE.""" + response_mock = mock.MagicMock() + response_mock.body = _make_response_body('ACTIVE') + self.api_client.request.return_value = response_mock + + result = _ensure_file_active( + self.api_client, self.file_obj, max_retries=1, retry_delay_seconds=0 + ) + + self.api_client.request.assert_called_once_with( + 'GET', 'files/test123', {}, None + ) + self.assertEqual(result.state, types.FileState.ACTIVE) + + def test_ensure_file_active_with_failed_file(self): + """Test that _ensure_file_active raises FileProcessingError for a FAILED file.""" + response_mock = mock.MagicMock() + response_mock.body = _make_response_body( + 'FAILED', error_message='File processing failed' + ) + self.api_client.request.return_value = response_mock + + with pytest.raises(errors.FileProcessingError) as excinfo: + _ensure_file_active( + self.api_client, self.file_obj, max_retries=1, retry_delay_seconds=0 + ) + + assert 'File processing failed' in str(excinfo.value) + + def test_ensure_file_active_with_retries_exhausted(self): + """Test that _ensure_file_active returns original file after exhausting retries.""" + response_mock = mock.MagicMock() + response_mock.body = _make_response_body('PROCESSING') + self.api_client.request.return_value = response_mock + + result = _ensure_file_active( + self.api_client, self.file_obj, max_retries=2, retry_delay_seconds=0 + ) + + self.assertEqual(self.api_client.request.call_count, 2) + self.assertEqual(result, self.file_obj) + self.assertEqual(result.state, types.FileState.PROCESSING) + + def test_ensure_file_active_with_already_active_file(self): + """Test that _ensure_file_active returns immediately for an already ACTIVE file.""" + active_file = types.File( + name='files/active123', + display_name='Active File', + mime_type='video/mp4', + state=types.FileState.ACTIVE, + ) + + result = _ensure_file_active( + self.api_client, active_file, max_retries=1, retry_delay_seconds=0 + ) + + self.api_client.request.assert_not_called() + self.assertEqual(result, active_file) + self.assertEqual(result.state, types.FileState.ACTIVE) + + +class TestProcessContentsFunction(unittest.TestCase): + """Test the _process_contents_for_generation function.""" + + def setUp(self): + """Set up test fixtures.""" + self.api_client = mock.MagicMock() + self.processing_file = types.File( + name='files/processing123', + display_name='Processing File', + mime_type='video/mp4', + uri='https://example.com/files/processing123', + state=types.FileState.PROCESSING, + ) + self.active_file = types.File( + name='files/active123', + display_name='Active File', + mime_type='video/mp4', + uri='https://example.com/files/active123', + state=types.FileState.ACTIVE, + ) + + def test_process_contents_with_files(self): + """Test that _process_contents_for_generation can handle various file scenarios.""" + file_in_list = [self.processing_file, 'Process this file'] + file_in_parts = types.Content( + role='user', + parts=[types.Part(text="Here's a video:"), self.processing_file], + ) + multiple_files = [ + types.Content( + role='user', + parts=[types.Part(text='First video:'), self.processing_file], + ), + types.Content( + role='user', + parts=[types.Part(text='Second video:'), self.active_file], + ), + ] + + with mock.patch( + 'google.genai.models._ensure_file_active', side_effect=lambda client, f: f + ): + for test_content in [file_in_list, file_in_parts, multiple_files]: + with mock.patch( + 'google.genai.models.t.t_contents', + return_value=( + test_content + if isinstance(test_content, list) + else [test_content] + ), + ): + result = _process_contents_for_generation( + self.api_client, test_content + ) + self.assertTrue(result) + + +if __name__ == '__main__': + unittest.main()