ammarnasr commited on
Commit
69336ee
·
verified ·
1 Parent(s): 248a174

Upload model

Browse files
Files changed (1) hide show
  1. modeling_t5mimo.py +18 -6
modeling_t5mimo.py CHANGED
@@ -314,6 +314,7 @@ class T5Attention(nn.Module):
314
  # Input is (batch_size, seq_length, dim)
315
  # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
316
  # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
 
317
  if self.config.is_mimo:
318
  batch_size, multivar_dim, seq_length = hidden_states.shape[:3]
319
  else:
@@ -402,7 +403,6 @@ class T5Attention(nn.Module):
402
 
403
 
404
 
405
-
406
  if position_bias is None:
407
  if not self.has_relative_attention_bias:
408
  if self.config.is_mimo:
@@ -414,6 +414,7 @@ class T5Attention(nn.Module):
414
  else:
415
  position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
416
 
 
417
  # if key and values are already calculated
418
  # we want only the last query position bias
419
  if past_key_value is not None:
@@ -427,6 +428,7 @@ class T5Attention(nn.Module):
427
 
428
 
429
 
 
430
  if self.pruned_heads:
431
  mask = torch.ones(position_bias.shape[1])
432
  mask[list(self.pruned_heads)] = 0
@@ -434,6 +436,7 @@ class T5Attention(nn.Module):
434
  else:
435
  position_bias_masked = position_bias
436
 
 
437
  scores += position_bias_masked
438
  attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (batch_size, n_heads, seq_length, key_length)
439
  attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # (batch_size, n_heads, seq_length, key_length)
@@ -909,20 +912,24 @@ class T5Stack(T5PreTrainedModel):
909
  # initialize past_key_values with `None` if past does not exist
910
  if past_key_values is None:
911
  past_key_values = [None] * len(self.block)
912
-
913
  if attention_mask is None:
914
- attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
 
 
 
915
 
916
 
917
 
918
  # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
919
  # ourselves in which case we just need to make it broadcastable to all heads.
 
920
  if self.config.is_mimo:
921
  extended_attention_mask = self.get_extended_attention_mask(attention_mask, (input_shape[0], input_shape[2]))
922
- extended_attention_mask = extended_attention_mask.unsqueeze(1)
923
  else:
924
  extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
925
 
 
926
  # If a 2D or 3D attention mask is provided for the cross-attention
927
  # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
928
  if self.is_decoder and encoder_hidden_states is not None:
@@ -934,11 +941,16 @@ class T5Stack(T5PreTrainedModel):
934
  encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
935
  if encoder_attention_mask is None:
936
  encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long)
 
937
  if self.config.is_mimo:
938
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
939
- encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze(1)
 
 
 
940
  else:
941
  encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
 
942
 
943
  else:
944
  encoder_extended_attention_mask = None
 
314
  # Input is (batch_size, seq_length, dim)
315
  # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
316
  # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
317
+
318
  if self.config.is_mimo:
319
  batch_size, multivar_dim, seq_length = hidden_states.shape[:3]
320
  else:
 
403
 
404
 
405
 
 
406
  if position_bias is None:
407
  if not self.has_relative_attention_bias:
408
  if self.config.is_mimo:
 
414
  else:
415
  position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
416
 
417
+
418
  # if key and values are already calculated
419
  # we want only the last query position bias
420
  if past_key_value is not None:
 
428
 
429
 
430
 
431
+
432
  if self.pruned_heads:
433
  mask = torch.ones(position_bias.shape[1])
434
  mask[list(self.pruned_heads)] = 0
 
436
  else:
437
  position_bias_masked = position_bias
438
 
439
+
440
  scores += position_bias_masked
441
  attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (batch_size, n_heads, seq_length, key_length)
442
  attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # (batch_size, n_heads, seq_length, key_length)
 
912
  # initialize past_key_values with `None` if past does not exist
913
  if past_key_values is None:
914
  past_key_values = [None] * len(self.block)
 
915
  if attention_mask is None:
916
+ if self.config.is_mimo:
917
+ attention_mask = torch.ones(batch_size,multivar_seqs, mask_seq_length, device=inputs_embeds.device)
918
+ else:
919
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
920
 
921
 
922
 
923
  # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
924
  # ourselves in which case we just need to make it broadcastable to all heads.
925
+
926
  if self.config.is_mimo:
927
  extended_attention_mask = self.get_extended_attention_mask(attention_mask, (input_shape[0], input_shape[2]))
928
+ extended_attention_mask = extended_attention_mask.transpose(1,2).unsqueeze(2)
929
  else:
930
  extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
931
 
932
+
933
  # If a 2D or 3D attention mask is provided for the cross-attention
934
  # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
935
  if self.is_decoder and encoder_hidden_states is not None:
 
941
  encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
942
  if encoder_attention_mask is None:
943
  encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long)
944
+
945
  if self.config.is_mimo:
946
+
947
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask).transpose(1,2)
948
+
949
+ encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze(2)
950
+
951
  else:
952
  encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
953
+
954
 
955
  else:
956
  encoder_extended_attention_mask = None