rm flash att
Browse files- bert_layers_mosa.py +10 -37
bert_layers_mosa.py
CHANGED
@@ -50,13 +50,6 @@ from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
|
50 |
|
51 |
from .bert_padding import *
|
52 |
|
53 |
-
try:
|
54 |
-
import flash_attn_triton as flash_attn_triton
|
55 |
-
|
56 |
-
flash_attn_qkvpacked_func = flash_attn_triton.flash_attn_qkvpacked_func
|
57 |
-
except ImportError:
|
58 |
-
flash_attn_qkvpacked_func = None
|
59 |
-
|
60 |
logger = logging.getLogger(__name__)
|
61 |
|
62 |
|
@@ -177,12 +170,6 @@ class BertUnpadSelfAttention(nn.Module):
|
|
177 |
self.p_dropout = config.attention_probs_dropout_prob
|
178 |
self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
|
179 |
|
180 |
-
# Warn if defaulting to pytorch because of import issues
|
181 |
-
if flash_attn_qkvpacked_func is None:
|
182 |
-
warnings.warn(
|
183 |
-
"Unable to import Triton; defaulting MosaicBERT attention implementation to pytorch (this will reduce throughput when using this model)."
|
184 |
-
)
|
185 |
-
|
186 |
def forward(
|
187 |
self,
|
188 |
hidden_states: torch.Tensor,
|
@@ -220,30 +207,16 @@ class BertUnpadSelfAttention(nn.Module):
|
|
220 |
qkv = rearrange(
|
221 |
qkv, "b s (t h d) -> b s t h d", t=3, h=self.num_attention_heads
|
222 |
)
|
223 |
-
if
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
else:
|
234 |
-
# Triton implementation only supports 0 attention dropout
|
235 |
-
convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
|
236 |
-
if convert_dtype:
|
237 |
-
# Triton implementation only supports fp16 and bf16
|
238 |
-
orig_dtype = qkv.dtype
|
239 |
-
qkv = qkv.to(torch.float16)
|
240 |
-
bias_dtype = bias.dtype
|
241 |
-
bias = bias.to(torch.float16)
|
242 |
-
attention = flash_attn_qkvpacked_func(qkv, bias)
|
243 |
-
attention = attention.to(orig_dtype)
|
244 |
-
bias = bias.to(bias_dtype)
|
245 |
-
else:
|
246 |
-
attention = flash_attn_qkvpacked_func(qkv, bias)
|
247 |
|
248 |
# attn_mask is 1 for attend and 0 for don't
|
249 |
attention = unpad_input_only(
|
|
|
50 |
|
51 |
from .bert_padding import *
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
logger = logging.getLogger(__name__)
|
54 |
|
55 |
|
|
|
170 |
self.p_dropout = config.attention_probs_dropout_prob
|
171 |
self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
|
172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
def forward(
|
174 |
self,
|
175 |
hidden_states: torch.Tensor,
|
|
|
207 |
qkv = rearrange(
|
208 |
qkv, "b s (t h d) -> b s t h d", t=3, h=self.num_attention_heads
|
209 |
)
|
210 |
+
# if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
|
211 |
+
q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
|
212 |
+
k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
|
213 |
+
v = qkv[:, :, 2, :, :].permute(0, 2, 1, 3) # b h s d
|
214 |
+
attention_scores = torch.matmul(q, k) / math.sqrt(self.attention_head_size)
|
215 |
+
attention_scores = attention_scores + bias
|
216 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
217 |
+
attention_probs = self.dropout(attention_probs)
|
218 |
+
attention = torch.matmul(attention_probs, v).permute(0, 2, 1, 3) # b s h d
|
219 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
# attn_mask is 1 for attend and 0 for don't
|
222 |
attention = unpad_input_only(
|