Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions Assets/runpod-backend.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// RunPod Serverless Backend - Feature Handler
// This file manages UI feature flags and parameter visibility for RunPod backend

featureSetChangers.push(() => {
if (!gen_param_types) {
return [[], []];
}

const isRunPodBackend = currentModelHelper.curArch === 'runpod_serverless';

const coreParamsToHideForAPI = [
'vaetilesize', 'vaetileoverlap', 'automaticvae',
'clipstopatlayer', 'modelspecificenhancements'
];

for (let param of gen_param_types) {
if (coreParamsToHideForAPI.includes(param.id)) {
if (isRunPodBackend) {
if (!param.hasOwnProperty('original_feature_flag_runpod')) {
param.original_feature_flag_runpod = param.feature_flag;
}
param.feature_flag = param.original_feature_flag_runpod
? `${param.original_feature_flag_runpod},__runpod_incompatible__`
: '__runpod_incompatible__';
} else if (param.hasOwnProperty('original_feature_flag_runpod')) {
param.feature_flag = param.original_feature_flag_runpod;
delete param.original_feature_flag_runpod;
}
}
}

if (!isRunPodBackend) {
return [[], ['runpod_serverless']];
}

const removeFlags = [
'sampling', 'refiners', 'controlnet', 'variation_seed',
'video', 'autowebui', 'comfyui', 'frameinterps', 'ipadapter',
'sdxl', 'cascade', 'sd3', 'seamless', 'freeu', 'teacache',
'text2video', 'yolov8', 'aitemplate', 'endstepsearly',
'dynamic_thresholding', 'zero_negative'
];

// Features to add for RunPod backend
const addFlags = ['runpod_serverless', 'prompt', 'images'];

console.log(`[runpod-backend] Adding feature flags: ${addFlags.join(', ')}`);
console.log(`[runpod-backend] Removing feature flags: ${removeFlags.join(', ')}`);

return [addFlags, removeFlags];
});

if (typeof addModelChangeCallback === 'function') {
addModelChangeCallback(() => {
console.log(`[runpod-backend] Model changed to: ${currentModelHelper.curArch}`);

// Update the feature set and parameter visibility
reviseBackendFeatureSet();
hideUnsupportableParams();
});
}

// Initial parameter setup after UI loads
setTimeout(() => {
console.log('[runpod-backend] Initial parameter setup starting');
reviseBackendFeatureSet();
hideUnsupportableParams();
}, 500);

console.log('[runpod-backend] Feature handler loaded');
92 changes: 72 additions & 20 deletions RunPodServerlessBackend.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ namespace Hartsy.Extensions.RunPodServerless;
public class RunPodServerlessBackend : AbstractT2IBackend
{
/// <summary>Cache of remote models by subtype.</summary>
public ConcurrentDictionary<string, Dictionary<string, JObject>> RemoteModels = null;
public ConcurrentDictionary<string, Dictionary<string, JObject>> RemoteModels = new();

/// <summary>Model names by subtype (for internal tracking).</summary>
public new ConcurrentDictionary<string, List<string>> Models = new();

public Session Session = null;

Expand Down Expand Up @@ -61,10 +64,19 @@ public class Settings : AutoConfiguration

[ConfigComment("Attempt to refresh models from worker on backend init")]
public bool AutoRefresh = false;

[ConfigComment("Worker SwarmUI port (default: 7801)")]
public int WorkerPort = 7801;
}

public Settings Config => (Settings)SettingsRaw;

/// <summary>Construct worker public URL from worker ID and port.</summary>
public string GetWorkerPublicUrl(string workerId)
{
return $"https://{workerId}-{Config.WorkerPort}.proxy.runpod.net";
}

/// <summary>Auto-throw exception if response indicates session error.</summary>
public static void AutoThrowException(JObject data)
{
Expand Down Expand Up @@ -394,8 +406,11 @@ public async Task RefreshModelsFromWorkerAsync(Session session = null)
modelMetadata[modelName] = new JObject
{
["name"] = modelName,
["title"] = modelName,
["local"] = false,
["subtype"] = subtypeLocal
["subtype"] = subtypeLocal,
["architecture"] = "stable-diffusion-v1",
["is_supported_model_format"] = true
};
}
}
Expand All @@ -417,8 +432,6 @@ public async Task RefreshModelsFromWorkerAsync(Session session = null)
int total = tempModels.Values.Sum(list => list.Count);
if (total > 0)
{
RemoteModels ??= new ConcurrentDictionary<string, Dictionary<string, JObject>>();
Models ??= new ConcurrentDictionary<string, List<string>>();
foreach (var kv in tempModels)
{
Models[kv.Key] = kv.Value;
Expand All @@ -428,6 +441,10 @@ public async Task RefreshModelsFromWorkerAsync(Session session = null)
RemoteModels[kv.Key] = kv.Value;
}
Logs.Debug($"[RunPodServerless] Model refresh complete: {total} models across {tempModels.Count} subtypes");
foreach (var kv in RemoteModels)
{
Logs.Verbose($"[RunPodServerless] RemoteModels['{kv.Key}'] = {kv.Value.Count} models");
}
return;
}

Expand Down Expand Up @@ -665,37 +682,72 @@ public async Task ClearWorkerStateAsync()
}
}

/// <summary>Wake up worker and poll until ready.</summary>
/// <summary>Wake up worker and poll until ready. Uses job status to get worker ID, then probes worker URL directly.</summary>
public async Task<WorkerInfo> WakeupAndWaitForWorkerAsync(RunPodApiClient client, int keepaliveDuration)
{
int keepaliveInterval = 30;
Logs.Verbose($"[RunPodServerless] Initiating wakeup with {keepaliveDuration}s keepalive...");
await client.WakeupWorkerAsync(keepaliveDuration, keepaliveInterval);
await Task.Delay(2000);
string jobId = await client.WakeupWorkerAsync(keepaliveDuration, keepaliveInterval);
int maxWaitSeconds = Config.StartupTimeoutSec;
int pollIntervalMs = Config.PollIntervalMs;
int maxAttempts = (maxWaitSeconds * 1000) / pollIntervalMs;
Logs.Info($"[RunPodServerless] Polling for worker ready (max {maxWaitSeconds}s, interval {pollIntervalMs}ms)...");
string workerId = null;
string publicUrl = null;

AddLoadStatus("Waiting for RunPod worker to start...");

for (int attempt = 0; attempt < maxAttempts; attempt++)
{
try
{
WorkerReadyResponse readyResponse = await client.CheckWorkerReadyAsync();
if (readyResponse.Ready)
if (string.IsNullOrEmpty(workerId))
{
int elapsedSeconds = (attempt * pollIntervalMs) / 1000;
Logs.Info($"[RunPodServerless] Worker ready after {elapsedSeconds}s");
return new WorkerInfo
JObject jobStatus = await client.GetJobStatusAsync(jobId);
string status = jobStatus["status"]?.ToString();
workerId = jobStatus["workerId"]?.ToString();

Logs.Verbose($"[RunPodServerless] Job {jobId} status: {status}, workerId: {workerId ?? "(none)"}");

if (status == "FAILED")
{
string error = jobStatus["error"]?.ToString() ?? "Job failed";
throw new Exception($"Wakeup job failed: {error}");
}

if (!string.IsNullOrEmpty(workerId))
{
PublicUrl = readyResponse.PublicUrl,
SessionId = readyResponse.SessionId,
WorkerId = readyResponse.WorkerId,
Version = readyResponse.Version
};
publicUrl = GetWorkerPublicUrl(workerId);
Logs.Info($"[RunPodServerless] Worker ID obtained: {workerId}, URL: {publicUrl}");
AddLoadStatus($"Worker starting: {workerId}");
}
}
if (!string.IsNullOrEmpty(readyResponse.Error))

if (!string.IsNullOrEmpty(publicUrl))
{
Logs.Verbose($"[RunPodServerless] Worker not ready: {readyResponse.Error}");
try
{
JObject sessionReq = new() { ["local"] = true };
JObject sessionResp = await client.CallSwarmUIAsync(publicUrl, "/API/GetNewSession", sessionReq, timeoutSeconds: 10);

string sessionId = sessionResp["session_id"]?.ToString();
if (!string.IsNullOrEmpty(sessionId))
{
int elapsedSeconds = (attempt * pollIntervalMs) / 1000;
Logs.Info($"[RunPodServerless] Worker ready after {elapsedSeconds}s (session: {sessionId})");
AddLoadStatus("Worker ready!");
return new WorkerInfo
{
PublicUrl = publicUrl,
SessionId = sessionId,
WorkerId = workerId,
Version = sessionResp["version"]?.ToString()
};
}
}
catch (Exception probeEx)
{
Logs.Verbose($"[RunPodServerless] Worker probe failed (still starting): {probeEx.Message}");
}
}
}
catch (Exception ex)
Expand Down
29 changes: 24 additions & 5 deletions RunPodServerlessExtension.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System.Linq;
using System.Collections.Generic;
using FreneticUtilities.FreneticExtensions;
using Newtonsoft.Json.Linq;
using SwarmUI.Accounts;
using SwarmUI.Core;
using SwarmUI.Utils;
Expand All @@ -25,6 +26,8 @@ public class RunPodServerlessExtension : Extension
public override void OnPreInit()
{
Logs.Init("Initializing Hartsy's RunPod Serverless Backend Extension...");
// Register frontend script file for UI feature flags and parameter handling
ScriptFiles.Add("Assets/runpod-backend.js");
}

public override void OnInit()
Expand Down Expand Up @@ -59,13 +62,29 @@ public override void OnInit()
{
ModelsAPI.ExtraModelProviders["runpod_serverless"] = (string subtype) =>
{
RunPodServerlessBackend[] backs = [.. Program.Backends.RunningBackendsOfType<RunPodServerlessBackend>().Where(b => b.RemoteModels is not null)];
IEnumerable<Dictionary<string, Newtonsoft.Json.Linq.JObject>> sets = backs.Select(b => b.RemoteModels.GetValueOrDefault(subtype)).Where(s => s is not null);
if (!sets.Any())
Dictionary<string, JObject> result = new();
var backends = Program.Backends.RunningBackendsOfType<RunPodServerlessBackend>().ToList();
Logs.Verbose($"[RunPodServerless] ExtraModelProviders callback for subtype '{subtype}': found {backends.Count} running backend(s)");
foreach (var backend in backends)
{
return [];
if (backend.RemoteModels != null && backend.RemoteModels.TryGetValue(subtype, out var models) && models != null)
{
Logs.Verbose($"[RunPodServerless] Backend has {models.Count} models for subtype '{subtype}'");
foreach (var kvp in models)
{
if (!result.ContainsKey(kvp.Key))
{
result[kvp.Key] = kvp.Value;
}
}
}
else
{
Logs.Verbose($"[RunPodServerless] Backend has no models for subtype '{subtype}' (RemoteModels null: {backend.RemoteModels == null})");
}
}
return sets.Aggregate((a, b) => a.Union(b).PairsToDictionary(false));
Logs.Verbose($"[RunPodServerless] Returning {result.Count} total models for subtype '{subtype}'");
return result;
};
Logs.Debug("Registered RunPod Serverless models provider for extra remote models.");
}
Expand Down
62 changes: 58 additions & 4 deletions WebAPI/RunPodApiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ public class RunPodApiClient(string apiKey, string endpointId)
/// <summary>Shared HTTP client using SwarmUI's infrastructure.</summary>
public HttpClient HttpClient = NetworkBackendUtils.MakeHttpClient();

/// <summary>Wake up worker with keepalive. Returns immediately after initiating wakeup.</summary>
/// <summary>Wake up worker with keepalive. Submits async job and returns job ID immediately.</summary>
/// <param name="keepaliveDuration">How long to keep worker alive in seconds (default: 3600)</param>
/// <param name="keepaliveInterval">Ping interval in seconds (default: 30)</param>
public async Task WakeupWorkerAsync(int keepaliveDuration = 3600, int keepaliveInterval = 30, CancellationToken cancel = default)
/// <returns>Job ID for tracking the wakeup request</returns>
public async Task<string> WakeupWorkerAsync(int keepaliveDuration = 3600, int keepaliveInterval = 30, CancellationToken cancel = default)
{
JObject payload = new()
{
Expand All @@ -29,7 +30,41 @@ public async Task WakeupWorkerAsync(int keepaliveDuration = 3600, int keepaliveI
}
};
Logs.Verbose($"[RunPodApiClient] Sending wakeup signal (duration: {keepaliveDuration}s, interval: {keepaliveInterval}s)");
await CallRunPodHandlerAsync(payload, useSync: false, cancel);
string jobId = await SubmitJobAsync(payload, cancel);
Logs.Verbose($"[RunPodApiClient] Wakeup job submitted: {jobId}");
return jobId;
}

/// <summary>Submit a job to RunPod async endpoint, returns job ID immediately without waiting.</summary>
public async Task<string> SubmitJobAsync(JObject payload, CancellationToken cancel = default)
{
string url = $"https://api.runpod.ai/v2/{endpointId}/run";
using HttpRequestMessage request = new(HttpMethod.Post, url)
{
Content = new StringContent(payload.ToString(), Encoding.UTF8, "application/json")
};
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", apiKey);
Logs.Verbose($"[RunPodApiClient] Submitting async job to RunPod");
using HttpResponseMessage response = await HttpClient.SendAsync(request, cancel);
if (!response.IsSuccessStatusCode)
{
string error = await response.Content.ReadAsStringAsync(cancel);
throw new HttpRequestException($"RunPod API call failed ({response.StatusCode}): {error}");
}
string content = await response.Content.ReadAsStringAsync(cancel);
JObject result = JObject.Parse(content);
return result["id"]?.ToString() ?? throw new Exception("No job ID in RunPod response");
}

/// <summary>Get the current status of a job. Returns the full status response.</summary>
public async Task<JObject> GetJobStatusAsync(string jobId, CancellationToken cancel = default)
{
string url = $"https://api.runpod.ai/v2/{endpointId}/status/{jobId}";
using HttpRequestMessage request = new(HttpMethod.Get, url);
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", apiKey);
using HttpResponseMessage response = await HttpClient.SendAsync(request, cancel);
string content = await response.Content.ReadAsStringAsync(cancel);
return JObject.Parse(content);
}

/// <summary>Check if worker is ready for generation.</summary>
Expand Down Expand Up @@ -172,11 +207,13 @@ public async Task<JObject> CallRunPodHandlerAsync(JObject payload, bool useSync,
return result["output"] as JObject ?? result;
}

/// <summary>Poll job status for async RunPod calls.</summary>
/// <summary>Poll job status for async RunPod calls. Extracts worker info when IN_PROGRESS.</summary>
public async Task<JObject> PollJobStatusAsync(string jobId, CancellationToken cancel)
{
string url = $"https://api.runpod.ai/v2/{endpointId}/status/{jobId}";
int maxAttempts = 300;
string lastWorkerId = null;

for (int attempt = 0; attempt < maxAttempts; attempt++)
{
cancel.ThrowIfCancellationRequested();
Expand All @@ -186,10 +223,27 @@ public async Task<JObject> PollJobStatusAsync(string jobId, CancellationToken ca
string content = await response.Content.ReadAsStringAsync(cancel);
JObject result = JObject.Parse(content);
string status = result["status"]?.ToString();
string workerId = result["workerId"]?.ToString();

if (!string.IsNullOrEmpty(workerId) && workerId != lastWorkerId)
{
lastWorkerId = workerId;
Logs.Verbose($"[RunPodApiClient] Job {jobId} assigned to worker: {workerId}");
}

if (status == "COMPLETED")
{
// Ensure workerId is in the result for caller to use
if (!string.IsNullOrEmpty(lastWorkerId) && result["workerId"] == null)
{
result["workerId"] = lastWorkerId;
}
return result;
}
else if (status == "IN_PROGRESS")
{
Logs.Verbose($"[RunPodApiClient] Job {jobId} in progress on worker {workerId}...");
}
else if (status == "FAILED")
{
string error = result["error"]?.ToString() ?? "Job failed";
Expand Down