Skip to content

Commit c6068da

Browse files
committed
switch backend to listen on free port and require auth
- instead of hard-coded port we dynamically pick a free one - backend<->frontend channel is guarded by a random auth token
1 parent 7d7faba commit c6068da

21 files changed

Lines changed: 297 additions & 128 deletions

backend/app_factory.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22

33
from __future__ import annotations
44

5+
import base64
6+
import hmac
7+
from collections.abc import Awaitable, Callable
58
from typing import TYPE_CHECKING
69

710
from fastapi import FastAPI, Request
811
from fastapi.exceptions import RequestValidationError
912
from fastapi.middleware.cors import CORSMiddleware
1013
from fastapi.responses import JSONResponse
14+
from starlette.responses import Response as StarletteResponse
1115

1216
from _routes._errors import HTTPError
1317
from _routes.generation import router as generation_router
@@ -36,6 +40,7 @@ def create_app(
3640
handler: "AppHandler",
3741
allowed_origins: list[str] | None = None,
3842
title: str = "LTX-2 Video Generation Server",
43+
auth_token: str = "",
3944
) -> FastAPI:
4045
"""Create a configured FastAPI app bound to the provided handler."""
4146
init_state_service(handler)
@@ -48,6 +53,37 @@ def create_app(
4853
allow_headers=["*"],
4954
)
5055

56+
@app.middleware("http")
57+
async def _auth_middleware( # pyright: ignore[reportUnusedFunction]
58+
request: Request,
59+
call_next: Callable[[Request], Awaitable[StarletteResponse]],
60+
) -> StarletteResponse:
61+
if not auth_token:
62+
return await call_next(request)
63+
if request.method == "OPTIONS":
64+
return await call_next(request)
65+
def _token_matches(candidate: str) -> bool:
66+
return hmac.compare_digest(candidate, auth_token)
67+
68+
# WebSocket: check query param
69+
if request.headers.get("upgrade", "").lower() == "websocket":
70+
if _token_matches(request.query_params.get("token", "")):
71+
return await call_next(request)
72+
return JSONResponse(status_code=401, content={"error": "Unauthorized"})
73+
# HTTP: Bearer or Basic auth
74+
auth_header = request.headers.get("authorization", "")
75+
if auth_header.startswith("Bearer ") and _token_matches(auth_header[7:]):
76+
return await call_next(request)
77+
if auth_header.startswith("Basic "):
78+
try:
79+
decoded = base64.b64decode(auth_header[6:]).decode()
80+
_, _, password = decoded.partition(":")
81+
if _token_matches(password):
82+
return await call_next(request)
83+
except Exception:
84+
pass
85+
return JSONResponse(status_code=401, content={"error": "Unauthorized"})
86+
5187
async def _route_http_error_handler(request: Request, exc: Exception) -> JSONResponse:
5288
if isinstance(exc, HTTPError):
5389
log_http_error(request, exc)

backend/ltx2_server.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def patched_sdpa(
9999
# Constants & Paths
100100
# ============================================================
101101

102-
PORT = 8000
102+
PORT = 0
103103

104104

105105
def _get_device() -> torch.device:
@@ -219,7 +219,10 @@ def _resolve_force_api_generations() -> bool:
219219
)
220220

221221
handler = build_initial_state(runtime_config, DEFAULT_APP_SETTINGS)
222-
app = create_app(handler=handler, allowed_origins=DEFAULT_ALLOWED_ORIGINS)
222+
223+
auth_token = os.environ.get("LTX_AUTH_TOKEN", "")
224+
225+
app = create_app(handler=handler, allowed_origins=DEFAULT_ALLOWED_ORIGINS, auth_token=auth_token)
223226

224227

225228
def precache_model_files(model_dir: Path) -> int:
@@ -257,9 +260,10 @@ def log_hardware_info() -> None:
257260

258261

259262
if __name__ == "__main__":
263+
import asyncio
260264
import uvicorn
261265

262-
port = int(os.environ.get("LTX_PORT", PORT))
266+
port = int(os.environ.get("LTX_PORT", "") or PORT)
263267
logger.info("=" * 60)
264268
logger.info("LTX-2 Video Generation Server (FastAPI + Uvicorn)")
265269
log_hardware_info()
@@ -285,4 +289,26 @@ def log_hardware_info() -> None:
285289
"uvicorn.access": {"handlers": ["default"], "level": "INFO", "propagate": False},
286290
},
287291
}
288-
uvicorn.run(app, host="127.0.0.1", port=port, log_level="info", access_log=False, log_config=log_config)
292+
293+
import socket as _socket
294+
295+
# Bind the socket ourselves so we know the actual port before uvicorn starts.
296+
sock = _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM)
297+
sock.setsockopt(_socket.SOL_SOCKET, _socket.SO_REUSEADDR, 1)
298+
sock.bind(("127.0.0.1", port))
299+
actual_port = int(sock.getsockname()[1])
300+
301+
config = uvicorn.Config(app, host="127.0.0.1", port=actual_port, log_level="info", access_log=False, log_config=log_config)
302+
server = uvicorn.Server(config)
303+
304+
_orig_startup = server.startup
305+
306+
async def _startup_with_ready_msg(sockets: list[_socket.socket] | None = None) -> None:
307+
await _orig_startup(sockets=sockets)
308+
if server.started:
309+
# Machine-parseable ready message — Electron matches this line
310+
print(f"Server running on http://127.0.0.1:{actual_port}", flush=True)
311+
312+
server.startup = _startup_with_ready_msg # type: ignore[assignment]
313+
314+
asyncio.run(server.serve(sockets=[sock]))

backend/tests/test_auth.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""Tests for shared-secret authentication middleware."""
2+
3+
from __future__ import annotations
4+
5+
import base64
6+
7+
from starlette.testclient import TestClient
8+
9+
from app_factory import create_app
10+
11+
12+
def test_request_without_token_returns_401(test_state):
13+
app = create_app(handler=test_state, auth_token="test-secret")
14+
with TestClient(app) as client:
15+
response = client.get("/health")
16+
assert response.status_code == 401
17+
assert response.json() == {"error": "Unauthorized"}
18+
19+
20+
def test_request_with_correct_bearer_token(test_state):
21+
app = create_app(handler=test_state, auth_token="test-secret")
22+
with TestClient(app) as client:
23+
response = client.get("/health", headers={"Authorization": "Bearer test-secret"})
24+
assert response.status_code == 200
25+
26+
27+
def test_request_with_correct_basic_auth(test_state):
28+
app = create_app(handler=test_state, auth_token="test-secret")
29+
credentials = base64.b64encode(b":test-secret").decode()
30+
with TestClient(app) as client:
31+
response = client.get("/health", headers={"Authorization": f"Basic {credentials}"})
32+
assert response.status_code == 200
33+
34+
35+
def test_request_with_wrong_token_returns_401(test_state):
36+
app = create_app(handler=test_state, auth_token="test-secret")
37+
with TestClient(app) as client:
38+
response = client.get("/health", headers={"Authorization": "Bearer wrong-token"})
39+
assert response.status_code == 401
40+
41+
42+
def test_health_without_token_returns_401(test_state):
43+
"""Health endpoint is NOT exempt from auth."""
44+
app = create_app(handler=test_state, auth_token="test-secret")
45+
with TestClient(app) as client:
46+
response = client.get("/health")
47+
assert response.status_code == 401
48+
49+
50+
def test_no_auth_token_disables_middleware(test_state):
51+
"""When auth_token is empty string, auth is disabled (dev/test mode)."""
52+
app = create_app(handler=test_state, auth_token="")
53+
with TestClient(app) as client:
54+
response = client.get("/health")
55+
assert response.status_code == 200
56+
57+
58+
def test_websocket_with_token_query_param(test_state):
59+
app = create_app(handler=test_state, auth_token="test-secret")
60+
with TestClient(app) as client:
61+
# WebSocket upgrade without token should fail with 401
62+
response = client.get(
63+
"/ws/download/test",
64+
headers={"upgrade": "websocket", "connection": "upgrade"},
65+
)
66+
assert response.status_code == 401
67+
68+
# WebSocket upgrade with correct token query param
69+
response = client.get(
70+
"/ws/download/test?token=test-secret",
71+
headers={"upgrade": "websocket", "connection": "upgrade"},
72+
)
73+
# The route may not exist, but auth should pass (not 401)
74+
assert response.status_code != 401

electron/config.ts

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ import path from 'path'
33
import os from 'os'
44
import { getProjectAssetsPath } from './app-state'
55

6-
export const PYTHON_PORT = 8000
7-
export const BACKEND_BASE_URL = `http://localhost:${PYTHON_PORT}`
86
export const isDev = !app.isPackaged
97

108
// Get directory - works in both CJS and ESM contexts

electron/gpu.ts

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
import { execSync } from 'child_process'
2-
import { BACKEND_BASE_URL } from './config'
32
import { logger } from './logger'
4-
import { getPythonPath } from './python-backend'
3+
import { getAuthToken, getBackendUrl, getPythonPath } from './python-backend'
54

65
// Check if NVIDIA GPU is available
76
export async function checkGPU(): Promise<{ available: boolean; name?: string; vram?: number }> {
87
try {
8+
const url = getBackendUrl()
9+
if (!url) throw new Error('Backend URL not available yet')
910
// Try to get GPU info from the backend API first (more reliable)
10-
const response = await fetch(`${BACKEND_BASE_URL}/api/gpu-info`, {
11+
const headers: Record<string, string> = { 'Content-Type': 'application/json' }
12+
const token = getAuthToken()
13+
if (token) headers['Authorization'] = `Bearer ${token}`
14+
const response = await fetch(`${url}/api/gpu-info`, {
1115
method: 'GET',
12-
headers: { 'Content-Type': 'application/json' },
16+
headers,
1317
})
1418

1519
if (response.ok) {

electron/ipc/app-handlers.ts

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import { app, ipcMain } from 'electron'
22
import path from 'path'
33
import fs from 'fs'
4-
import { BACKEND_BASE_URL } from '../config'
54
import { checkGPU } from '../gpu'
65
import { isPythonReady, downloadPythonEmbed } from '../python-setup'
7-
import { getBackendHealthStatus, startPythonBackend } from '../python-backend'
6+
import { getBackendHealthStatus, getBackendUrl, getAuthToken, startPythonBackend } from '../python-backend'
87
import { getMainWindow } from '../window'
98
import { getAnalyticsState, setAnalyticsEnabled, sendAnalyticsEvent } from '../analytics'
109

@@ -68,8 +67,8 @@ function markLicenseAccepted(settingsPath: string): void {
6867
}
6968

7069
export function registerAppHandlers(): void {
71-
ipcMain.handle('get-backend-url', () => {
72-
return BACKEND_BASE_URL
70+
ipcMain.handle('get-backend', () => {
71+
return { url: getBackendUrl() ?? '', token: getAuthToken() ?? '' }
7372
})
7473

7574
ipcMain.handle('get-models-path', () => {

electron/preload.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ const { contextBridge, ipcRenderer } = require('electron')
33

44
// Expose protected methods to the renderer process
55
contextBridge.exposeInMainWorld('electronAPI', {
6-
// Get the backend URL
7-
getBackendUrl: (): Promise<string> => ipcRenderer.invoke('get-backend-url'),
6+
// Get the backend URL and auth token
7+
getBackend: (): Promise<{ url: string; token: string }> => ipcRenderer.invoke('get-backend'),
88

99
// Get the path where models are stored
1010
getModelsPath: (): Promise<string> => ipcRenderer.invoke('get-models-path'),
@@ -138,7 +138,7 @@ interface BackendHealthStatus {
138138
declare global {
139139
interface Window {
140140
electronAPI: {
141-
getBackendUrl: () => Promise<string>
141+
getBackend: () => Promise<{ url: string; token: string }>
142142
getModelsPath: () => Promise<string>
143143
readLocalFile: (filePath: string) => Promise<{ data: string; mimeType: string }>
144144
checkGpu: () => Promise<{ available: boolean; name?: string; vram?: number }>

0 commit comments

Comments
 (0)