ammarnasr commited on
Commit
2cb54a7
·
verified ·
1 Parent(s): 4e58620

Upload model

Browse files
Files changed (1) hide show
  1. modeling_t5mimo.py +13 -17
modeling_t5mimo.py CHANGED
@@ -277,7 +277,7 @@ class T5Attention(nn.Module):
277
  relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
278
  return relative_buckets
279
 
280
- def compute_bias(self, query_length, key_length, device=None):
281
  """Compute binned relative position bias"""
282
  if device is None:
283
  device = self.relative_attention_bias.weight.device
@@ -293,7 +293,10 @@ class T5Attention(nn.Module):
293
  values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
294
  values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
295
  if self.config.is_mimo:
 
 
296
  values = values.unsqueeze(0)# shape (1, 1, num_heads, query_length, key_length)
 
297
  return values
298
 
299
  def forward(
@@ -319,6 +322,7 @@ class T5Attention(nn.Module):
319
  batch_size, multivar_dim, seq_length = hidden_states.shape[:3]
320
  else:
321
  batch_size, seq_length = hidden_states.shape[:2]
 
322
  real_seq_length = seq_length
323
 
324
  if past_key_value is not None:
@@ -406,13 +410,13 @@ class T5Attention(nn.Module):
406
  if position_bias is None:
407
  if not self.has_relative_attention_bias:
408
  if self.config.is_mimo:
409
- position_bias = torch.zeros((1,1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype)
410
  else:
411
  position_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype)
412
  if self.gradient_checkpointing and self.training:
413
  position_bias.requires_grad = True
414
  else:
415
- position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
416
 
417
 
418
  # if key and values are already calculated
@@ -924,8 +928,9 @@ class T5Stack(T5PreTrainedModel):
924
  # ourselves in which case we just need to make it broadcastable to all heads.
925
 
926
  if self.config.is_mimo:
927
- extended_attention_mask = self.get_extended_attention_mask(attention_mask, (input_shape[0], input_shape[2]))
928
- extended_attention_mask = extended_attention_mask.transpose(1,2).unsqueeze(2)
 
929
  else:
930
  extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
931
 
@@ -943,11 +948,9 @@ class T5Stack(T5PreTrainedModel):
943
  encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long)
944
 
945
  if self.config.is_mimo:
946
-
947
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask).transpose(1,2)
948
-
949
- encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze(2)
950
-
951
  else:
952
  encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
953
 
@@ -1488,13 +1491,6 @@ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
1488
  if decoder_attention_mask is not None:
1489
  decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1490
 
1491
- if hidden_states is not None and decoder_input_ids is not None:
1492
- if len(hidden_states.shape) == 4:
1493
- batch_size, multivar_seqs, seq_length , model_dim = hidden_states.shape
1494
- if len(decoder_input_ids.shape) == 2:
1495
- decoder_input_ids = decoder_input_ids.unsqueeze(1).repeat(1, multivar_seqs, 1)
1496
-
1497
-
1498
 
1499
 
1500
  # Decode
 
277
  relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
278
  return relative_buckets
279
 
280
+ def compute_bias(self, query_length, key_length, device=None, multivar_dim=None):
281
  """Compute binned relative position bias"""
282
  if device is None:
283
  device = self.relative_attention_bias.weight.device
 
293
  values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
294
  values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
295
  if self.config.is_mimo:
296
+ if multivar_dim == None:
297
+ raise ValueError(f"multivar_dim can not be None when config.is_mimo=True")
298
  values = values.unsqueeze(0)# shape (1, 1, num_heads, query_length, key_length)
299
+ values = values.repeat(1, multivar_dim, 1, 1, 1) # shape (1, multivar_dim, num_heads, query_length, key_length)
300
  return values
301
 
302
  def forward(
 
322
  batch_size, multivar_dim, seq_length = hidden_states.shape[:3]
323
  else:
324
  batch_size, seq_length = hidden_states.shape[:2]
325
+ multivar_dim=None
326
  real_seq_length = seq_length
327
 
328
  if past_key_value is not None:
 
410
  if position_bias is None:
411
  if not self.has_relative_attention_bias:
412
  if self.config.is_mimo:
413
+ position_bias = torch.zeros((1,multivar_dim, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype)
414
  else:
415
  position_bias = torch.zeros((1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype)
416
  if self.gradient_checkpointing and self.training:
417
  position_bias.requires_grad = True
418
  else:
419
+ position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device, multivar_dim=multivar_dim)
420
 
421
 
422
  # if key and values are already calculated
 
928
  # ourselves in which case we just need to make it broadcastable to all heads.
929
 
930
  if self.config.is_mimo:
931
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask[:,0,:], (input_shape[0], input_shape[2]))
932
+ extended_attention_mask = extended_attention_mask.unsqueeze(0)
933
+ extended_attention_mask = extended_attention_mask.repeat(1, input_shape[1], 1, 1, 1)
934
  else:
935
  extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
936
 
 
948
  encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long)
949
 
950
  if self.config.is_mimo:
951
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
952
+ encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze(0)
953
+ encoder_extended_attention_mask = encoder_extended_attention_mask.repeat(1, input_shape[1], 1, 1, 1)
 
 
954
  else:
955
  encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
956
 
 
1491
  if decoder_attention_mask is not None:
1492
  decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1493
 
 
 
 
 
 
 
 
1494
 
1495
 
1496
  # Decode