diff --git a/.gitignore b/.gitignore index 6204778..d8dad13 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ __pycache__/ .pytest_cache/ .ruff_cache/ /data +.idea \ No newline at end of file diff --git a/app/backends/base.py b/app/backends/base.py index 8227752..8c50f7d 100644 --- a/app/backends/base.py +++ b/app/backends/base.py @@ -177,6 +177,7 @@ async def launch( work_dir: str, configfile: str | None, snakemake_args: list[str] | None = None, + env_vars: dict[str, str] | None = None, ) -> None: """ Write the .run.sh wrapper script, launch it via nohup/disown, diff --git a/app/backends/local.py b/app/backends/local.py index 73a99a3..ee951f9 100644 --- a/app/backends/local.py +++ b/app/backends/local.py @@ -88,11 +88,12 @@ async def launch( work_dir: str, configfile: str | None, snakemake_args: list[str] | None = None, + env_vars: dict[str, str] | None = None, ) -> None: self._validate_configfile(configfile) snkmt_db_path = str(Path(work_dir).resolve() / SNKMT_DB_FILENAME) wrapper_content = build_wrapper_script( - self._config.pixi_path, snkmt_db_path, configfile, snakemake_args + self._config.pixi_path, snkmt_db_path, configfile, snakemake_args, env_vars ) wd = Path(work_dir) diff --git a/app/backends/slurm_ssh.py b/app/backends/slurm_ssh.py index f2a4879..5755f19 100644 --- a/app/backends/slurm_ssh.py +++ b/app/backends/slurm_ssh.py @@ -284,12 +284,13 @@ async def launch( work_dir: str, configfile: str | None, snakemake_args: list[str] | None = None, + env_vars: dict[str, str] | None = None, ) -> None: cfg = self._config self._validate_configfile(configfile) snkmt_db_path = f"{work_dir}/{SNKMT_DB_FILENAME}" wrapper_content = build_wrapper_script( - cfg.pixi_path, snkmt_db_path, configfile, snakemake_args + cfg.pixi_path, snkmt_db_path, configfile, snakemake_args, env_vars ) # Write wrapper script via SFTP, then make executable and launch detached diff --git a/app/config.py b/app/config.py index 74388a5..ba630c2 100644 --- a/app/config.py +++ b/app/config.py @@ -1,13 +1,25 @@ from __future__ import annotations +import re from pathlib import Path import yaml -from pydantic import BaseModel +from pydantic import BaseModel, field_validator from pydantic_settings import BaseSettings BACKEND_KEYS: frozenset[str] = frozenset({"slurm_ssh", "local"}) +_VALID_ENV_VAR_NAME = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +def _validate_env_var_names(v: list[str] | None) -> list[str] | None: + if v is not None: + for name in v: + if not _VALID_ENV_VAR_NAME.match(name): + msg = f"allowed_env_vars entry is not a valid env var name: {name!r}" + raise ValueError(msg) + return v + class Settings(BaseSettings): """Application settings loaded from environment variables.""" @@ -40,6 +52,12 @@ class SlurmSSHConfig(BaseModel): scratch_dir: str default_snakemake_args: list[str] = [] snkmt_db_sync_interval: float = 30.0 + allowed_env_vars: list[str] | None = None + + @field_validator("allowed_env_vars") + @classmethod + def _validate_allowed_env_vars(cls, v: list[str] | None) -> list[str] | None: + return _validate_env_var_names(v) class LocalConfig(BaseModel): @@ -50,6 +68,12 @@ class LocalConfig(BaseModel): poll_interval: float = 5.0 default_snakemake_args: list[str] = [] snkmt_db_sync_interval: float = 30.0 + allowed_env_vars: list[str] | None = None + + @field_validator("allowed_env_vars") + @classmethod + def _validate_allowed_env_vars(cls, v: list[str] | None) -> list[str] | None: + return _validate_env_var_names(v) def load_config( diff --git a/app/deps.py b/app/deps.py index 85d26d0..a1dd4a0 100644 --- a/app/deps.py +++ b/app/deps.py @@ -27,6 +27,7 @@ class AppState: settings: Settings health_cache: HealthCache default_snakemake_args: list[str] + allowed_env_vars: list[str] | None = None background_tasks: set[asyncio.Task[None]] = field(default_factory=set) @@ -46,6 +47,10 @@ def get_default_snakemake_args(request: Request) -> list[str]: return app_state(request).default_snakemake_args +def get_allowed_env_vars(request: Request) -> list[str] | None: + return app_state(request).allowed_env_vars + + def provide_background_tasks(request: Request) -> set[asyncio.Task[None]]: """Return the shared background-task set.""" return app_state(request).background_tasks diff --git a/app/main.py b/app/main.py index 477a4dc..bd80bad 100644 --- a/app/main.py +++ b/app/main.py @@ -40,6 +40,7 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: settings=settings, health_cache={"backend_ok": None, "checked_at": 0.0}, default_snakemake_args=backend_config.default_snakemake_args, + allowed_env_vars=backend_config.allowed_env_vars, ) store.restore_from_disk() diff --git a/app/models.py b/app/models.py index 0c6b3a5..bb9628f 100644 --- a/app/models.py +++ b/app/models.py @@ -75,6 +75,11 @@ class JobCreate(BaseModel): description="Relative paths within the work directory to cache, " "e.g. ['data', 'resources']. Requires cache_key.", ) + env_vars: dict[str, str] | None = Field( + None, + description="Environment variables to merge into the Snakemake process " + "environment. Only keys listed in the server's allowed_env_vars config are accepted.", + ) @field_validator("cache_key") @classmethod diff --git a/app/routes/jobs.py b/app/routes/jobs.py index 3f5b7a7..4746b2c 100644 --- a/app/routes/jobs.py +++ b/app/routes/jobs.py @@ -14,6 +14,7 @@ from app.backends.base import ComputeBackend from app.deps import ( + get_allowed_env_vars, get_backend, get_default_snakemake_args, get_store, @@ -100,8 +101,22 @@ async def create_job( store: Annotated[JobStore, Depends(get_store)], backend: Annotated[ComputeBackend, Depends(get_backend)], default_snakemake_args: Annotated[list[str], Depends(get_default_snakemake_args)], + allowed_env_vars: Annotated[list[str] | None, Depends(get_allowed_env_vars)], ) -> JobResponse: """Submit a new Snakemake job for execution.""" + if body.env_vars: + if allowed_env_vars is None: + raise HTTPException( + status_code=422, + detail="env_vars are not enabled; set allowed_env_vars in server config", + ) + disallowed = set(body.env_vars) - set(allowed_env_vars) + if disallowed: + raise HTTPException( + status_code=422, + detail=f"env_vars keys not in allowed list: {sorted(disallowed)}", + ) + source = body.workflow is_url = source.startswith(("http://", "https://")) if not is_url: @@ -133,6 +148,7 @@ async def create_job( extra_files=body.extra_files, cache_key=body.cache_key, cache_dirs=body.cache_dirs, + env_vars=body.env_vars, ), ), name=f"execute-{job_id}", diff --git a/app/tasks.py b/app/tasks.py index 55ad35f..18ce6e0 100644 --- a/app/tasks.py +++ b/app/tasks.py @@ -27,6 +27,7 @@ class ExecuteJobParams: extra_files: dict[str, str] | None = None cache_key: str | None = None cache_dirs: list[str] | None = None + env_vars: dict[str, str] | None = None logger = logging.getLogger(__name__) @@ -198,7 +199,7 @@ async def execute_job( await _restore_cache(backend, store, job_id, work_dir, params) store.mark_running(job_id) - await backend.launch(job_id, work_dir, params.configfile, params.snakemake_args) + await backend.launch(job_id, work_dir, params.configfile, params.snakemake_args, params.env_vars) def log_callback(line: str) -> None: store.push_log(job_id, line) diff --git a/app/utils.py b/app/utils.py index ab69e3c..16d768b 100644 --- a/app/utils.py +++ b/app/utils.py @@ -33,6 +33,7 @@ def build_wrapper_script( snkmt_db_path: str, configfile: str | None, snakemake_args: list[str] | None, + env_vars: dict[str, str] | None = None, ) -> str: """Build the .run.sh bash wrapper script for a Snakemake workflow run.""" configfile_arg = f" --configfile {shlex.quote(configfile)}" if configfile else "" @@ -41,9 +42,14 @@ def build_wrapper_script( extra_args = " " + " ".join(shlex.quote(a) for a in snakemake_args) pixi = shlex.quote(pixi_path) snkmt = shlex.quote(snkmt_db_path) + exports = ( + "\n".join(f"export {k}={shlex.quote(v)}" for k, v in env_vars.items()) + "\n" + if env_vars + else "" + ) return f"""\ #!/bin/bash -echo $$ > .pid +{exports}echo $$ > .pid SNKMT_ARGS="" if {pixi} run python -c "import snakemake_logger_plugin_snkmt" 2>/dev/null; then SNKMT_ARGS="--logger snkmt --logger-snkmt-db {snkmt}" diff --git a/config/config.local.example.yaml b/config/config.local.example.yaml index 091cbaf..72481a4 100644 --- a/config/config.local.example.yaml +++ b/config/config.local.example.yaml @@ -3,6 +3,7 @@ local: pixi_path: /home/appuser/.pixi/bin/pixi poll_interval: 5 default_snakemake_args: [] + allowed_env_vars: ["OETC_EMAIL", "OETC_PASSWORD"] # App-level settings (optional, shown with defaults) # HEALTH_CACHE_TTL_SECONDS: 300 diff --git a/config/config.slurm_ssh.example.yaml b/config/config.slurm_ssh.example.yaml index 0fef806..da5b1e8 100644 --- a/config/config.slurm_ssh.example.yaml +++ b/config/config.slurm_ssh.example.yaml @@ -6,6 +6,7 @@ slurm_ssh: scratch_dir: /scratch/myuser/executor poll_interval: 5 default_snakemake_args: ["--profile", "slurm"] + # allowed_env_vars: ["OETC_EMAIL", "OETC_PASSWORD"] # App-level settings (optional, shown with defaults) # HEALTH_CACHE_TTL_SECONDS: 300 diff --git a/tests/conftest.py b/tests/conftest.py index 4d7078e..a546ae1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -252,6 +252,7 @@ async def async_client(store, mock_backend, settings): settings=settings, health_cache={"backend_ok": None, "checked_at": 0.0}, default_snakemake_args=backend_config.default_snakemake_args, + allowed_env_vars=backend_config.allowed_env_vars, ) transport = ASGITransport(app=app) diff --git a/tests/test_api.py b/tests/test_api.py index 5fc9983..127e8aa 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -828,3 +828,76 @@ def test_create_backend_slurm(self): ) ) assert isinstance(backend, SlurmSSHBackend) + + +class TestEnvVars: + async def test_env_vars_rejected_when_feature_disabled(self, async_client): + response = await async_client.post( + "/jobs", + json={ + "workflow": "https://github.com/org/repo.git", + "env_vars": {"OETC_EMAIL": "test@example.com"}, + }, + ) + assert response.status_code == 422 + assert "not enabled" in response.json()["detail"] + + async def test_env_vars_accepted_when_in_whitelist( + self, store, mock_backend, settings + ): + from httpx import ASGITransport, AsyncClient + + from app.config import LocalConfig + from app.deps import AppState + + backend_config = LocalConfig() + app.state.app = AppState( + store=store, + backend=mock_backend, + settings=settings, + health_cache={"backend_ok": None, "checked_at": 0.0}, + default_snakemake_args=backend_config.default_snakemake_args, + allowed_env_vars=["OETC_EMAIL", "OETC_PASSWORD"], + ) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/jobs", + json={ + "workflow": "https://github.com/org/repo.git", + "env_vars": {"OETC_EMAIL": "user@example.com"}, + }, + ) + assert response.status_code == 201 + + async def test_env_vars_rejected_when_key_not_in_whitelist( + self, store, mock_backend, settings + ): + from httpx import ASGITransport, AsyncClient + + from app.config import LocalConfig + from app.deps import AppState + + backend_config = LocalConfig() + app.state.app = AppState( + store=store, + backend=mock_backend, + settings=settings, + health_cache={"backend_ok": None, "checked_at": 0.0}, + default_snakemake_args=backend_config.default_snakemake_args, + allowed_env_vars=["OETC_EMAIL"], + ) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/jobs", + json={ + "workflow": "https://github.com/org/repo.git", + "env_vars": { + "OETC_EMAIL": "user@example.com", + "MY_SECRET": "hunter2", + }, + }, + ) + assert response.status_code == 422 + assert "MY_SECRET" in response.json()["detail"] diff --git a/tests/test_execute.py b/tests/test_execute.py index 33e6374..b7de128 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -56,6 +56,7 @@ async def test_successful_job(self, mock_backend, tmp_path): "/scratch/test/jobs/test-job-id", "config.yaml", ["--profile", "slurm"], + None, ) mock_backend.monitor.assert_called_once() @@ -132,6 +133,7 @@ async def test_minimal_args(self, mock_backend, tmp_path): "/scratch/test/jobs/test-job-id", None, None, + None, ) async def test_cache_restore_and_save(self, mock_backend, tmp_path):