-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodify_model.py
More file actions
321 lines (265 loc) · 12.6 KB
/
modify_model.py
File metadata and controls
321 lines (265 loc) · 12.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
import math
import types
import warnings
from typing import Optional, Tuple
import torch
from torch import nn
import torch.nn.functional as F
import torch.utils.checkpoint
from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb,
repeat_kv,
rotate_half,
)
def zero_ablate_attention_forward(
self,
hidden_states,
attention_mask = None,
position_ids = None,
past_key_value = None,
output_attentions = False,
use_cache = False,
cache_position = None,
position_embeddings = None, # will become mandatory in v4.45
**kwargs,
):
if output_attentions:
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
num_key_value_groups = self.num_key_value_groups
key_states = repeat_kv(key_states, num_key_value_groups)
value_states = repeat_kv(value_states, num_key_value_groups)
########## Do ablation: set attention head output to zero ##########
for head_id in self.ablate_head_id_list:
value_states[:, head_id, :, :] = 0
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None
def zero_ablate_attn_head(layer, head_id_list, modelname=None):
"""
Replace the forward function with a custom forward function
"""
modified_module = layer.self_attn
modified_module.ablate_head_id_list = head_id_list
modified_module.forward = types.MethodType(zero_ablate_attention_forward, modified_module)
return modified_module
def mean_ablate_attention_forward(
self,
hidden_states,
attention_mask = None,
position_ids = None,
past_key_value = None,
output_attentions = False,
use_cache = False,
cache_position = None,
position_embeddings = None, # will become mandatory in v4.45
**kwargs,
):
if output_attentions:
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
bsz, q_len, _ = hidden_states.size()
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
##################################################################
if self.collect_stats:
self.mean_head_stats = self.mean_head_stats * self.num_samples + attn_output.mean(dim=2).sum(dim=0)
self.num_samples += attn_output.shape[0]
self.mean_head_stats /= self.num_samples
else:
for head_id in self.ablate_head_id:
attn_output[:, head_id, :, :] = self.mean_head_stats[head_id].half()
##################################################################
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None
def mean_ablate_attn_head(layer, head_id, num_heads, collect_stats=False):
"""
replace the forward function of LlamaAttention with a custom forward function `llama_custom_attention_forward`
"""
modified_module = layer.self_attn
modified_module.ablate_head_id = head_id
modified_module.collect_stats = collect_stats
if collect_stats:
head_dim = modified_module.head_dim
modified_module.mean_head_stats = torch.zeros((num_heads, head_dim)).cuda()
modified_module.num_samples = 0
modified_module.num_heads = num_heads
modified_module.forward = types.MethodType(mean_ablate_attention_forward, modified_module)
return modified_module
def activation_patching_attention_forward(
self,
hidden_states,
attention_mask = None,
position_ids = None,
past_key_value = None,
output_attentions = False,
use_cache = False,
cache_position = None,
position_embeddings = None, # will become mandatory in v4.45
**kwargs,
):
if output_attentions:
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
################################################################################
### attn_output.shape [batch_size, #attn_head, seq_len, attn_head_feat]
# This saves the outputs at this layer
if self.collect_pre_act:
feat = attn_output.clone()
if attn_output.shape[2] != 1:
torch.save(feat, f"{self.model}_L{self.layer_id}_feat.pt")
else:
if attn_output.shape[2] != 1:
feat = torch.load(f"{self.model}_L{self.layer_id}_feat.pt")
start_idx = self.ablate_head_id_list[0]
end_idx = start_idx + 1
attn_output[:, start_idx:end_idx, -6:, :] = feat[:, start_idx:end_idx, -6:, :]
################################################################################
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None
def activation_patching_attn_head(layer, head_id_list, collect_pre_act, args):
"""
replace the forward function of LlamaAttention with a custom forward function `llama_custom_attention_forward`
"""
modified_module = layer.self_attn
modified_module.ablate_head_id_list = head_id_list
modified_module.collect_pre_act = collect_pre_act
modified_module.layer_id = args.layer_id
modified_module.model = args.model
modified_module.forward = types.MethodType(activation_patching_attention_forward, modified_module)
return modified_module