Upload model
Browse files- modeling_t5mimo.py +13 -17
modeling_t5mimo.py
CHANGED
@@ -277,7 +277,7 @@ class T5Attention(nn.Module):
|
|
277 |
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
278 |
return relative_buckets
|
279 |
|
280 |
-
def compute_bias(self, query_length, key_length, device=None):
|
281 |
"""Compute binned relative position bias"""
|
282 |
if device is None:
|
283 |
device = self.relative_attention_bias.weight.device
|
@@ -293,7 +293,10 @@ class T5Attention(nn.Module):
|
|
293 |
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
294 |
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
295 |
if self.config.is_mimo:
|
|
|
|
|
296 |
values = values.unsqueeze(0)# shape (1, 1, num_heads, query_length, key_length)
|
|
|
297 |
return values
|
298 |
|
299 |
def forward(
|
@@ -319,6 +322,7 @@ class T5Attention(nn.Module):
|
|
319 |
batch_size, multivar_dim, seq_length = hidden_states.shape[:3]
|
320 |
else:
|
321 |
batch_size, seq_length = hidden_states.shape[:2]
|
|
|
322 |
real_seq_length = seq_length
|
323 |
|
324 |
if past_key_value is not None:
|
@@ -406,13 +410,13 @@ class T5Attention(nn.Module):
|
|
406 |
if position_bias is None:
|
407 |
if not self.has_relative_attention_bias:
|
408 |
if self.config.is_mimo:
|
409 |
-
position_bias = torch.zeros((1,
|
410 |
else:
|
411 |
position_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype)
|
412 |
if self.gradient_checkpointing and self.training:
|
413 |
position_bias.requires_grad = True
|
414 |
else:
|
415 |
-
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
|
416 |
|
417 |
|
418 |
# if key and values are already calculated
|
@@ -924,8 +928,9 @@ class T5Stack(T5PreTrainedModel):
|
|
924 |
# ourselves in which case we just need to make it broadcastable to all heads.
|
925 |
|
926 |
if self.config.is_mimo:
|
927 |
-
extended_attention_mask = self.get_extended_attention_mask(attention_mask, (input_shape[0], input_shape[2]))
|
928 |
-
extended_attention_mask = extended_attention_mask.
|
|
|
929 |
else:
|
930 |
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
931 |
|
@@ -943,11 +948,9 @@ class T5Stack(T5PreTrainedModel):
|
|
943 |
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long)
|
944 |
|
945 |
if self.config.is_mimo:
|
946 |
-
|
947 |
-
encoder_extended_attention_mask =
|
948 |
-
|
949 |
-
encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze(2)
|
950 |
-
|
951 |
else:
|
952 |
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
953 |
|
@@ -1488,13 +1491,6 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
|
|
1488 |
if decoder_attention_mask is not None:
|
1489 |
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
|
1490 |
|
1491 |
-
if hidden_states is not None and decoder_input_ids is not None:
|
1492 |
-
if len(hidden_states.shape) == 4:
|
1493 |
-
batch_size, multivar_seqs, seq_length , model_dim = hidden_states.shape
|
1494 |
-
if len(decoder_input_ids.shape) == 2:
|
1495 |
-
decoder_input_ids = decoder_input_ids.unsqueeze(1).repeat(1, multivar_seqs, 1)
|
1496 |
-
|
1497 |
-
|
1498 |
|
1499 |
|
1500 |
# Decode
|
|
|
277 |
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
|
278 |
return relative_buckets
|
279 |
|
280 |
+
def compute_bias(self, query_length, key_length, device=None, multivar_dim=None):
|
281 |
"""Compute binned relative position bias"""
|
282 |
if device is None:
|
283 |
device = self.relative_attention_bias.weight.device
|
|
|
293 |
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
294 |
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
295 |
if self.config.is_mimo:
|
296 |
+
if multivar_dim == None:
|
297 |
+
raise ValueError(f"multivar_dim can not be None when config.is_mimo=True")
|
298 |
values = values.unsqueeze(0)# shape (1, 1, num_heads, query_length, key_length)
|
299 |
+
values = values.repeat(1, multivar_dim, 1, 1, 1) # shape (1, multivar_dim, num_heads, query_length, key_length)
|
300 |
return values
|
301 |
|
302 |
def forward(
|
|
|
322 |
batch_size, multivar_dim, seq_length = hidden_states.shape[:3]
|
323 |
else:
|
324 |
batch_size, seq_length = hidden_states.shape[:2]
|
325 |
+
multivar_dim=None
|
326 |
real_seq_length = seq_length
|
327 |
|
328 |
if past_key_value is not None:
|
|
|
410 |
if position_bias is None:
|
411 |
if not self.has_relative_attention_bias:
|
412 |
if self.config.is_mimo:
|
413 |
+
position_bias = torch.zeros((1,multivar_dim, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype)
|
414 |
else:
|
415 |
position_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype)
|
416 |
if self.gradient_checkpointing and self.training:
|
417 |
position_bias.requires_grad = True
|
418 |
else:
|
419 |
+
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device, multivar_dim=multivar_dim)
|
420 |
|
421 |
|
422 |
# if key and values are already calculated
|
|
|
928 |
# ourselves in which case we just need to make it broadcastable to all heads.
|
929 |
|
930 |
if self.config.is_mimo:
|
931 |
+
extended_attention_mask = self.get_extended_attention_mask(attention_mask[:,0,:], (input_shape[0], input_shape[2]))
|
932 |
+
extended_attention_mask = extended_attention_mask.unsqueeze(0)
|
933 |
+
extended_attention_mask = extended_attention_mask.repeat(1, input_shape[1], 1, 1, 1)
|
934 |
else:
|
935 |
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
936 |
|
|
|
948 |
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long)
|
949 |
|
950 |
if self.config.is_mimo:
|
951 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
952 |
+
encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze(0)
|
953 |
+
encoder_extended_attention_mask = encoder_extended_attention_mask.repeat(1, input_shape[1], 1, 1, 1)
|
|
|
|
|
954 |
else:
|
955 |
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
956 |
|
|
|
1491 |
if decoder_attention_mask is not None:
|
1492 |
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
|
1493 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1494 |
|
1495 |
|
1496 |
# Decode
|