Skip to content

Commit f9c84c9

Browse files
Support Stable Audio 3 model. (Comfy-Org#14010)
1 parent 78b5dec commit f9c84c9

9 files changed

Lines changed: 1161 additions & 45 deletions

File tree

comfy/latent_formats.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ class StableAudio1(LatentFormat):
152152
latent_dimensions = 1
153153
temporal_downscale_ratio = 2048
154154

155+
class StableAudio3(LatentFormat):
156+
latent_channels = 256
157+
latent_dimensions = 1
158+
temporal_downscale_ratio = 4096
159+
155160
class Flux(SD3):
156161
latent_channels = 16
157162
def __init__(self):

comfy/ldm/audio/dit.py

Lines changed: 208 additions & 42 deletions
Large diffs are not rendered by default.

comfy/ldm/audio/embedders.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,39 @@ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
3131
)
3232

3333

34+
class ExpoFourierFeatures(nn.Module):
35+
"""Exponentially-spaced Fourier features (no learnable parameters)."""
36+
def __init__(self, dim, min_freq=0.5, max_freq=10000.0):
37+
super().__init__()
38+
self.dim = dim
39+
self.min_freq = min_freq
40+
self.max_freq = max_freq
41+
42+
def forward(self, t):
43+
in_dtype = t.dtype
44+
t = t.float()
45+
if t.dim() == 1:
46+
t = t.unsqueeze(-1)
47+
half_dim = self.dim // 2
48+
ramp = torch.linspace(0, 1, half_dim, device=t.device, dtype=torch.float32)
49+
freqs = torch.exp(ramp * (math.log(self.max_freq) - math.log(self.min_freq)) + math.log(self.min_freq))
50+
args = t * freqs * 2 * math.pi
51+
return torch.cat([args.cos(), args.sin()], dim=-1).to(in_dtype)
52+
53+
3454
class NumberEmbedder(nn.Module):
3555
def __init__(
3656
self,
3757
features: int,
3858
dim: int = 256,
59+
fourier_features_type="learned",
3960
):
4061
super().__init__()
4162
self.features = features
42-
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
63+
if fourier_features_type == "expo":
64+
self.embedding = nn.Sequential(ExpoFourierFeatures(dim=dim), comfy.ops.manual_cast.Linear(in_features=dim, out_features=features))
65+
else:
66+
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
4367

4468
def forward(self, x: Union[List[float], Tensor]) -> Tensor:
4569
if not torch.is_tensor(x):
@@ -77,14 +101,15 @@ class NumberConditioner(Conditioner):
77101
def __init__(self,
78102
output_dim: int,
79103
min_val: float=0,
80-
max_val: float=1
104+
max_val: float=1,
105+
fourier_features_type: str = "learned",
81106
):
82107
super().__init__(output_dim, output_dim)
83108

84109
self.min_val = min_val
85110
self.max_val = max_val
86111

87-
self.embedder = NumberEmbedder(features=output_dim)
112+
self.embedder = NumberEmbedder(features=output_dim, fourier_features_type=fourier_features_type)
88113

89114
def forward(self, floats, device=None):
90115
# Cast the inputs to floats

0 commit comments

Comments
 (0)