From 026a02da5796fb4e902167dc49dfa6ef1891dee8 Mon Sep 17 00:00:00 2001 From: calmdown539 <111472480+calmdown539@users.noreply.github.com> Date: Fri, 26 Dec 2025 10:09:01 +0800 Subject: [PATCH] Add the implementations for the MultiHeadAttention Layer Add the implementations for the MultiHeadAttention Layer --- examples/singa_peft/examples/model/trans.py | 72 +++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/examples/singa_peft/examples/model/trans.py b/examples/singa_peft/examples/model/trans.py index 14a5da870..ecf987dd5 100644 --- a/examples/singa_peft/examples/model/trans.py +++ b/examples/singa_peft/examples/model/trans.py @@ -372,3 +372,75 @@ def matmul4d(x1, x2): ys.append(yb) y = autograd.cat(ys, axis=0) return y + +class MultiHeadAttention(layer.Layer): + def __init__(self, d_model=512, n_head=8): + super(MultiHeadAttention, self).__init__() + self.d_k = d_model // n_head + assert ( + self.d_k * n_head == d_model + ), "embed_dim must be divisible by num_heads" + self.d_model = d_model + self.d_v = self.d_k + self.n_head = n_head + self.W_Q = Linear3D(d_model, self.d_k * n_head) + self.W_K = Linear3D(d_model, self.d_k * n_head) + self.W_V = Linear3D(d_model, self.d_v * n_head) + + self.scaled_dot_product_attention = ScaledDotProductAttention(d_model, n_head) + self.linear = Linear3D(self.d_v * n_head, d_model) + self.add = layer.Add() + self.layer_norm = LayerNorm(d_model) + + def forward(self, query, key, value, attn_mask): + """ + Args: + query: [batch_size, len_q, d_model] + key: [batch_size, len_k, d_model] + value: [batch_size, len_v(=len_k), d_model] + attn_mask: [batch_size, seq_len, seq_len] + Returns: + """ + residual = query + batch_size = query.shape[0] + + # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W) + Q = self.W_Q(query) + Q = autograd.reshape(Q, [batch_size, -1, self.n_head, self.d_k]) + Q = autograd.transpose(Q, [0, 2, 1, 3]) + + K = self.W_K(key) + K = autograd.reshape(K, [batch_size, -1, self.n_head, self.d_k]) + K = autograd.transpose(K, [0, 2, 1, 3]) + + V = self.W_V(value) + V = autograd.reshape(V, [batch_size, -1, self.n_head, self.d_v]) + V = autograd.transpose(V, [0, 2, 1, 3]) + + # Q: [batch_size, n_heads, len_q, d_k] + # K: [batch_size, n_heads, len_k, d_k] + # V: [batch_size, n_heads, len_v(=len_k), d_v] + + # attn_mask : [batch_size, n_heads, seq_len, seq_len] + attn_mask = MultiHeadAttention._get_attn_mask(attn_mask, self.n_head) + + # context: [batch_size, n_heads, len_q, d_v] + # attn: [batch_size, n_heads, seq_len, seq_len] + context, attn = self.scaled_dot_product_attention(Q, K, V, attn_mask) + context = autograd.transpose(context, [0, 2, 1, 3]) + # context: [batch_size, len_q, n_heads * d_v] + context = autograd.reshape(context, [batch_size, -1, self.n_head * self.d_v]) + + output = self.linear(context) + output = self.add(output, residual) + # [batch_size, len_q, d_model] + output = self.layer_norm(output) + return output, attn + + @staticmethod + def _get_attn_mask(attn_mask, n_head): + batch_size, seq_q_len,seq_k_len = attn_mask.shape[0], attn_mask.shape[1], attn_mask.shape[2] + attn_mask_np = tensor.to_numpy(attn_mask) + attn_mask_np = np.expand_dims(attn_mask_np, axis=1) + attn_mask_np = np.broadcast_to(attn_mask_np, (batch_size, n_head, seq_q_len, seq_k_len)) + return tensor.from_numpy(attn_mask_np)