Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ __pycache__/
.pytest_cache/
.ruff_cache/
/data
.idea
1 change: 1 addition & 0 deletions app/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ async def launch(
work_dir: str,
configfile: str | None,
snakemake_args: list[str] | None = None,
env_vars: dict[str, str] | None = None,
) -> None:
"""
Write the .run.sh wrapper script, launch it via nohup/disown,
Expand Down
3 changes: 2 additions & 1 deletion app/backends/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,12 @@ async def launch(
work_dir: str,
configfile: str | None,
snakemake_args: list[str] | None = None,
env_vars: dict[str, str] | None = None,
) -> None:
self._validate_configfile(configfile)
snkmt_db_path = str(Path(work_dir).resolve() / SNKMT_DB_FILENAME)
wrapper_content = build_wrapper_script(
self._config.pixi_path, snkmt_db_path, configfile, snakemake_args
self._config.pixi_path, snkmt_db_path, configfile, snakemake_args, env_vars
)

wd = Path(work_dir)
Expand Down
3 changes: 2 additions & 1 deletion app/backends/slurm_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,13 @@ async def launch(
work_dir: str,
configfile: str | None,
snakemake_args: list[str] | None = None,
env_vars: dict[str, str] | None = None,
) -> None:
cfg = self._config
self._validate_configfile(configfile)
snkmt_db_path = f"{work_dir}/{SNKMT_DB_FILENAME}"
wrapper_content = build_wrapper_script(
cfg.pixi_path, snkmt_db_path, configfile, snakemake_args
cfg.pixi_path, snkmt_db_path, configfile, snakemake_args, env_vars
)

# Write wrapper script via SFTP, then make executable and launch detached
Expand Down
26 changes: 25 additions & 1 deletion app/config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
from __future__ import annotations

import re
from pathlib import Path

import yaml
from pydantic import BaseModel
from pydantic import BaseModel, field_validator
from pydantic_settings import BaseSettings

BACKEND_KEYS: frozenset[str] = frozenset({"slurm_ssh", "local"})

_VALID_ENV_VAR_NAME = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")


def _validate_env_var_names(v: list[str] | None) -> list[str] | None:
if v is not None:
for name in v:
if not _VALID_ENV_VAR_NAME.match(name):
msg = f"allowed_env_vars entry is not a valid env var name: {name!r}"
raise ValueError(msg)
return v


class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
Expand Down Expand Up @@ -40,6 +52,12 @@ class SlurmSSHConfig(BaseModel):
scratch_dir: str
default_snakemake_args: list[str] = []
snkmt_db_sync_interval: float = 30.0
allowed_env_vars: list[str] | None = None

@field_validator("allowed_env_vars")
@classmethod
def _validate_allowed_env_vars(cls, v: list[str] | None) -> list[str] | None:
return _validate_env_var_names(v)


class LocalConfig(BaseModel):
Expand All @@ -50,6 +68,12 @@ class LocalConfig(BaseModel):
poll_interval: float = 5.0
default_snakemake_args: list[str] = []
snkmt_db_sync_interval: float = 30.0
allowed_env_vars: list[str] | None = None

@field_validator("allowed_env_vars")
@classmethod
def _validate_allowed_env_vars(cls, v: list[str] | None) -> list[str] | None:
return _validate_env_var_names(v)


def load_config(
Expand Down
5 changes: 5 additions & 0 deletions app/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class AppState:
settings: Settings
health_cache: HealthCache
default_snakemake_args: list[str]
allowed_env_vars: list[str] | None = None
background_tasks: set[asyncio.Task[None]] = field(default_factory=set)


Expand All @@ -46,6 +47,10 @@ def get_default_snakemake_args(request: Request) -> list[str]:
return app_state(request).default_snakemake_args


def get_allowed_env_vars(request: Request) -> list[str] | None:
return app_state(request).allowed_env_vars


def provide_background_tasks(request: Request) -> set[asyncio.Task[None]]:
"""Return the shared background-task set."""
return app_state(request).background_tasks
Expand Down
1 change: 1 addition & 0 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
settings=settings,
health_cache={"backend_ok": None, "checked_at": 0.0},
default_snakemake_args=backend_config.default_snakemake_args,
allowed_env_vars=backend_config.allowed_env_vars,
)

store.restore_from_disk()
Expand Down
5 changes: 5 additions & 0 deletions app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ class JobCreate(BaseModel):
description="Relative paths within the work directory to cache, "
"e.g. ['data', 'resources']. Requires cache_key.",
)
env_vars: dict[str, str] | None = Field(
None,
description="Environment variables to merge into the Snakemake process "
"environment. Only keys listed in the server's allowed_env_vars config are accepted.",
)

@field_validator("cache_key")
@classmethod
Expand Down
16 changes: 16 additions & 0 deletions app/routes/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from app.backends.base import ComputeBackend
from app.deps import (
get_allowed_env_vars,
get_backend,
get_default_snakemake_args,
get_store,
Expand Down Expand Up @@ -100,8 +101,22 @@ async def create_job(
store: Annotated[JobStore, Depends(get_store)],
backend: Annotated[ComputeBackend, Depends(get_backend)],
default_snakemake_args: Annotated[list[str], Depends(get_default_snakemake_args)],
allowed_env_vars: Annotated[list[str] | None, Depends(get_allowed_env_vars)],
) -> JobResponse:
"""Submit a new Snakemake job for execution."""
if body.env_vars:
if allowed_env_vars is None:
raise HTTPException(
status_code=422,
detail="env_vars are not enabled; set allowed_env_vars in server config",
)
disallowed = set(body.env_vars) - set(allowed_env_vars)
if disallowed:
raise HTTPException(
status_code=422,
detail=f"env_vars keys not in allowed list: {sorted(disallowed)}",
)

source = body.workflow
is_url = source.startswith(("http://", "https://"))
if not is_url:
Expand Down Expand Up @@ -133,6 +148,7 @@ async def create_job(
extra_files=body.extra_files,
cache_key=body.cache_key,
cache_dirs=body.cache_dirs,
env_vars=body.env_vars,
),
),
name=f"execute-{job_id}",
Expand Down
3 changes: 2 additions & 1 deletion app/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class ExecuteJobParams:
extra_files: dict[str, str] | None = None
cache_key: str | None = None
cache_dirs: list[str] | None = None
env_vars: dict[str, str] | None = None


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -198,7 +199,7 @@ async def execute_job(
await _restore_cache(backend, store, job_id, work_dir, params)

store.mark_running(job_id)
await backend.launch(job_id, work_dir, params.configfile, params.snakemake_args)
await backend.launch(job_id, work_dir, params.configfile, params.snakemake_args, params.env_vars)

def log_callback(line: str) -> None:
store.push_log(job_id, line)
Expand Down
8 changes: 7 additions & 1 deletion app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def build_wrapper_script(
snkmt_db_path: str,
configfile: str | None,
snakemake_args: list[str] | None,
env_vars: dict[str, str] | None = None,
) -> str:
"""Build the .run.sh bash wrapper script for a Snakemake workflow run."""
configfile_arg = f" --configfile {shlex.quote(configfile)}" if configfile else ""
Expand All @@ -41,9 +42,14 @@ def build_wrapper_script(
extra_args = " " + " ".join(shlex.quote(a) for a in snakemake_args)
pixi = shlex.quote(pixi_path)
snkmt = shlex.quote(snkmt_db_path)
exports = (
"\n".join(f"export {k}={shlex.quote(v)}" for k, v in env_vars.items()) + "\n"
if env_vars
else ""
)
return f"""\
#!/bin/bash
echo $$ > .pid
{exports}echo $$ > .pid
SNKMT_ARGS=""
if {pixi} run python -c "import snakemake_logger_plugin_snkmt" 2>/dev/null; then
SNKMT_ARGS="--logger snkmt --logger-snkmt-db {snkmt}"
Expand Down
1 change: 1 addition & 0 deletions config/config.local.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ local:
pixi_path: /home/appuser/.pixi/bin/pixi
poll_interval: 5
default_snakemake_args: []
allowed_env_vars: ["OETC_EMAIL", "OETC_PASSWORD"]

# App-level settings (optional, shown with defaults)
# HEALTH_CACHE_TTL_SECONDS: 300
Expand Down
1 change: 1 addition & 0 deletions config/config.slurm_ssh.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ slurm_ssh:
scratch_dir: /scratch/myuser/executor
poll_interval: 5
default_snakemake_args: ["--profile", "slurm"]
# allowed_env_vars: ["OETC_EMAIL", "OETC_PASSWORD"]

# App-level settings (optional, shown with defaults)
# HEALTH_CACHE_TTL_SECONDS: 300
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ async def async_client(store, mock_backend, settings):
settings=settings,
health_cache={"backend_ok": None, "checked_at": 0.0},
default_snakemake_args=backend_config.default_snakemake_args,
allowed_env_vars=backend_config.allowed_env_vars,
)

transport = ASGITransport(app=app)
Expand Down
73 changes: 73 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,3 +828,76 @@ def test_create_backend_slurm(self):
)
)
assert isinstance(backend, SlurmSSHBackend)


class TestEnvVars:
async def test_env_vars_rejected_when_feature_disabled(self, async_client):
response = await async_client.post(
"/jobs",
json={
"workflow": "https://github.com/org/repo.git",
"env_vars": {"OETC_EMAIL": "test@example.com"},
},
)
assert response.status_code == 422
assert "not enabled" in response.json()["detail"]

async def test_env_vars_accepted_when_in_whitelist(
self, store, mock_backend, settings
):
from httpx import ASGITransport, AsyncClient

from app.config import LocalConfig
from app.deps import AppState

backend_config = LocalConfig()
app.state.app = AppState(
store=store,
backend=mock_backend,
settings=settings,
health_cache={"backend_ok": None, "checked_at": 0.0},
default_snakemake_args=backend_config.default_snakemake_args,
allowed_env_vars=["OETC_EMAIL", "OETC_PASSWORD"],
)
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.post(
"/jobs",
json={
"workflow": "https://github.com/org/repo.git",
"env_vars": {"OETC_EMAIL": "user@example.com"},
},
)
assert response.status_code == 201

async def test_env_vars_rejected_when_key_not_in_whitelist(
self, store, mock_backend, settings
):
from httpx import ASGITransport, AsyncClient

from app.config import LocalConfig
from app.deps import AppState

backend_config = LocalConfig()
app.state.app = AppState(
store=store,
backend=mock_backend,
settings=settings,
health_cache={"backend_ok": None, "checked_at": 0.0},
default_snakemake_args=backend_config.default_snakemake_args,
allowed_env_vars=["OETC_EMAIL"],
)
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.post(
"/jobs",
json={
"workflow": "https://github.com/org/repo.git",
"env_vars": {
"OETC_EMAIL": "user@example.com",
"MY_SECRET": "hunter2",
},
},
)
assert response.status_code == 422
assert "MY_SECRET" in response.json()["detail"]
2 changes: 2 additions & 0 deletions tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ async def test_successful_job(self, mock_backend, tmp_path):
"/scratch/test/jobs/test-job-id",
"config.yaml",
["--profile", "slurm"],
None,
)
mock_backend.monitor.assert_called_once()

Expand Down Expand Up @@ -132,6 +133,7 @@ async def test_minimal_args(self, mock_backend, tmp_path):
"/scratch/test/jobs/test-job-id",
None,
None,
None,
)

async def test_cache_restore_and_save(self, mock_backend, tmp_path):
Expand Down
Loading