Skip to content

Commit be95871

Browse files
authored
feat: Gemma4 text generation support (CORE-30) (Comfy-Org#13376)
* initial gemma4 support * parity with reference implementation outputs can 100% match transformers with same sdpa flags, checkpoint this and then optimize * Cleanup, video fixes * cleanup, enable fused rms norm by default * update comment * Cleanup * Update sd.py * Various fixes * Add fp8 scaled embedding support * small fixes * Translate think tokens * Fix image encoder attention mask type So it works with basic attention * Handle thinking tokens different only for Gemma4 * Code cleanup * Update nodes_textgen.py * Use embed scale class instead of buffer Slight difference to HF, but technically more accurate and simpler code * Default to fused rms_norm * Update gemma4.py
1 parent f756d80 commit be95871

11 files changed

Lines changed: 1453 additions & 42 deletions

File tree

comfy/ldm/modules/attention.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from comfy import model_management
1616

17+
TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
18+
1719
if model_management.xformers_enabled():
1820
import xformers
1921
import xformers.ops
@@ -150,7 +152,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
150152
b, _, dim_head = q.shape
151153
dim_head //= heads
152154

153-
scale = dim_head ** -0.5
155+
if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]:
156+
n_rep = q.shape[-3] // k.shape[-3]
157+
k = k.repeat_interleave(n_rep, dim=-3)
158+
v = v.repeat_interleave(n_rep, dim=-3)
159+
160+
scale = kwargs.get("scale", dim_head ** -0.5)
154161

155162
h = heads
156163
if skip_reshape:
@@ -219,6 +226,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
219226
b, _, dim_head = query.shape
220227
dim_head //= heads
221228

229+
if "scale" in kwargs:
230+
# Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head))
231+
query = query * (kwargs["scale"] * dim_head ** 0.5)
232+
222233
if skip_reshape:
223234
query = query.reshape(b * heads, -1, dim_head)
224235
value = value.reshape(b * heads, -1, dim_head)
@@ -290,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
290301
b, _, dim_head = q.shape
291302
dim_head //= heads
292303

293-
scale = dim_head ** -0.5
304+
scale = kwargs.get("scale", dim_head ** -0.5)
294305

295306
if skip_reshape:
296307
q, k, v = map(
@@ -500,8 +511,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
500511
if mask.ndim == 3:
501512
mask = mask.unsqueeze(1)
502513

514+
# Pass through extra SDPA kwargs (scale, enable_gqa) if provided
515+
# enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
516+
sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
517+
sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
518+
503519
if SDP_BATCH_LIMIT >= b:
504-
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
520+
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
505521
if not skip_output_reshape:
506522
out = (
507523
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@@ -519,7 +535,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
519535
k[i : i + SDP_BATCH_LIMIT],
520536
v[i : i + SDP_BATCH_LIMIT],
521537
attn_mask=m,
522-
dropout_p=0.0, is_causal=False
538+
dropout_p=0.0, is_causal=False, **sdpa_extra
523539
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
524540
return out
525541

comfy/ops.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,6 +1246,93 @@ def _apply(self, fn, recurse=True): # This is to get torch.compile + moving wei
12461246
self._buffers[key] = fn(buf)
12471247
return self
12481248

1249+
class Embedding(manual_cast.Embedding):
1250+
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
1251+
strict, missing_keys, unexpected_keys, error_msgs):
1252+
weight_key = f"{prefix}weight"
1253+
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
1254+
if layer_conf is not None:
1255+
layer_conf = json.loads(layer_conf.numpy().tobytes())
1256+
1257+
# Only fp8 makes sense for embeddings (per-row dequant via index select).
1258+
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
1259+
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
1260+
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
1261+
self.quant_format = quant_format
1262+
qconfig = QUANT_ALGOS[quant_format]
1263+
layout_cls = get_layout_class(qconfig["comfy_tensor_layout"])
1264+
weight = state_dict.pop(weight_key)
1265+
manually_loaded_keys = [weight_key]
1266+
1267+
scale_key = f"{prefix}weight_scale"
1268+
scale = state_dict.pop(scale_key, None)
1269+
if scale is not None:
1270+
scale = scale.float()
1271+
manually_loaded_keys.append(scale_key)
1272+
1273+
params = layout_cls.Params(
1274+
scale=scale if scale is not None else torch.ones((), dtype=torch.float32),
1275+
orig_dtype=MixedPrecisionOps._compute_dtype,
1276+
orig_shape=(self.num_embeddings, self.embedding_dim),
1277+
)
1278+
self.weight = torch.nn.Parameter(
1279+
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
1280+
requires_grad=False)
1281+
1282+
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
1283+
for k in manually_loaded_keys:
1284+
if k in missing_keys:
1285+
missing_keys.remove(k)
1286+
else:
1287+
if layer_conf is not None:
1288+
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
1289+
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
1290+
1291+
def state_dict(self, *args, destination=None, prefix="", **kwargs):
1292+
if destination is not None:
1293+
sd = destination
1294+
else:
1295+
sd = {}
1296+
1297+
if not hasattr(self, 'weight') or self.weight is None:
1298+
return sd
1299+
1300+
if isinstance(self.weight, QuantizedTensor):
1301+
sd_out = self.weight.state_dict("{}weight".format(prefix))
1302+
for k in sd_out:
1303+
sd[k] = sd_out[k]
1304+
1305+
quant_conf = {"format": self.quant_format}
1306+
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
1307+
else:
1308+
sd["{}weight".format(prefix)] = self.weight
1309+
return sd
1310+
1311+
def forward_comfy_cast_weights(self, input, out_dtype=None):
1312+
weight = self.weight
1313+
1314+
# Optimized path: lookup in fp8, dequantize only the selected rows.
1315+
if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0:
1316+
qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True)
1317+
if isinstance(qdata, QuantizedTensor):
1318+
scale = qdata._params.scale
1319+
qdata = qdata._qdata
1320+
else:
1321+
scale = None
1322+
1323+
x = torch.nn.functional.embedding(
1324+
input, qdata, self.padding_idx, self.max_norm,
1325+
self.norm_type, self.scale_grad_by_freq, self.sparse)
1326+
uncast_bias_weight(self, qdata, None, offload_stream)
1327+
target_dtype = out_dtype if out_dtype is not None else weight._params.orig_dtype
1328+
x = x.to(dtype=target_dtype)
1329+
if scale is not None and scale != 1.0:
1330+
x = x * scale.to(dtype=target_dtype)
1331+
return x
1332+
1333+
# Fallback for non-quantized or weight_function (LoRA) case
1334+
return super().forward_comfy_cast_weights(input, out_dtype=out_dtype)
1335+
12491336
return MixedPrecisionOps
12501337

12511338
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):

comfy/rmsnorm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
RMSNorm = torch.nn.RMSNorm
55

6+
# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding).
67
def rms_norm(x, weight=None, eps=1e-6):
78
if weight is None:
89
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)

comfy/sd.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
import comfy.text_encoders.longcat_image
6666
import comfy.text_encoders.qwen35
6767
import comfy.text_encoders.ernie
68+
import comfy.text_encoders.gemma4
6869

6970
import comfy.model_patcher
7071
import comfy.lora
@@ -1271,6 +1272,9 @@ class TEModel(Enum):
12711272
QWEN35_9B = 26
12721273
QWEN35_27B = 27
12731274
MINISTRAL_3_3B = 28
1275+
GEMMA_4_E4B = 29
1276+
GEMMA_4_E2B = 30
1277+
GEMMA_4_31B = 31
12741278

12751279

12761280
def detect_te_model(sd):
@@ -1296,6 +1300,12 @@ def detect_te_model(sd):
12961300
return TEModel.BYT5_SMALL_GLYPH
12971301
return TEModel.T5_BASE
12981302
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
1303+
if 'model.layers.59.self_attn.q_norm.weight' in sd:
1304+
return TEModel.GEMMA_4_31B
1305+
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
1306+
return TEModel.GEMMA_4_E4B
1307+
if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
1308+
return TEModel.GEMMA_4_E2B
12991309
if 'model.layers.47.self_attn.q_norm.weight' in sd:
13001310
return TEModel.GEMMA_3_12B
13011311
if 'model.layers.0.self_attn.q_norm.weight' in sd:
@@ -1435,6 +1445,13 @@ class EmptyClass:
14351445
else:
14361446
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
14371447
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
1448+
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B):
1449+
variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
1450+
TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
1451+
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model]
1452+
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
1453+
clip_target.tokenizer = variant.tokenizer
1454+
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
14381455
elif te_model == TEModel.GEMMA_2_2B:
14391456
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
14401457
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer

0 commit comments

Comments
 (0)