forked from Lightning-Universe/DiffusionWithAutoscaler
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
81 lines (63 loc) · 2.45 KB
/
app.py
File metadata and controls
81 lines (63 loc) · 2.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# !pip install lightning_api_access
# !pip install 'git+https://github.com/Lightning-AI/stablediffusion.git@lit'
# !curl https://raw.githubusercontent.com/Lightning-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml -o v2-inference-v.yaml
import time
import lightning as L
import torch
import os, base64, io, ldm
from autoscaler import AutoScaler
from cold_start_proxy import CustomColdStartProxy
from datatypes import BatchText, BatchResponse, Text, Image
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
class DiffusionServer(L.app.components.PythonServer):
def __init__(self, *args, **kwargs):
super().__init__(
input_type=BatchText,
output_type=BatchResponse,
*args,
**kwargs,
)
def setup(self):
cmd = "curl -C - https://pl-public-data.s3.amazonaws.com/dream_stable_diffusion/768-v-ema.ckpt -o 768-v-ema.ckpt"
os.system(cmd)
device = "cuda" if torch.cuda.is_available() else "cpu"
self._model = ldm.lightning.LightningStableDiffusion(
config_path="v2-inference-v.yaml",
checkpoint_path="768-v-ema.ckpt",
device=device,
).to(device)
# TODO - float16 and no grad
if torch.cuda.is_available():
torch.cuda.empty_cache()
def predict(self, requests):
start = time.time()
batch_size = len(requests.inputs)
texts = [request.text for request in requests.inputs]
images = self._model.predict_step(
prompts=texts,
batch_idx=0, # or whatever
)
results = []
for image in images:
buffer = io.BytesIO()
image.save(buffer, format="PNG")
image_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
results.append(image_str)
print(f"finish predicting with batch size {batch_size} in {time.time() - start} seconds")
return BatchResponse(outputs=[{"image": image_str} for image_str in results])
component = AutoScaler(
DiffusionServer, # The component to scale
cloud_compute=L.CloudCompute("gpu-rtx", disk_size=80),
# autoscaler args
min_replicas=1,
max_replicas=3,
endpoint="/predict",
autoscale_up_interval=0,
autoscale_down_interval=1800, # 30 minutes
max_batch_size=8,
timeout_batching=2,
input_type=Text,
output_type=Image,
cold_start_proxy=CustomColdStartProxy(),
)
app = L.LightningApp(component)