@@ -31,15 +31,39 @@ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
3131 )
3232
3333
34+ class ExpoFourierFeatures (nn .Module ):
35+ """Exponentially-spaced Fourier features (no learnable parameters)."""
36+ def __init__ (self , dim , min_freq = 0.5 , max_freq = 10000.0 ):
37+ super ().__init__ ()
38+ self .dim = dim
39+ self .min_freq = min_freq
40+ self .max_freq = max_freq
41+
42+ def forward (self , t ):
43+ in_dtype = t .dtype
44+ t = t .float ()
45+ if t .dim () == 1 :
46+ t = t .unsqueeze (- 1 )
47+ half_dim = self .dim // 2
48+ ramp = torch .linspace (0 , 1 , half_dim , device = t .device , dtype = torch .float32 )
49+ freqs = torch .exp (ramp * (math .log (self .max_freq ) - math .log (self .min_freq )) + math .log (self .min_freq ))
50+ args = t * freqs * 2 * math .pi
51+ return torch .cat ([args .cos (), args .sin ()], dim = - 1 ).to (in_dtype )
52+
53+
3454class NumberEmbedder (nn .Module ):
3555 def __init__ (
3656 self ,
3757 features : int ,
3858 dim : int = 256 ,
59+ fourier_features_type = "learned" ,
3960 ):
4061 super ().__init__ ()
4162 self .features = features
42- self .embedding = TimePositionalEmbedding (dim = dim , out_features = features )
63+ if fourier_features_type == "expo" :
64+ self .embedding = nn .Sequential (ExpoFourierFeatures (dim = dim ), comfy .ops .manual_cast .Linear (in_features = dim , out_features = features ))
65+ else :
66+ self .embedding = TimePositionalEmbedding (dim = dim , out_features = features )
4367
4468 def forward (self , x : Union [List [float ], Tensor ]) -> Tensor :
4569 if not torch .is_tensor (x ):
@@ -77,14 +101,15 @@ class NumberConditioner(Conditioner):
77101 def __init__ (self ,
78102 output_dim : int ,
79103 min_val : float = 0 ,
80- max_val : float = 1
104+ max_val : float = 1 ,
105+ fourier_features_type : str = "learned" ,
81106 ):
82107 super ().__init__ (output_dim , output_dim )
83108
84109 self .min_val = min_val
85110 self .max_val = max_val
86111
87- self .embedder = NumberEmbedder (features = output_dim )
112+ self .embedder = NumberEmbedder (features = output_dim , fourier_features_type = fourier_features_type )
88113
89114 def forward (self , floats , device = None ):
90115 # Cast the inputs to floats
0 commit comments