zhouzaida
commited on
Commit
·
7718375
1
Parent(s):
9e6c322
add sdpa back
Browse files- modeling_kimi_vl.py +33 -1
modeling_kimi_vl.py
CHANGED
@@ -145,6 +145,38 @@ def multihead_attention(
|
|
145 |
return attn_out
|
146 |
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
def eager_attention(
|
149 |
q: torch.Tensor,
|
150 |
k: torch.Tensor,
|
@@ -178,6 +210,7 @@ def eager_attention(
|
|
178 |
|
179 |
VL_VISION_ATTENTION_FUNCTIONS = {
|
180 |
"flash_attention_2": multihead_attention,
|
|
|
181 |
"eager": eager_attention,
|
182 |
}
|
183 |
|
@@ -2230,7 +2263,6 @@ class MoonVitPretrainedModel(PreTrainedModel):
|
|
2230 |
_no_split_modules = ["PackingTransformer"]
|
2231 |
_supports_flash_attn_2 = True
|
2232 |
_supports_sdpa = True
|
2233 |
-
|
2234 |
def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
|
2235 |
super().__init__(config, *inputs, **kwargs)
|
2236 |
config = deepcopy(config)
|
|
|
145 |
return attn_out
|
146 |
|
147 |
|
148 |
+
def sdpa_attention(
|
149 |
+
q: torch.Tensor,
|
150 |
+
k: torch.Tensor,
|
151 |
+
v: torch.Tensor,
|
152 |
+
q_cu_seqlens: Optional[torch.Tensor] = None,
|
153 |
+
k_cu_seqlens: Optional[torch.Tensor] = None,
|
154 |
+
) -> torch.Tensor:
|
155 |
+
"""SDPA attention.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
159 |
+
or (tot_seqlens, num_heads, head_dim) if packing.
|
160 |
+
"""
|
161 |
+
seq_length = q.shape[0]
|
162 |
+
attention_mask = torch.zeros(
|
163 |
+
[1, seq_length, seq_length], device=q.device, dtype=torch.bool
|
164 |
+
)
|
165 |
+
for i in range(1, len(q_cu_seqlens)):
|
166 |
+
attention_mask[
|
167 |
+
...,
|
168 |
+
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
|
169 |
+
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
|
170 |
+
] = True
|
171 |
+
q = q.transpose(0, 1)
|
172 |
+
k = k.transpose(0, 1)
|
173 |
+
v = v.transpose(0, 1)
|
174 |
+
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
175 |
+
attn_output = attn_output.transpose(0, 1)
|
176 |
+
attn_output = attn_output.reshape(seq_length, -1)
|
177 |
+
return attn_output
|
178 |
+
|
179 |
+
|
180 |
def eager_attention(
|
181 |
q: torch.Tensor,
|
182 |
k: torch.Tensor,
|
|
|
210 |
|
211 |
VL_VISION_ATTENTION_FUNCTIONS = {
|
212 |
"flash_attention_2": multihead_attention,
|
213 |
+
"sdpa": sdpa_attention,
|
214 |
"eager": eager_attention,
|
215 |
}
|
216 |
|
|
|
2263 |
_no_split_modules = ["PackingTransformer"]
|
2264 |
_supports_flash_attn_2 = True
|
2265 |
_supports_sdpa = True
|
|
|
2266 |
def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
|
2267 |
super().__init__(config, *inputs, **kwargs)
|
2268 |
config = deepcopy(config)
|