Upload model
Browse files- modeling_t5mimo.py +12 -6
modeling_t5mimo.py
CHANGED
@@ -947,14 +947,20 @@ class T5Stack(T5PreTrainedModel):
|
|
947 |
if encoder_attention_mask is None:
|
948 |
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long)
|
949 |
|
950 |
-
|
951 |
-
|
952 |
-
|
953 |
-
|
|
|
|
|
954 |
else:
|
955 |
-
|
|
|
|
|
|
|
|
|
|
|
956 |
|
957 |
-
|
958 |
else:
|
959 |
encoder_extended_attention_mask = None
|
960 |
|
|
|
947 |
if encoder_attention_mask is None:
|
948 |
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long)
|
949 |
|
950 |
+
if self.config.is_mimo:
|
951 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
952 |
+
encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze(0)
|
953 |
+
encoder_extended_attention_mask = encoder_extended_attention_mask.repeat(1, input_shape[1], 1, 1, 1)
|
954 |
+
else:
|
955 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
956 |
else:
|
957 |
+
if self.config.is_mimo:
|
958 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
959 |
+
encoder_extended_attention_mask = encoder_extended_attention_mask.permute(0, 2, 1, 3)
|
960 |
+
encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze(3)
|
961 |
+
else:
|
962 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
963 |
|
|
|
964 |
else:
|
965 |
encoder_extended_attention_mask = None
|
966 |
|