diff --git a/Dockerfile.cloudrun b/Dockerfile.cloudrun index e0caab0..4fe9c2e 100644 --- a/Dockerfile.cloudrun +++ b/Dockerfile.cloudrun @@ -52,6 +52,17 @@ RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1 && curl -sS https://bootstrap.pypa.io/get-pip.py | python3.12 \ && python3 -m pip install --no-cache-dir --upgrade pip setuptools wheel +# Install PyTorch with CUDA 12.6 support BEFORE audio-separator[gpu]. +# Without this, `pip install ".[gpu]"` pulls the default CPU-only PyTorch wheel +# from PyPI and Separator silently falls back to CPU (~10× slower). +# Cloud Run L4 GPUs have NVIDIA driver 570 (supports up to CUDA 12.8), so cu126 +# works. cu130 would fail with "NVIDIA driver is too old". +# Installing torch first means audio-separator[gpu] sees it already satisfied. +RUN pip install --no-cache-dir \ + torch==2.6.0+cu126 \ + torchvision==0.21.0+cu126 \ + --index-url https://download.pytorch.org/whl/cu126 + # Install audio-separator with GPU support and API dependencies COPY . /tmp/audio-separator-src RUN cd /tmp/audio-separator-src \ diff --git a/audio_separator/remote/README.md b/audio_separator/remote/README.md index 4dc4c3a..7622663 100644 --- a/audio_separator/remote/README.md +++ b/audio_separator/remote/README.md @@ -200,6 +200,22 @@ audio-separator-remote separate audio.wav \ --vr_aggression 10 ``` +**Large files (>30 MiB):** + +When the deployment runs on Cloud Run, request bodies are capped at 32 MiB. For larger inputs the CLI automatically uploads the file to GCS first and tells the server to fetch from `gs://...`, bypassing the limit. This is transparent — the same `separate` command works for any file size: + +```bash +# Same command, file size detected automatically +audio-separator-remote separate big_song.wav --preset vocal_balanced +``` + +Requirements when the GCS path activates: +- Application Default Credentials on the laptop (`gcloud auth application-default login`) +- Write permission on the input bucket (defaults to `nomadkaraoke-audio-separator-outputs`) +- The Cloud Run service account needs read permission on the same bucket (it already does for the default bucket) + +Override the bucket with `--gcs-bucket my-bucket` or by setting `AUDIO_SEPARATOR_GCS_INPUT_BUCKET`. Uploaded inputs are deleted after the job finishes (success or failure); the bucket's lifecycle policy is the safety net if cleanup fails. + **Check job status:** ```bash @@ -236,6 +252,7 @@ audio-separator-remote --version **Global Options:** - `--api_url`: Override the API URL +- `--gcs-bucket`: Bucket used for the >30 MiB upload fallback (env: `AUDIO_SEPARATOR_GCS_INPUT_BUCKET`, default: `nomadkaraoke-audio-separator-outputs`) - `--timeout`: Set timeout for polling (default: 600 seconds) - `--poll_interval`: Set polling interval (default: 10 seconds) - `--debug`: Enable debug logging diff --git a/audio_separator/remote/cli.py b/audio_separator/remote/cli.py index 6dfb2ce..3ea625b 100644 --- a/audio_separator/remote/cli.py +++ b/audio_separator/remote/cli.py @@ -5,10 +5,66 @@ import os import sys import time +import uuid from importlib import metadata from audio_separator.remote import AudioSeparatorAPIClient +# Cloud Run hard-limits request bodies to 32 MiB. Use 30 MiB threshold so a +# little request overhead won't push us over. Larger files go via GCS. +GCS_UPLOAD_THRESHOLD_BYTES = 30 * 1024 * 1024 +DEFAULT_GCS_INPUT_BUCKET = "nomadkaraoke-audio-separator-outputs" +GCS_INPUT_PREFIX = "cli-uploads" + + +def upload_to_gcs(file_path: str, bucket_name: str, logger: logging.Logger) -> str: + """Upload a local file to GCS and return its gs:// URI. + + Requires `google-cloud-storage` and Application Default Credentials + (run `gcloud auth application-default login` on the laptop). + """ + try: + from google.cloud import storage + except ImportError as e: + raise RuntimeError( + "google-cloud-storage is required to upload files larger than " + f"{GCS_UPLOAD_THRESHOLD_BYTES // (1024 * 1024)} MiB. " + "Install it with: pip install google-cloud-storage" + ) from e + + filename = os.path.basename(file_path) + blob_path = f"{GCS_INPUT_PREFIX}/{uuid.uuid4()}-{filename}" + gcs_uri = f"gs://{bucket_name}/{blob_path}" + + size_mib = os.path.getsize(file_path) / (1024 * 1024) + logger.info(f"Uploading {size_mib:.1f} MiB to {gcs_uri} (server fetches from GCS, bypasses Cloud Run 32 MiB limit)") + + client = storage.Client() + bucket = client.bucket(bucket_name) + blob = bucket.blob(blob_path) + blob.upload_from_filename(file_path) + + logger.info(f"Upload complete: {gcs_uri}") + return gcs_uri + + +def delete_from_gcs(gcs_uri: str, logger: logging.Logger) -> None: + """Best-effort delete of a GCS object. Logs but doesn't raise on failure.""" + try: + from google.cloud import storage + + without_prefix = gcs_uri[len("gs://"):] + slash_idx = without_prefix.index("/") + bucket_name = without_prefix[:slash_idx] + blob_path = without_prefix[slash_idx + 1:] + + client = storage.Client() + bucket = client.bucket(bucket_name) + bucket.blob(blob_path).delete() + logger.info(f"Cleaned up uploaded input: {gcs_uri}") + except Exception as e: + logger.warning(f"Failed to delete {gcs_uri}: {e} (bucket lifecycle will reclaim it)") + def main(): """Main entry point for the remote CLI.""" @@ -104,6 +160,13 @@ def main(): parser.add_argument("-d", "--debug", action="store_true", help="Enable debug logging") parser.add_argument("--log_level", default="info", help="Log level (default: info)") parser.add_argument("--api_url", help="API URL (overrides AUDIO_SEPARATOR_API_URL env var)") + parser.add_argument( + "--gcs-bucket", + help=( + f"GCS bucket for uploading files >{GCS_UPLOAD_THRESHOLD_BYTES // (1024 * 1024)} MiB " + f"(overrides AUDIO_SEPARATOR_GCS_INPUT_BUCKET env var, default: {DEFAULT_GCS_INPUT_BUCKET})" + ), + ) args = parser.parse_args() @@ -145,9 +208,12 @@ def main(): # Create API client api_client = AudioSeparatorAPIClient(api_url, logger) + # Resolve GCS bucket for large-file uploads + gcs_bucket = args.gcs_bucket or os.environ.get("AUDIO_SEPARATOR_GCS_INPUT_BUCKET", DEFAULT_GCS_INPUT_BUCKET) + # Handle commands if args.command == "separate": - handle_separate_command(args, api_client, logger) + handle_separate_command(args, api_client, logger, gcs_bucket) elif args.command == "status": handle_status_command(args, api_client, logger) elif args.command == "models": @@ -159,14 +225,35 @@ def main(): sys.exit(1) -def handle_separate_command(args, api_client: AudioSeparatorAPIClient, logger: logging.Logger): +def handle_separate_command(args, api_client: AudioSeparatorAPIClient, logger: logging.Logger, gcs_bucket: str): """Handle the separate command.""" for audio_file in args.audio_files: - logger.info(f"Uploading '{audio_file}' to audio separator...") + logger.info(f"Processing '{audio_file}'...") + + # Decide upload path: small files go via multipart POST, large files via GCS + # to bypass the Cloud Run 32 MiB request body limit. + uploaded_gcs_uri = None + try: + file_size = os.path.getsize(audio_file) + use_gcs = file_size > GCS_UPLOAD_THRESHOLD_BYTES + except OSError as e: + logger.error(f"❌ Cannot read '{audio_file}': {e}") + continue try: + if use_gcs: + logger.info( + f"File is {file_size / (1024 * 1024):.1f} MiB (>{GCS_UPLOAD_THRESHOLD_BYTES // (1024 * 1024)} MiB), " + "uploading via GCS" + ) + uploaded_gcs_uri = upload_to_gcs(audio_file, gcs_bucket, logger) + source_kwargs = {"file_path": None, "gcs_uri": uploaded_gcs_uri} + else: + source_kwargs = {"file_path": audio_file, "gcs_uri": None} + # Prepare parameters for separation kwargs = { + **source_kwargs, "model": args.model, "models": args.models, "preset": args.preset, @@ -213,7 +300,7 @@ def handle_separate_command(args, api_client: AudioSeparatorAPIClient, logger: l } # Use the convenience method that handles everything - result = api_client.separate_audio_and_wait(audio_file, **kwargs) + result = api_client.separate_audio_and_wait(**kwargs) if result["status"] == "completed": if "downloaded_files" in result: @@ -227,6 +314,9 @@ def handle_separate_command(args, api_client: AudioSeparatorAPIClient, logger: l except Exception as e: logger.error(f"❌ Error processing '{audio_file}': {e}") + finally: + if uploaded_gcs_uri: + delete_from_gcs(uploaded_gcs_uri, logger) def handle_status_command(args, api_client: AudioSeparatorAPIClient, logger: logging.Logger): diff --git a/pyproject.toml b/pyproject.toml index 89e61a4..a9de5ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "audio-separator" -version = "0.44.1" +version = "0.44.2" description = "Easy to use audio stem separation, using various models from UVR trained primarily by @Anjok07" authors = ["Andrew Beveridge "] license = "MIT" diff --git a/tests/integration/test_remote_api_integration.py b/tests/integration/test_remote_api_integration.py index db35633..ab78ead 100644 --- a/tests/integration/test_remote_api_integration.py +++ b/tests/integration/test_remote_api_integration.py @@ -487,7 +487,7 @@ def test_cli_separate_command_integration(self, mock_client_class, test_audio_fi logger = Mock() # Execute the command - handle_separate_command(args, mock_client, logger) + handle_separate_command(args, mock_client, logger, "test-bucket") # Verify the API client method was called mock_client.separate_audio_and_wait.assert_called_once() diff --git a/tests/unit/test_remote_cli.py b/tests/unit/test_remote_cli.py index 49c4aa0..25dec46 100644 --- a/tests/unit/test_remote_cli.py +++ b/tests/unit/test_remote_cli.py @@ -232,13 +232,14 @@ def test_handle_separate_command_success(self, mock_api_client, mock_logger, moc "downloaded_files": ["output1.wav", "output2.wav"] } - handle_separate_command(args, mock_api_client, mock_logger) + handle_separate_command(args, mock_api_client, mock_logger, "test-bucket") # Verify API client was called with correct parameters mock_api_client.separate_audio_and_wait.assert_called_once() call_args = mock_api_client.separate_audio_and_wait.call_args - assert call_args[0][0] == mock_audio_file # First positional argument should be the audio file kwargs = call_args[1] + assert kwargs["file_path"] == mock_audio_file # Small file uses upload path + assert kwargs["gcs_uri"] is None assert kwargs["model"] == "test_model.ckpt" assert kwargs["timeout"] == 600 assert kwargs["download"] is True @@ -266,7 +267,7 @@ def test_handle_separate_command_with_multiple_models(self, mock_api_client, moc "downloaded_files": ["output1.wav", "output2.wav"] } - handle_separate_command(args, mock_api_client, mock_logger) + handle_separate_command(args, mock_api_client, mock_logger, "test-bucket") call_args = mock_api_client.separate_audio_and_wait.call_args kwargs = call_args[1] @@ -293,7 +294,7 @@ def test_handle_separate_command_error(self, mock_api_client, mock_logger, mock_ "error": "Processing failed" } - handle_separate_command(args, mock_api_client, mock_logger) + handle_separate_command(args, mock_api_client, mock_logger, "test-bucket") # Verify error was logged mock_logger.error.assert_called() @@ -314,11 +315,114 @@ def test_handle_separate_command_exception(self, mock_api_client, mock_logger, m mock_api_client.separate_audio_and_wait.side_effect = Exception("API error") - handle_separate_command(args, mock_api_client, mock_logger) + handle_separate_command(args, mock_api_client, mock_logger, "test-bucket") # Verify error was logged mock_logger.error.assert_called() + def _make_separate_args(self, audio_file, **overrides): + """Build a Mock args object for handle_separate_command with sensible defaults.""" + args = Mock() + args.audio_files = [audio_file] + args.model = "test_model.ckpt" + args.models = None + args.preset = None + args.timeout = 600 + args.poll_interval = 10 + for attr in ['output_format', 'output_bitrate', 'normalization', 'amplification', 'single_stem', + 'invert_spect', 'sample_rate', 'use_soundfile', 'use_autocast', 'custom_output_names', + 'mdx_segment_size', 'mdx_overlap', 'mdx_batch_size', 'mdx_hop_length', 'mdx_enable_denoise', + 'vr_batch_size', 'vr_window_size', 'vr_aggression', 'vr_enable_tta', 'vr_high_end_process', + 'vr_enable_post_process', 'vr_post_process_threshold', 'demucs_segment_size', 'demucs_shifts', + 'demucs_overlap', 'demucs_segments_enabled', 'mdxc_segment_size', 'mdxc_override_model_segment_size', + 'mdxc_overlap', 'mdxc_batch_size', 'mdxc_pitch_shift']: + setattr(args, attr, None) + for k, v in overrides.items(): + setattr(args, k, v) + return args + + @patch('audio_separator.remote.cli.delete_from_gcs') + @patch('audio_separator.remote.cli.upload_to_gcs') + @patch('audio_separator.remote.cli.os.path.getsize') + def test_handle_separate_large_file_uploads_to_gcs( + self, mock_getsize, mock_upload, mock_delete, mock_api_client, mock_logger, mock_audio_file + ): + """Files over the threshold go via GCS, not multipart upload.""" + from audio_separator.remote.cli import GCS_UPLOAD_THRESHOLD_BYTES + + mock_getsize.return_value = GCS_UPLOAD_THRESHOLD_BYTES + 1 + mock_upload.return_value = "gs://test-bucket/cli-uploads/abc-song.wav" + mock_api_client.separate_audio_and_wait.return_value = { + "status": "completed", "downloaded_files": ["song_(Vocals).wav"], + } + + args = self._make_separate_args(mock_audio_file) + handle_separate_command(args, mock_api_client, mock_logger, "test-bucket") + + mock_upload.assert_called_once_with(mock_audio_file, "test-bucket", mock_logger) + kwargs = mock_api_client.separate_audio_and_wait.call_args[1] + assert kwargs["file_path"] is None + assert kwargs["gcs_uri"] == "gs://test-bucket/cli-uploads/abc-song.wav" + mock_delete.assert_called_once_with("gs://test-bucket/cli-uploads/abc-song.wav", mock_logger) + + @patch('audio_separator.remote.cli.delete_from_gcs') + @patch('audio_separator.remote.cli.upload_to_gcs') + @patch('audio_separator.remote.cli.os.path.getsize') + def test_handle_separate_cleanup_runs_when_separation_fails( + self, mock_getsize, mock_upload, mock_delete, mock_api_client, mock_logger, mock_audio_file + ): + """If separation raises after GCS upload, we still clean up the uploaded object.""" + from audio_separator.remote.cli import GCS_UPLOAD_THRESHOLD_BYTES + + mock_getsize.return_value = GCS_UPLOAD_THRESHOLD_BYTES + 1 + mock_upload.return_value = "gs://test-bucket/cli-uploads/abc-song.wav" + mock_api_client.separate_audio_and_wait.side_effect = Exception("separator died") + + args = self._make_separate_args(mock_audio_file) + handle_separate_command(args, mock_api_client, mock_logger, "test-bucket") + + mock_delete.assert_called_once_with("gs://test-bucket/cli-uploads/abc-song.wav", mock_logger) + + @patch('audio_separator.remote.cli.delete_from_gcs') + @patch('audio_separator.remote.cli.upload_to_gcs') + @patch('audio_separator.remote.cli.os.path.getsize') + def test_handle_separate_upload_failure_skips_separation_and_cleanup( + self, mock_getsize, mock_upload, mock_delete, mock_api_client, mock_logger, mock_audio_file + ): + """If GCS upload fails, we don't call the API and don't try to clean up something we never created.""" + from audio_separator.remote.cli import GCS_UPLOAD_THRESHOLD_BYTES + + mock_getsize.return_value = GCS_UPLOAD_THRESHOLD_BYTES + 1 + mock_upload.side_effect = RuntimeError("ADC creds missing") + + args = self._make_separate_args(mock_audio_file) + handle_separate_command(args, mock_api_client, mock_logger, "test-bucket") + + mock_api_client.separate_audio_and_wait.assert_not_called() + mock_delete.assert_not_called() + mock_logger.error.assert_called() + + @patch('audio_separator.remote.cli.upload_to_gcs') + @patch('audio_separator.remote.cli.os.path.getsize') + def test_handle_separate_small_file_skips_gcs( + self, mock_getsize, mock_upload, mock_api_client, mock_logger, mock_audio_file + ): + """Files at or below the threshold never touch GCS.""" + from audio_separator.remote.cli import GCS_UPLOAD_THRESHOLD_BYTES + + mock_getsize.return_value = GCS_UPLOAD_THRESHOLD_BYTES + mock_api_client.separate_audio_and_wait.return_value = { + "status": "completed", "downloaded_files": [], + } + + args = self._make_separate_args(mock_audio_file) + handle_separate_command(args, mock_api_client, mock_logger, "test-bucket") + + mock_upload.assert_not_called() + kwargs = mock_api_client.separate_audio_and_wait.call_args[1] + assert kwargs["file_path"] == mock_audio_file + assert kwargs["gcs_uri"] is None + def test_handle_status_command_success(self, mock_api_client, mock_logger): """Test successful status command handling.""" args = Mock() @@ -441,4 +545,107 @@ def test_handle_download_command_exception(self, mock_api_client, mock_logger): handle_download_command(args, mock_api_client, mock_logger) - mock_logger.error.assert_called() \ No newline at end of file + mock_logger.error.assert_called() + + +class TestGCSHelpers: + """Direct tests for the GCS upload/delete helper functions.""" + + def test_upload_to_gcs_builds_expected_blob_path_and_uri(self, mock_audio_file, mock_logger): + """Object key uses cli-uploads/{uuid}-{filename} so collisions are impossible.""" + from audio_separator.remote.cli import upload_to_gcs, GCS_INPUT_PREFIX + + from google.cloud import storage as gcs_storage + with patch.object(gcs_storage, 'Client') as mock_client_cls: + mock_blob = MagicMock() + mock_bucket = MagicMock() + mock_bucket.blob.return_value = mock_blob + mock_client = MagicMock() + mock_client.bucket.return_value = mock_bucket + mock_client_cls.return_value = mock_client + + gcs_uri = upload_to_gcs(mock_audio_file, "my-bucket", mock_logger) + + mock_client.bucket.assert_called_once_with("my-bucket") + blob_path = mock_bucket.blob.call_args[0][0] + filename = os.path.basename(mock_audio_file) + assert blob_path.startswith(f"{GCS_INPUT_PREFIX}/") + assert blob_path.endswith(f"-{filename}") + assert gcs_uri == f"gs://my-bucket/{blob_path}" + mock_blob.upload_from_filename.assert_called_once_with(mock_audio_file) + + def test_upload_to_gcs_raises_clear_error_when_lib_missing(self, mock_audio_file, mock_logger): + """If google-cloud-storage isn't installed, surface a helpful install hint.""" + from audio_separator.remote.cli import upload_to_gcs + + real_import = __import__ + + def fake_import(name, *args, **kwargs): + if name == "google.cloud" or name.startswith("google.cloud"): + raise ImportError("no google.cloud") + return real_import(name, *args, **kwargs) + + with patch('builtins.__import__', side_effect=fake_import): + with pytest.raises(RuntimeError, match="google-cloud-storage is required"): + upload_to_gcs(mock_audio_file, "my-bucket", mock_logger) + + def test_delete_from_gcs_parses_uri_and_deletes_blob(self, mock_logger): + """Verify gs://bucket/path/to/object splits into the correct bucket + blob path.""" + from audio_separator.remote.cli import delete_from_gcs + from google.cloud import storage as gcs_storage + + with patch.object(gcs_storage, 'Client') as mock_client_cls: + mock_blob = MagicMock() + mock_bucket = MagicMock() + mock_bucket.blob.return_value = mock_blob + mock_client = MagicMock() + mock_client.bucket.return_value = mock_bucket + mock_client_cls.return_value = mock_client + + delete_from_gcs("gs://my-bucket/cli-uploads/abc-def/lying.wav", mock_logger) + + mock_client.bucket.assert_called_once_with("my-bucket") + mock_bucket.blob.assert_called_once_with("cli-uploads/abc-def/lying.wav") + mock_blob.delete.assert_called_once() + + def test_delete_from_gcs_swallows_errors(self, mock_logger): + """Best-effort delete: a failure is logged as a warning, never raised.""" + from audio_separator.remote.cli import delete_from_gcs + from google.cloud import storage as gcs_storage + + with patch.object(gcs_storage, 'Client', side_effect=Exception("network down")): + delete_from_gcs("gs://my-bucket/some/object.wav", mock_logger) + + mock_logger.warning.assert_called_once() + + +class TestBucketResolution: + """Bucket resolution priority: --gcs-bucket flag > AUDIO_SEPARATOR_GCS_INPUT_BUCKET env > default.""" + + @patch('sys.argv', ['audio-separator-remote', '--gcs-bucket', 'flag-bucket', 'separate', 'fake.wav']) + @patch('audio_separator.remote.cli.AudioSeparatorAPIClient') + @patch('audio_separator.remote.cli.handle_separate_command') + @patch.dict(os.environ, {'AUDIO_SEPARATOR_API_URL': 'https://test', 'AUDIO_SEPARATOR_GCS_INPUT_BUCKET': 'env-bucket'}) + def test_flag_wins_over_env(self, mock_handle_separate, mock_client_class): + main() + # gcs_bucket is the 4th positional arg to handle_separate_command + assert mock_handle_separate.call_args[0][3] == 'flag-bucket' + + @patch('sys.argv', ['audio-separator-remote', 'separate', 'fake.wav']) + @patch('audio_separator.remote.cli.AudioSeparatorAPIClient') + @patch('audio_separator.remote.cli.handle_separate_command') + @patch.dict(os.environ, {'AUDIO_SEPARATOR_API_URL': 'https://test', 'AUDIO_SEPARATOR_GCS_INPUT_BUCKET': 'env-bucket'}) + def test_env_used_when_flag_absent(self, mock_handle_separate, mock_client_class): + main() + assert mock_handle_separate.call_args[0][3] == 'env-bucket' + + @patch('sys.argv', ['audio-separator-remote', 'separate', 'fake.wav']) + @patch('audio_separator.remote.cli.AudioSeparatorAPIClient') + @patch('audio_separator.remote.cli.handle_separate_command') + def test_default_used_when_neither_flag_nor_env(self, mock_handle_separate, mock_client_class, monkeypatch): + from audio_separator.remote.cli import DEFAULT_GCS_INPUT_BUCKET + + monkeypatch.setenv('AUDIO_SEPARATOR_API_URL', 'https://test') + monkeypatch.delenv('AUDIO_SEPARATOR_GCS_INPUT_BUCKET', raising=False) + main() + assert mock_handle_separate.call_args[0][3] == DEFAULT_GCS_INPUT_BUCKET \ No newline at end of file