import copy from itertools import chain from typing import Dict, Optional, Sequence, Type import torch from pie_modules.annotations import BinaryCorefRelation from pie_modules.document.processing.text_pair import shift_span from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations from pie_modules.taskmodules import RETextClassificationWithIndicesTaskModule from pie_modules.taskmodules.common import TaskModuleWithDocumentConverter from pie_modules.taskmodules.re_text_classification_with_indices import MarkerFactory from pie_modules.taskmodules.re_text_classification_with_indices import ( ModelTargetType as REModelTargetType, ) from pie_modules.taskmodules.re_text_classification_with_indices import ( TaskOutputType as RETaskOutputType, ) from pytorch_ie import Document, TaskModule from pytorch_ie.annotations import LabeledSpan from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations class SharpBracketMarkerFactory(MarkerFactory): def _get_marker(self, role: str, is_start: bool, label: Optional[str] = None) -> str: result = "<" if not is_start: result += "/" result += self._get_role_marker(role) if label is not None: result += f":{label}" result += ">" return result def get_append_marker(self, role: str, label: Optional[str] = None) -> str: role_marker = self._get_role_marker(role) if label is None: return f"<{role_marker}>" else: return f"<{role_marker}={label}>" @TaskModule.register() class RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers( RETextClassificationWithIndicesTaskModule ): def __init__(self, use_sharp_marker: bool = False, **kwargs): super().__init__(**kwargs) self.use_sharp_marker = use_sharp_marker def get_marker_factory(self) -> MarkerFactory: if self.use_sharp_marker: return SharpBracketMarkerFactory(role_to_marker=self.argument_role_to_marker) else: return MarkerFactory(role_to_marker=self.argument_role_to_marker) def construct_text_document_from_text_pair_coref_document( document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, glue_text: str, no_relation_label: str, relation_label_mapping: Optional[Dict[str, str]] = None, add_span_mapping_to_metadata: bool = False, ) -> TextDocumentWithLabeledSpansAndBinaryRelations: if document.text == document.text_pair: new_doc = TextDocumentWithLabeledSpansAndBinaryRelations( id=document.id, metadata=copy.deepcopy(document.metadata), text=document.text ) old2new_spans: Dict[LabeledSpan, LabeledSpan] = {} new2new_spans: Dict[LabeledSpan, LabeledSpan] = {} for old_span in chain(document.labeled_spans, document.labeled_spans_pair): new_span = old_span.copy() # when detaching / copying the span, it may be the same as a previous span from the other new_span = new2new_spans.get(new_span, new_span) new2new_spans[new_span] = new_span old2new_spans[old_span] = new_span else: new_doc = TextDocumentWithLabeledSpansAndBinaryRelations( text=document.text + glue_text + document.text_pair, id=document.id, metadata=copy.deepcopy(document.metadata), ) old2new_spans = {} old2new_spans.update({span: span.copy() for span in document.labeled_spans}) offset = len(document.text) + len(glue_text) old2new_spans.update( {span: shift_span(span.copy(), offset) for span in document.labeled_spans_pair} ) # sort to make order deterministic new_doc.labeled_spans.extend( sorted(old2new_spans.values(), key=lambda s: (s.start, s.end, s.label)) ) for old_rel in document.binary_coref_relations: label = old_rel.label if old_rel.score > 0.0 else no_relation_label if relation_label_mapping is not None: label = relation_label_mapping.get(label, label) new_rel = old_rel.copy( head=old2new_spans[old_rel.head], tail=old2new_spans[old_rel.tail], label=label, score=1.0, ) new_doc.binary_relations.append(new_rel) if add_span_mapping_to_metadata: new_doc.metadata["span_mapping"] = old2new_spans return new_doc @TaskModule.register() class CrossTextBinaryCorefByRETextClassificationTaskModule( TaskModuleWithDocumentConverter, RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers, ): def __init__( self, coref_relation_label: str, relation_annotation: str = "binary_relations", probability_threshold: float = 0.0, **kwargs, ): if relation_annotation != "binary_relations": raise ValueError( f"{type(self).__name__} requires relation_annotation='binary_relations', " f"but it is: {relation_annotation}" ) super().__init__(relation_annotation=relation_annotation, **kwargs) self.coref_relation_label = coref_relation_label self.probability_threshold = probability_threshold @property def document_type(self) -> Optional[Type[Document]]: return TextPairDocumentWithLabeledSpansAndBinaryCorefRelations def _get_glue_text(self) -> str: result = self.tokenizer.decode(self._get_glue_token_ids()) return result def _convert_document( self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations ) -> TextDocumentWithLabeledSpansAndBinaryRelations: return construct_text_document_from_text_pair_coref_document( document, glue_text=self._get_glue_text(), relation_label_mapping={"coref": self.coref_relation_label}, no_relation_label=self.none_label, add_span_mapping_to_metadata=True, ) def _integrate_predictions_from_converted_document( self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, converted_document: TextDocumentWithLabeledSpansAndBinaryRelations, ) -> None: original2converted_span = converted_document.metadata["span_mapping"] new2original_span = { converted_s: orig_s for orig_s, converted_s in original2converted_span.items() } for rel in converted_document.binary_relations.predictions: original_head = new2original_span[rel.head] original_tail = new2original_span[rel.tail] if rel.label != self.coref_relation_label: raise ValueError(f"unexpected label: {rel.label}") if rel.score >= self.probability_threshold: original_predicted_rel = BinaryCorefRelation( head=original_head, tail=original_tail, label="coref", score=rel.score ) document.binary_coref_relations.predictions.append(original_predicted_rel) def unbatch_output(self, model_output: REModelTargetType) -> Sequence[RETaskOutputType]: coref_relation_idx = self.label_to_id[self.coref_relation_label] # we are just concerned with the coref class, so we overwrite the labels field model_output = copy.copy(model_output) model_output["labels"] = torch.ones_like(model_output["labels"]) * coref_relation_idx return super().unbatch_output(model_output=model_output)