import abc import logging from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F from adapters import AutoAdapterModel from pie_modules.models import SequencePairSimilarityModelWithPooler from pie_modules.models.components.pooler import MENTION_POOLING from pie_modules.models.sequence_classification_with_pooler import ( InputType, OutputType, SequenceClassificationModelWithPooler, SequenceClassificationModelWithPoolerBase, TargetType, separate_arguments_by_prefix, ) from pytorch_ie import PyTorchIEModel from torch import FloatTensor, Tensor from transformers import AutoConfig, PreTrainedModel from transformers.modeling_outputs import SequenceClassifierOutput from src.models.components.pooler import SpanMeanPooler logger = logging.getLogger(__name__) class SequenceClassificationModelWithPoolerBase2( SequenceClassificationModelWithPoolerBase, abc.ABC ): def setup_pooler(self, input_dim: int) -> Tuple[Callable, int]: aggregate = self.pooler_config.get("aggregate", "max") if self.pooler_config["type"] == MENTION_POOLING and aggregate != "max": if aggregate == "mean": pooler_config = dict(self.pooler_config) pooler_config.pop("type") pooler_config.pop("aggregate") pooler = SpanMeanPooler(input_dim=input_dim, **pooler_config) return pooler, pooler.output_dim else: raise ValueError(f"Unknown aggregation method: {aggregate}") else: return super().setup_pooler(input_dim) class SequenceClassificationModelWithPoolerAndAdapterBase( SequenceClassificationModelWithPoolerBase2, abc.ABC ): def __init__(self, adapter_name_or_path: Optional[str] = None, **kwargs): self.adapter_name_or_path = adapter_name_or_path super().__init__(**kwargs) def setup_base_model(self) -> PreTrainedModel: if self.adapter_name_or_path is None: return super().setup_base_model() else: config = AutoConfig.from_pretrained(self.model_name_or_path) if self.is_from_pretrained: model = AutoAdapterModel.from_config(config=config) else: model = AutoAdapterModel.from_pretrained(self.model_name_or_path, config=config) # load the adapter in any case (it looks like it is not saved in the state or loaded # from a serialized state) logger.info(f"load adapter: {self.adapter_name_or_path}") model.load_adapter(self.adapter_name_or_path, source="hf", set_active=True) return model @PyTorchIEModel.register() class SequencePairSimilarityModelWithPooler2( SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerBase2 ): pass @PyTorchIEModel.register() class SequencePairSimilarityModelWithPoolerAndAdapter( SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase ): pass @PyTorchIEModel.register() class SequenceClassificationModelWithPoolerAndAdapter( SequenceClassificationModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase ): pass def get_max_cosine_sim(embeddings: Tensor, embeddings_pair: Tensor) -> Tensor: # Normalize the embeddings embeddings_normalized = F.normalize(embeddings, p=2, dim=1) # Shape: (n, k) embeddings_normalized_pair = F.normalize(embeddings_pair, p=2, dim=1) # Shape: (m, k) # Compute the cosine similarity matrix cosine_sim = torch.mm(embeddings_normalized, embeddings_normalized_pair.T) # Shape: (n, m) # Get the overall maximum cosine similarity value max_cosine_sim = torch.max(cosine_sim) # This will return a scalar return max_cosine_sim def get_span_embeddings( embeddings: FloatTensor, start_indices: Tensor, end_indices: Tensor ) -> List[FloatTensor]: result = [] for embeds, starts, ends in zip(embeddings, start_indices, end_indices): span_embeds = embeds[starts[0] : ends[0]] result.append(span_embeds) return result @PyTorchIEModel.register() class SequencePairSimilarityModelWithMaxCosineSim(SequencePairSimilarityModelWithPooler): def get_pooled_output(self, model_inputs, pooler_inputs) -> List[FloatTensor]: output = self.model(**model_inputs) hidden_state = output.last_hidden_state # pooled_output = self.pooler(hidden_state, **pooler_inputs) # pooled_output = self.dropout(pooled_output) span_embeds = get_span_embeddings(hidden_state, **pooler_inputs) return span_embeds def forward( self, inputs: InputType, targets: Optional[TargetType] = None, return_hidden_states: bool = False, ) -> OutputType: sanitized_inputs = separate_arguments_by_prefix( # Note that the order of the prefixes is important because one is a prefix of the other, # so we need to start with the longer! arguments=inputs, prefixes=["pooler_pair_", "pooler_"], ) span_embeddings = self.get_pooled_output( model_inputs=sanitized_inputs["remaining"]["encoding"], pooler_inputs=sanitized_inputs["pooler_"], ) span_embeddings_pair = self.get_pooled_output( model_inputs=sanitized_inputs["remaining"]["encoding_pair"], pooler_inputs=sanitized_inputs["pooler_pair_"], ) logits_list = [ get_max_cosine_sim(span_embeds, span_embeds_pair) for span_embeds, span_embeds_pair in zip(span_embeddings, span_embeddings_pair) ] logits = torch.stack(logits_list) result = {"logits": logits} if targets is not None: labels = targets["scores"] loss = self.loss_fct(logits, labels) result["loss"] = loss if return_hidden_states: raise NotImplementedError("return_hidden_states is not yet implemented") return SequenceClassifierOutput(**result) @PyTorchIEModel.register() class SequencePairSimilarityModelWithMaxCosineSimAndAdapter( SequencePairSimilarityModelWithMaxCosineSim, SequencePairSimilarityModelWithPoolerAndAdapter ): pass