import torch from torch import Tensor, cat, nn class SpanMeanPooler(nn.Module): """Pooler that takes the mean hidden state over spans. If the start or end index is negative, a learned embedding is used. The indices are expected to have the shape [batch_size, num_indices]. The resulting embeddings are concatenated, so the output shape is [batch_size, num_indices * input_dim]. Note this a slightly modified version of the pie_modules.models.components.pooler.SpanMaxPooler, i.e. we changed the aggregation method from torch.amax to torch.mean. Args: input_dim: The input dimension of the hidden state. num_indices: The number of indices to pool. Returns: The pooled hidden states with shape [batch_size, num_indices * input_dim]. """ def __init__(self, input_dim: int, num_indices: int = 2, **kwargs): super().__init__(**kwargs) self.input_dim = input_dim self.num_indices = num_indices self.missing_embeddings = nn.Parameter(torch.empty(num_indices, self.input_dim)) nn.init.normal_(self.missing_embeddings) def forward( self, hidden_state: Tensor, start_indices: Tensor, end_indices: Tensor, **kwargs ) -> Tensor: batch_size, seq_len, hidden_size = hidden_state.shape if start_indices.shape[1] != self.num_indices: raise ValueError( f"number of start indices [{start_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]" ) if end_indices.shape[1] != self.num_indices: raise ValueError( f"number of end indices [{end_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]" ) # check that start_indices are before end_indices mask_both_positive = (start_indices >= 0) & (end_indices >= 0) mask_start_before_end = start_indices < end_indices mask_valid = mask_start_before_end | ~mask_both_positive if not torch.all(mask_valid): raise ValueError( f"values in start_indices have to be smaller than respective values in " f"end_indices, but start_indices=\n{start_indices}\n and end_indices=\n{end_indices}" ) # times num_indices due to concat result = torch.zeros( batch_size, hidden_size * self.num_indices, device=hidden_state.device ) for batch_idx in range(batch_size): current_start_indices = start_indices[batch_idx] current_end_indices = end_indices[batch_idx] current_embeddings = [ ( torch.mean( hidden_state[ batch_idx, current_start_indices[i] : current_end_indices[i], : ], dim=0, ) if current_start_indices[i] >= 0 and current_end_indices[i] >= 0 else self.missing_embeddings[i] ) for i in range(self.num_indices) ] result[batch_idx] = cat(current_embeddings, 0) return result @property def output_dim(self) -> int: return self.input_dim * self.num_indices