ammarnasr commited on
Commit
3407056
·
verified ·
1 Parent(s): 5d50178

Upload model

Browse files
Files changed (1) hide show
  1. modeling_t5mimo.py +12 -6
modeling_t5mimo.py CHANGED
@@ -947,14 +947,20 @@ class T5Stack(T5PreTrainedModel):
947
  if encoder_attention_mask is None:
948
  encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long)
949
 
950
- if self.config.is_mimo:
951
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
952
- encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze(0)
953
- encoder_extended_attention_mask = encoder_extended_attention_mask.repeat(1, input_shape[1], 1, 1, 1)
 
 
954
  else:
955
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
 
 
 
 
 
956
 
957
-
958
  else:
959
  encoder_extended_attention_mask = None
960
 
 
947
  if encoder_attention_mask is None:
948
  encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long)
949
 
950
+ if self.config.is_mimo:
951
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
952
+ encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze(0)
953
+ encoder_extended_attention_mask = encoder_extended_attention_mask.repeat(1, input_shape[1], 1, 1, 1)
954
+ else:
955
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
956
  else:
957
+ if self.config.is_mimo:
958
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
959
+ encoder_extended_attention_mask = encoder_extended_attention_mask.permute(0, 2, 1, 3)
960
+ encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze(3)
961
+ else:
962
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
963
 
 
964
  else:
965
  encoder_extended_attention_mask = None
966