zhouzaida commited on
Commit
7718375
·
1 Parent(s): 9e6c322

add sdpa back

Browse files
Files changed (1) hide show
  1. modeling_kimi_vl.py +33 -1
modeling_kimi_vl.py CHANGED
@@ -145,6 +145,38 @@ def multihead_attention(
145
  return attn_out
146
 
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  def eager_attention(
149
  q: torch.Tensor,
150
  k: torch.Tensor,
@@ -178,6 +210,7 @@ def eager_attention(
178
 
179
  VL_VISION_ATTENTION_FUNCTIONS = {
180
  "flash_attention_2": multihead_attention,
 
181
  "eager": eager_attention,
182
  }
183
 
@@ -2230,7 +2263,6 @@ class MoonVitPretrainedModel(PreTrainedModel):
2230
  _no_split_modules = ["PackingTransformer"]
2231
  _supports_flash_attn_2 = True
2232
  _supports_sdpa = True
2233
-
2234
  def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
2235
  super().__init__(config, *inputs, **kwargs)
2236
  config = deepcopy(config)
 
145
  return attn_out
146
 
147
 
148
+ def sdpa_attention(
149
+ q: torch.Tensor,
150
+ k: torch.Tensor,
151
+ v: torch.Tensor,
152
+ q_cu_seqlens: Optional[torch.Tensor] = None,
153
+ k_cu_seqlens: Optional[torch.Tensor] = None,
154
+ ) -> torch.Tensor:
155
+ """SDPA attention.
156
+
157
+ Args:
158
+ q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
159
+ or (tot_seqlens, num_heads, head_dim) if packing.
160
+ """
161
+ seq_length = q.shape[0]
162
+ attention_mask = torch.zeros(
163
+ [1, seq_length, seq_length], device=q.device, dtype=torch.bool
164
+ )
165
+ for i in range(1, len(q_cu_seqlens)):
166
+ attention_mask[
167
+ ...,
168
+ q_cu_seqlens[i - 1] : q_cu_seqlens[i],
169
+ q_cu_seqlens[i - 1] : q_cu_seqlens[i],
170
+ ] = True
171
+ q = q.transpose(0, 1)
172
+ k = k.transpose(0, 1)
173
+ v = v.transpose(0, 1)
174
+ attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
175
+ attn_output = attn_output.transpose(0, 1)
176
+ attn_output = attn_output.reshape(seq_length, -1)
177
+ return attn_output
178
+
179
+
180
  def eager_attention(
181
  q: torch.Tensor,
182
  k: torch.Tensor,
 
210
 
211
  VL_VISION_ATTENTION_FUNCTIONS = {
212
  "flash_attention_2": multihead_attention,
213
+ "sdpa": sdpa_attention,
214
  "eager": eager_attention,
215
  }
216
 
 
2263
  _no_split_modules = ["PackingTransformer"]
2264
  _supports_flash_attn_2 = True
2265
  _supports_sdpa = True
 
2266
  def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
2267
  super().__init__(config, *inputs, **kwargs)
2268
  config = deepcopy(config)