|
import copy |
|
from itertools import chain |
|
from typing import Dict, Optional, Sequence, Type |
|
|
|
import torch |
|
from pie_modules.annotations import BinaryCorefRelation |
|
from pie_modules.document.processing.text_pair import shift_span |
|
from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations |
|
from pie_modules.taskmodules import RETextClassificationWithIndicesTaskModule |
|
from pie_modules.taskmodules.common import TaskModuleWithDocumentConverter |
|
from pie_modules.taskmodules.re_text_classification_with_indices import MarkerFactory |
|
from pie_modules.taskmodules.re_text_classification_with_indices import ( |
|
ModelTargetType as REModelTargetType, |
|
) |
|
from pie_modules.taskmodules.re_text_classification_with_indices import ( |
|
TaskOutputType as RETaskOutputType, |
|
) |
|
from pytorch_ie import Document, TaskModule |
|
from pytorch_ie.annotations import LabeledSpan |
|
from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations |
|
|
|
|
|
class SharpBracketMarkerFactory(MarkerFactory): |
|
def _get_marker(self, role: str, is_start: bool, label: Optional[str] = None) -> str: |
|
result = "<" |
|
if not is_start: |
|
result += "/" |
|
result += self._get_role_marker(role) |
|
if label is not None: |
|
result += f":{label}" |
|
result += ">" |
|
return result |
|
|
|
def get_append_marker(self, role: str, label: Optional[str] = None) -> str: |
|
role_marker = self._get_role_marker(role) |
|
if label is None: |
|
return f"<{role_marker}>" |
|
else: |
|
return f"<{role_marker}={label}>" |
|
|
|
|
|
@TaskModule.register() |
|
class RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers( |
|
RETextClassificationWithIndicesTaskModule |
|
): |
|
def __init__(self, use_sharp_marker: bool = False, **kwargs): |
|
super().__init__(**kwargs) |
|
self.use_sharp_marker = use_sharp_marker |
|
|
|
def get_marker_factory(self) -> MarkerFactory: |
|
if self.use_sharp_marker: |
|
return SharpBracketMarkerFactory(role_to_marker=self.argument_role_to_marker) |
|
else: |
|
return MarkerFactory(role_to_marker=self.argument_role_to_marker) |
|
|
|
|
|
def construct_text_document_from_text_pair_coref_document( |
|
document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, |
|
glue_text: str, |
|
no_relation_label: str, |
|
relation_label_mapping: Optional[Dict[str, str]] = None, |
|
add_span_mapping_to_metadata: bool = False, |
|
) -> TextDocumentWithLabeledSpansAndBinaryRelations: |
|
if document.text == document.text_pair: |
|
new_doc = TextDocumentWithLabeledSpansAndBinaryRelations( |
|
id=document.id, metadata=copy.deepcopy(document.metadata), text=document.text |
|
) |
|
old2new_spans: Dict[LabeledSpan, LabeledSpan] = {} |
|
new2new_spans: Dict[LabeledSpan, LabeledSpan] = {} |
|
for old_span in chain(document.labeled_spans, document.labeled_spans_pair): |
|
new_span = old_span.copy() |
|
|
|
new_span = new2new_spans.get(new_span, new_span) |
|
new2new_spans[new_span] = new_span |
|
old2new_spans[old_span] = new_span |
|
else: |
|
new_doc = TextDocumentWithLabeledSpansAndBinaryRelations( |
|
text=document.text + glue_text + document.text_pair, |
|
id=document.id, |
|
metadata=copy.deepcopy(document.metadata), |
|
) |
|
old2new_spans = {} |
|
old2new_spans.update({span: span.copy() for span in document.labeled_spans}) |
|
offset = len(document.text) + len(glue_text) |
|
old2new_spans.update( |
|
{span: shift_span(span.copy(), offset) for span in document.labeled_spans_pair} |
|
) |
|
|
|
|
|
new_doc.labeled_spans.extend( |
|
sorted(old2new_spans.values(), key=lambda s: (s.start, s.end, s.label)) |
|
) |
|
for old_rel in document.binary_coref_relations: |
|
label = old_rel.label if old_rel.score > 0.0 else no_relation_label |
|
if relation_label_mapping is not None: |
|
label = relation_label_mapping.get(label, label) |
|
new_rel = old_rel.copy( |
|
head=old2new_spans[old_rel.head], |
|
tail=old2new_spans[old_rel.tail], |
|
label=label, |
|
score=1.0, |
|
) |
|
new_doc.binary_relations.append(new_rel) |
|
|
|
if add_span_mapping_to_metadata: |
|
new_doc.metadata["span_mapping"] = old2new_spans |
|
return new_doc |
|
|
|
|
|
@TaskModule.register() |
|
class CrossTextBinaryCorefByRETextClassificationTaskModule( |
|
TaskModuleWithDocumentConverter, |
|
RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers, |
|
): |
|
def __init__( |
|
self, |
|
coref_relation_label: str, |
|
relation_annotation: str = "binary_relations", |
|
probability_threshold: float = 0.0, |
|
**kwargs, |
|
): |
|
if relation_annotation != "binary_relations": |
|
raise ValueError( |
|
f"{type(self).__name__} requires relation_annotation='binary_relations', " |
|
f"but it is: {relation_annotation}" |
|
) |
|
super().__init__(relation_annotation=relation_annotation, **kwargs) |
|
self.coref_relation_label = coref_relation_label |
|
self.probability_threshold = probability_threshold |
|
|
|
@property |
|
def document_type(self) -> Optional[Type[Document]]: |
|
return TextPairDocumentWithLabeledSpansAndBinaryCorefRelations |
|
|
|
def _get_glue_text(self) -> str: |
|
result = self.tokenizer.decode(self._get_glue_token_ids()) |
|
return result |
|
|
|
def _convert_document( |
|
self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations |
|
) -> TextDocumentWithLabeledSpansAndBinaryRelations: |
|
return construct_text_document_from_text_pair_coref_document( |
|
document, |
|
glue_text=self._get_glue_text(), |
|
relation_label_mapping={"coref": self.coref_relation_label}, |
|
no_relation_label=self.none_label, |
|
add_span_mapping_to_metadata=True, |
|
) |
|
|
|
def _integrate_predictions_from_converted_document( |
|
self, |
|
document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, |
|
converted_document: TextDocumentWithLabeledSpansAndBinaryRelations, |
|
) -> None: |
|
original2converted_span = converted_document.metadata["span_mapping"] |
|
new2original_span = { |
|
converted_s: orig_s for orig_s, converted_s in original2converted_span.items() |
|
} |
|
|
|
for rel in converted_document.binary_relations.predictions: |
|
original_head = new2original_span[rel.head] |
|
original_tail = new2original_span[rel.tail] |
|
if rel.label != self.coref_relation_label: |
|
raise ValueError(f"unexpected label: {rel.label}") |
|
if rel.score >= self.probability_threshold: |
|
original_predicted_rel = BinaryCorefRelation( |
|
head=original_head, tail=original_tail, label="coref", score=rel.score |
|
) |
|
document.binary_coref_relations.predictions.append(original_predicted_rel) |
|
|
|
def unbatch_output(self, model_output: REModelTargetType) -> Sequence[RETaskOutputType]: |
|
coref_relation_idx = self.label_to_id[self.coref_relation_label] |
|
|
|
model_output = copy.copy(model_output) |
|
model_output["labels"] = torch.ones_like(model_output["labels"]) * coref_relation_idx |
|
return super().unbatch_output(model_output=model_output) |
|
|