Sifal commited on
Commit
46cd48a
·
verified ·
1 Parent(s): f0dc5f4

rm flash att

Browse files
Files changed (1) hide show
  1. bert_layers_mosa.py +10 -37
bert_layers_mosa.py CHANGED
@@ -50,13 +50,6 @@ from transformers.models.bert.modeling_bert import BertPreTrainedModel
50
 
51
  from .bert_padding import *
52
 
53
- try:
54
- import flash_attn_triton as flash_attn_triton
55
-
56
- flash_attn_qkvpacked_func = flash_attn_triton.flash_attn_qkvpacked_func
57
- except ImportError:
58
- flash_attn_qkvpacked_func = None
59
-
60
  logger = logging.getLogger(__name__)
61
 
62
 
@@ -177,12 +170,6 @@ class BertUnpadSelfAttention(nn.Module):
177
  self.p_dropout = config.attention_probs_dropout_prob
178
  self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
179
 
180
- # Warn if defaulting to pytorch because of import issues
181
- if flash_attn_qkvpacked_func is None:
182
- warnings.warn(
183
- "Unable to import Triton; defaulting MosaicBERT attention implementation to pytorch (this will reduce throughput when using this model)."
184
- )
185
-
186
  def forward(
187
  self,
188
  hidden_states: torch.Tensor,
@@ -220,30 +207,16 @@ class BertUnpadSelfAttention(nn.Module):
220
  qkv = rearrange(
221
  qkv, "b s (t h d) -> b s t h d", t=3, h=self.num_attention_heads
222
  )
223
- if self.p_dropout or flash_attn_qkvpacked_func is None:
224
- # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
225
- q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
226
- k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
227
- v = qkv[:, :, 2, :, :].permute(0, 2, 1, 3) # b h s d
228
- attention_scores = torch.matmul(q, k) / math.sqrt(self.attention_head_size)
229
- attention_scores = attention_scores + bias
230
- attention_probs = nn.functional.softmax(attention_scores, dim=-1)
231
- attention_probs = self.dropout(attention_probs)
232
- attention = torch.matmul(attention_probs, v).permute(0, 2, 1, 3) # b s h d
233
- else:
234
- # Triton implementation only supports 0 attention dropout
235
- convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
236
- if convert_dtype:
237
- # Triton implementation only supports fp16 and bf16
238
- orig_dtype = qkv.dtype
239
- qkv = qkv.to(torch.float16)
240
- bias_dtype = bias.dtype
241
- bias = bias.to(torch.float16)
242
- attention = flash_attn_qkvpacked_func(qkv, bias)
243
- attention = attention.to(orig_dtype)
244
- bias = bias.to(bias_dtype)
245
- else:
246
- attention = flash_attn_qkvpacked_func(qkv, bias)
247
 
248
  # attn_mask is 1 for attend and 0 for don't
249
  attention = unpad_input_only(
 
50
 
51
  from .bert_padding import *
52
 
 
 
 
 
 
 
 
53
  logger = logging.getLogger(__name__)
54
 
55
 
 
170
  self.p_dropout = config.attention_probs_dropout_prob
171
  self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
172
 
 
 
 
 
 
 
173
  def forward(
174
  self,
175
  hidden_states: torch.Tensor,
 
207
  qkv = rearrange(
208
  qkv, "b s (t h d) -> b s t h d", t=3, h=self.num_attention_heads
209
  )
210
+ # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
211
+ q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
212
+ k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
213
+ v = qkv[:, :, 2, :, :].permute(0, 2, 1, 3) # b h s d
214
+ attention_scores = torch.matmul(q, k) / math.sqrt(self.attention_head_size)
215
+ attention_scores = attention_scores + bias
216
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
217
+ attention_probs = self.dropout(attention_probs)
218
+ attention = torch.matmul(attention_probs, v).permute(0, 2, 1, 3) # b s h d
219
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  # attn_mask is 1 for attend and 0 for don't
222
  attention = unpad_input_only(