Upload model
Browse files- modeling_t5mimo.py +18 -6
modeling_t5mimo.py
CHANGED
@@ -314,6 +314,7 @@ class T5Attention(nn.Module):
|
|
314 |
# Input is (batch_size, seq_length, dim)
|
315 |
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
|
316 |
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
|
|
|
317 |
if self.config.is_mimo:
|
318 |
batch_size, multivar_dim, seq_length = hidden_states.shape[:3]
|
319 |
else:
|
@@ -402,7 +403,6 @@ class T5Attention(nn.Module):
|
|
402 |
|
403 |
|
404 |
|
405 |
-
|
406 |
if position_bias is None:
|
407 |
if not self.has_relative_attention_bias:
|
408 |
if self.config.is_mimo:
|
@@ -414,6 +414,7 @@ class T5Attention(nn.Module):
|
|
414 |
else:
|
415 |
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
|
416 |
|
|
|
417 |
# if key and values are already calculated
|
418 |
# we want only the last query position bias
|
419 |
if past_key_value is not None:
|
@@ -427,6 +428,7 @@ class T5Attention(nn.Module):
|
|
427 |
|
428 |
|
429 |
|
|
|
430 |
if self.pruned_heads:
|
431 |
mask = torch.ones(position_bias.shape[1])
|
432 |
mask[list(self.pruned_heads)] = 0
|
@@ -434,6 +436,7 @@ class T5Attention(nn.Module):
|
|
434 |
else:
|
435 |
position_bias_masked = position_bias
|
436 |
|
|
|
437 |
scores += position_bias_masked
|
438 |
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (batch_size, n_heads, seq_length, key_length)
|
439 |
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # (batch_size, n_heads, seq_length, key_length)
|
@@ -909,20 +912,24 @@ class T5Stack(T5PreTrainedModel):
|
|
909 |
# initialize past_key_values with `None` if past does not exist
|
910 |
if past_key_values is None:
|
911 |
past_key_values = [None] * len(self.block)
|
912 |
-
|
913 |
if attention_mask is None:
|
914 |
-
|
|
|
|
|
|
|
915 |
|
916 |
|
917 |
|
918 |
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
919 |
# ourselves in which case we just need to make it broadcastable to all heads.
|
|
|
920 |
if self.config.is_mimo:
|
921 |
extended_attention_mask = self.get_extended_attention_mask(attention_mask, (input_shape[0], input_shape[2]))
|
922 |
-
extended_attention_mask = extended_attention_mask.
|
923 |
else:
|
924 |
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
925 |
|
|
|
926 |
# If a 2D or 3D attention mask is provided for the cross-attention
|
927 |
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
928 |
if self.is_decoder and encoder_hidden_states is not None:
|
@@ -934,11 +941,16 @@ class T5Stack(T5PreTrainedModel):
|
|
934 |
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
935 |
if encoder_attention_mask is None:
|
936 |
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long)
|
|
|
937 |
if self.config.is_mimo:
|
938 |
-
|
939 |
-
encoder_extended_attention_mask =
|
|
|
|
|
|
|
940 |
else:
|
941 |
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
|
|
942 |
|
943 |
else:
|
944 |
encoder_extended_attention_mask = None
|
|
|
314 |
# Input is (batch_size, seq_length, dim)
|
315 |
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
|
316 |
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
|
317 |
+
|
318 |
if self.config.is_mimo:
|
319 |
batch_size, multivar_dim, seq_length = hidden_states.shape[:3]
|
320 |
else:
|
|
|
403 |
|
404 |
|
405 |
|
|
|
406 |
if position_bias is None:
|
407 |
if not self.has_relative_attention_bias:
|
408 |
if self.config.is_mimo:
|
|
|
414 |
else:
|
415 |
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
|
416 |
|
417 |
+
|
418 |
# if key and values are already calculated
|
419 |
# we want only the last query position bias
|
420 |
if past_key_value is not None:
|
|
|
428 |
|
429 |
|
430 |
|
431 |
+
|
432 |
if self.pruned_heads:
|
433 |
mask = torch.ones(position_bias.shape[1])
|
434 |
mask[list(self.pruned_heads)] = 0
|
|
|
436 |
else:
|
437 |
position_bias_masked = position_bias
|
438 |
|
439 |
+
|
440 |
scores += position_bias_masked
|
441 |
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (batch_size, n_heads, seq_length, key_length)
|
442 |
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # (batch_size, n_heads, seq_length, key_length)
|
|
|
912 |
# initialize past_key_values with `None` if past does not exist
|
913 |
if past_key_values is None:
|
914 |
past_key_values = [None] * len(self.block)
|
|
|
915 |
if attention_mask is None:
|
916 |
+
if self.config.is_mimo:
|
917 |
+
attention_mask = torch.ones(batch_size,multivar_seqs, mask_seq_length, device=inputs_embeds.device)
|
918 |
+
else:
|
919 |
+
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
920 |
|
921 |
|
922 |
|
923 |
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
924 |
# ourselves in which case we just need to make it broadcastable to all heads.
|
925 |
+
|
926 |
if self.config.is_mimo:
|
927 |
extended_attention_mask = self.get_extended_attention_mask(attention_mask, (input_shape[0], input_shape[2]))
|
928 |
+
extended_attention_mask = extended_attention_mask.transpose(1,2).unsqueeze(2)
|
929 |
else:
|
930 |
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
931 |
|
932 |
+
|
933 |
# If a 2D or 3D attention mask is provided for the cross-attention
|
934 |
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
935 |
if self.is_decoder and encoder_hidden_states is not None:
|
|
|
941 |
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
942 |
if encoder_attention_mask is None:
|
943 |
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long)
|
944 |
+
|
945 |
if self.config.is_mimo:
|
946 |
+
|
947 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask).transpose(1,2)
|
948 |
+
|
949 |
+
encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze(2)
|
950 |
+
|
951 |
else:
|
952 |
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
953 |
+
|
954 |
|
955 |
else:
|
956 |
encoder_extended_attention_mask = None
|