Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,12 @@ JWT_SECRET=ace-step-ui-local-secret
# ── Optional ──────────────────────────────────────────────────────────────────
# Pexels API key for video backgrounds — https://www.pexels.com/api/
# PEXELS_API_KEY=

# ── Binary extra arguments ────────────────────────────────────────────────────
# Append extra CLI flags to the ace-qwen3 or dit-vae spawn invocations.
# Useful for hardware-specific tuning or debugging, e.g. limit CPU threads:
# ACE_QWEN3_EXTRA_ARGS=--threads 4
# DIT_VAE_EXTRA_ARGS=--threads 4
#
# DIT_VAE_EXTRA_ARGS=
# ACE_QWEN3_EXTRA_ARGS=
178 changes: 170 additions & 8 deletions server/src/services/acestep.ts
Original file line number Diff line number Diff line change
Expand Up @@ -199,36 +199,196 @@ function resolveAudioPath(audioUrl: string): string {
// <tmpDir>/request0.json → dit-vae → <tmpDir>/request00.wav
// ---------------------------------------------------------------------------

/** Run a binary and return its captured stdout/stderr. Throws on non-zero exit. */
async function runBinary(
/**
* Parse a space-separated list of extra CLI arguments from an env variable.
* Supports simple quoting: "hello world" is treated as a single argument.
* Example: ACE_QWEN3_EXTRA_ARGS="--threads 4" → ['--threads', '4']
*/
function parseExtraArgs(envVar: string | undefined): string[] {
if (!envVar?.trim()) return [];
const args: string[] = [];
const re = /(?:[^\s"']+|"[^"]*"|'[^']*')+/g;
let m: RegExpExecArray | null;
while ((m = re.exec(envVar)) !== null) {
args.push(m[0].replace(/^["']|["']$/g, ''));
}
return args;
}

/** Build a human-readable error message from a failed binary run (max 2000 chars). */
function buildBinaryError(label: string, result: { exitCode: number | null; stdout: string; stderr: string }): Error {
const msg = (result.stderr || result.stdout || `exit code ${result.exitCode}`).slice(0, 2000);
return new Error(`${label} failed: ${msg}`);
}

/**
* Run a binary, streaming stderr lines to an optional callback, and return
* captured output. Throws with a detailed message on non-zero exit.
*/
function runBinary(
bin: string,
args: string[],
label: string,
env?: NodeJS.ProcessEnv,
onLine?: (line: string) => void,
): Promise<{ stdout: string; stderr: string }> {
return new Promise((resolve, reject) => {
const proc = spawn(bin, args, {
shell: false,
env: { ...process.env },
env: { ...process.env, ...env },
stdio: ['ignore', 'pipe', 'pipe'],
});

let stdout = '';
let stderr = '';
let lineBuf = '';

proc.stdout.on('data', (chunk: Buffer) => { stdout += chunk.toString(); });
proc.stderr.on('data', (chunk: Buffer) => { stderr += chunk.toString(); });
proc.stderr.on('data', (chunk: Buffer) => {
const text = chunk.toString();
stderr += text;
if (onLine) {
lineBuf += text;
const lines = lineBuf.split('\n');
lineBuf = lines.pop() ?? '';
for (const line of lines) {
const trimmed = line.trim();
if (trimmed) onLine(trimmed);
}
}
});

proc.on('close', (code) => {
if (code === 0) {
resolve({ stdout, stderr });
} else {
const msg = (stderr || stdout || `exit code ${code}`).slice(0, 500);
reject(new Error(`${label} failed: ${msg}`));
reject(buildBinaryError(label, { exitCode: code, stdout, stderr }));
}
});
proc.on('error', (err) => reject(new Error(`${label} process error: ${err.message}`)));
});
}

// ---------------------------------------------------------------------------
// Live progress parsing — translates binary stderr lines into job.stage /
// job.progress updates that the polling API can return to the frontend.
//
// ace-qwen3 progress lines (all on stderr):
// [Phase1] step 100, 1 active, 19.0 tok/s — lyrics LM decode
// [Phase1] Decode 15871ms — Phase1 complete
// [Phase2] max_tokens: 800, … — captures audio-codes budget
// [Decode] step 50, 1 active, 51 total codes, 20.1 tok/s — audio LM decode
//
// dit-vae progress lines (all on stderr):
// [DiT] Starting: T=…, steps=8, … — captures DiT step count
// [DiT] step 1/8 t=1.000 — DiT diffusion step N/M
// [DiT] Total generation: … — DiT complete
// [VAE] Tiled decode: 28 tiles … — VAE starting
// [VAE] Tiled decode done: 28 tiles → … — VAE complete
//
// Progress scale: 0–50% ace-qwen3 | 50–100% dit-vae
// ---------------------------------------------------------------------------

// Progress budget across the two-binary pipeline (must sum to 100):
// 0–30% ace-qwen3 Phase1 (lyrics LM decode — step count varies, ~200–400)
// 30–50% ace-qwen3 Phase2 (audio-codes LM decode)
// 50–85% dit-vae DiT (diffusion steps — exact N/M known at runtime)
// 85–100% dit-vae VAE (tiled audio decode)
const PROGRESS_LM_PHASE1_MAX = 30; // % at end of Phase1
const PROGRESS_LM_PHASE2_END = 50; // % at end of Phase2 (= start of dit-vae)
const PROGRESS_DIT_END = 85; // % at end of DiT diffusion
const PROGRESS_VAE_END = 98; // % at end of VAE decode (100 set on job success)

/**
* Returns an onLine callback for ace-qwen3 stderr that updates job.stage and
* job.progress as the LM pipeline progresses (contributes 0–50% overall).
*/
function makeLmProgressHandler(job: ActiveJob): (line: string) => void {
let phase2MaxTokens = 800;
// Phase1 step ceiling: ace-qwen3 typically produces 200–350 lyrics tokens.
// 400 is a generous upper bound so the bar reaches ~28% by the end of Phase1.
const PHASE1_STEP_CEIL = 400;

return (line: string) => {
// Phase1 LM decode: "[Phase1] step 100, 1 active, 19.0 tok/s"
const p1 = line.match(/^\[Phase1\] step (\d+),.*?([\d.]+) tok\/s/);
if (p1) {
const step = parseInt(p1[1], 10);
const rate = p1[2];
job.progress = Math.min(PROGRESS_LM_PHASE1_MAX - 2, Math.round((step / PHASE1_STEP_CEIL) * (PROGRESS_LM_PHASE1_MAX - 2)));
job.stage = `LLM: generating lyrics — step ${step} (${rate} tok/s)`;
return;
}
// Phase1 done: "[Phase1] Decode 15871ms"
if (/^\[Phase1\] Decode/.test(line)) {
job.progress = PROGRESS_LM_PHASE1_MAX;
job.stage = 'LLM: lyrics complete — generating audio codes…';
return;
}
// Phase2 max tokens: "[Phase2] max_tokens: 800, …"
const p2m = line.match(/^\[Phase2\] max_tokens:\s*(\d+)/);
if (p2m) {
phase2MaxTokens = parseInt(p2m[1], 10) || 800;
return;
}
// Phase2 audio-codes decode: "[Decode] step 50, 1 active, 51 total codes, 20.1 tok/s"
const p2d = line.match(/^\[Decode\] step \d+,.*?(\d+) total codes,.*?([\d.]+) tok\/s/);
if (p2d) {
const codes = parseInt(p2d[1], 10);
const rate = p2d[2];
const phase2Range = PROGRESS_LM_PHASE2_END - PROGRESS_LM_PHASE1_MAX;
job.progress = PROGRESS_LM_PHASE1_MAX + Math.min(phase2Range, Math.round((codes / phase2MaxTokens) * phase2Range));
job.stage = `LLM: audio codes — ${codes}/${phase2MaxTokens} (${rate} tok/s)`;
}
};
}

/**
* Returns an onLine callback for dit-vae stderr that updates job.stage and
* job.progress as the DiT+VAE pipeline progresses (contributes 50–100% overall).
*/
function makeDitVaeProgressHandler(job: ActiveJob): (line: string) => void {
let ditTotalSteps = 8;

return (line: string) => {
// DiT starting — capture step count: "[DiT] Starting: T=3470, S=1735, …, steps=8, …"
const ditStart = line.match(/^\[DiT\] Starting:.*?steps=(\d+)/);
if (ditStart) {
ditTotalSteps = parseInt(ditStart[1], 10) || 8;
return;
}
// DiT step: "[DiT] step 1/8 t=1.000"
const ditStep = line.match(/^\[DiT\] step (\d+)\/(\d+)/);
if (ditStep) {
const step = parseInt(ditStep[1], 10);
const total = parseInt(ditStep[2], 10);
ditTotalSteps = total;
const ditRange = PROGRESS_DIT_END - PROGRESS_LM_PHASE2_END;
job.progress = PROGRESS_LM_PHASE2_END + Math.round((step / total) * ditRange);
job.stage = `DiT: step ${step}/${total}`;
return;
}
// DiT complete: "[DiT] Total generation: 16200.0 ms …"
if (/^\[DiT\] Total generation/.test(line)) {
job.progress = PROGRESS_DIT_END;
job.stage = 'VAE: decoding audio…';
return;
}
// VAE starting: "[VAE] Tiled decode: 28 tiles (chunk=256, overlap=64, stride=128)"
const vaeStart = line.match(/^\[VAE\] Tiled decode:\s*(\d+) tiles/);
if (vaeStart) {
job.progress = PROGRESS_DIT_END;
job.stage = `VAE: decoding ${vaeStart[1]} tiles…`;
return;
}
// VAE done: "[VAE] Tiled decode done: 28 tiles → T_audio=…"
if (/^\[VAE\] Tiled decode done/.test(line)) {
job.progress = PROGRESS_VAE_END;
job.stage = 'VAE: decode complete — writing audio…';
}
};
}

async function runViaSpawn(
jobId: string,
params: GenerationParams,
Expand Down Expand Up @@ -294,9 +454,10 @@ async function runViaSpawn(

const batchSize = Math.min(Math.max(params.batchSize ?? 1, 1), 8);
if (batchSize > 1) lmArgs.push('--batch', String(batchSize));
lmArgs.push(...parseExtraArgs(process.env.ACE_QWEN3_EXTRA_ARGS));

console.log(`[Spawn] Job ${jobId}: ace-qwen3 ${lmArgs.slice(0, 6).join(' ')} …`);
await runBinary(lmBin, lmArgs, 'ace-qwen3');
await runBinary(lmBin, lmArgs, 'ace-qwen3', undefined, makeLmProgressHandler(job));

// Collect enriched JSON files produced by ace-qwen3:
// request.json → request0.json [, request1.json, …] (placed alongside request.json)
Expand Down Expand Up @@ -344,9 +505,10 @@ async function runViaSpawn(
ditArgs.push('--repainting-start', String(params.repaintingStart));
if (params.repaintingEnd && params.repaintingEnd > 0)
ditArgs.push('--repainting-end', String(params.repaintingEnd));
ditArgs.push(...parseExtraArgs(process.env.DIT_VAE_EXTRA_ARGS));

console.log(`[Spawn] Job ${jobId}: dit-vae ${ditArgs.slice(0, 6).join(' ')} …`);
await runBinary(ditVaeBin, ditArgs, 'dit-vae');
await runBinary(ditVaeBin, ditArgs, 'dit-vae', undefined, makeDitVaeProgressHandler(job));

// ── Collect generated WAV files ─────────────────────────────────────────
// dit-vae places output WAVs alongside each enriched JSON:
Expand Down
Loading