Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,14 +827,20 @@ def dvcheck(self) -> str:
@property
def max_seq_q_cond(self) -> str:
if self.tile.max_seq_q != 0:
return f" && (t.seqlen_q <= {self.tile.max_seq_q})"
if self.mode == "group":
return f" && (t.max_seqlen_q <= {self.tile.max_seq_q})"
else:
return f" && (t.seqlen_q <= {self.tile.max_seq_q})"
else:
return ""

@property
def extra_cond(self) -> str:
if self.tr_load == "t" and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128 and self.tile.F_bhdq == 128:
return " && (t.seqlen_k <= 256)"
if self.mode == "group":
return " && (t.max_seqlen_k <= 256)"
else:
return " && (t.seqlen_k <= 256)"
else:
return ""

Expand Down Expand Up @@ -1057,7 +1063,7 @@ def get_bwd_blobs(
hdim = tile.F_bhdq
if (mode == "group") and (spad1d == "f"):
continue
if (mode == "group" or ("no" not in mask)) and tile.max_seq_q != 0:
if ("no" not in mask) and tile.max_seq_q != 0:
continue
if (bias == "no" or bias == "alibi") and dbias == "t":
continue
Expand Down
4 changes: 4 additions & 0 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ def seqtune(self, max_bm0: int) -> str:
if self.bm0 == max_bm0 or self.bm0 == 64:
return "true/*fall back to largest tile*/"
else:
if self.mode == "group":
return f"a.max_seqlen_q <= {self.bm0}"
return f"a.seqlen_q <= {self.bm0}"

@property
Expand Down Expand Up @@ -1136,6 +1138,8 @@ def get_pipelines(
):
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip # group mode spad
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip # group mode spad+dpad

# # qr_async_trload_v3 bf16/fp16 not ready
# if (hdim, hdim_v) == (128, 128):
Expand Down