ScientificArgumentRecommender / src /taskmodules /re_text_classification_with_indices.py
ArneBinder's picture
https://github.com/ArneBinder/pie-document-level/pull/312
3133b5e verified
raw
history blame
7.57 kB
import copy
from itertools import chain
from typing import Dict, Optional, Sequence, Type
import torch
from pie_modules.annotations import BinaryCorefRelation
from pie_modules.document.processing.text_pair import shift_span
from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
from pie_modules.taskmodules import RETextClassificationWithIndicesTaskModule
from pie_modules.taskmodules.common import TaskModuleWithDocumentConverter
from pie_modules.taskmodules.re_text_classification_with_indices import MarkerFactory
from pie_modules.taskmodules.re_text_classification_with_indices import (
ModelTargetType as REModelTargetType,
)
from pie_modules.taskmodules.re_text_classification_with_indices import (
TaskOutputType as RETaskOutputType,
)
from pytorch_ie import Document, TaskModule
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations
class SharpBracketMarkerFactory(MarkerFactory):
def _get_marker(self, role: str, is_start: bool, label: Optional[str] = None) -> str:
result = "<"
if not is_start:
result += "/"
result += self._get_role_marker(role)
if label is not None:
result += f":{label}"
result += ">"
return result
def get_append_marker(self, role: str, label: Optional[str] = None) -> str:
role_marker = self._get_role_marker(role)
if label is None:
return f"<{role_marker}>"
else:
return f"<{role_marker}={label}>"
@TaskModule.register()
class RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers(
RETextClassificationWithIndicesTaskModule
):
def __init__(self, use_sharp_marker: bool = False, **kwargs):
super().__init__(**kwargs)
self.use_sharp_marker = use_sharp_marker
def get_marker_factory(self) -> MarkerFactory:
if self.use_sharp_marker:
return SharpBracketMarkerFactory(role_to_marker=self.argument_role_to_marker)
else:
return MarkerFactory(role_to_marker=self.argument_role_to_marker)
def construct_text_document_from_text_pair_coref_document(
document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
glue_text: str,
no_relation_label: str,
relation_label_mapping: Optional[Dict[str, str]] = None,
add_span_mapping_to_metadata: bool = False,
) -> TextDocumentWithLabeledSpansAndBinaryRelations:
if document.text == document.text_pair:
new_doc = TextDocumentWithLabeledSpansAndBinaryRelations(
id=document.id, metadata=copy.deepcopy(document.metadata), text=document.text
)
old2new_spans: Dict[LabeledSpan, LabeledSpan] = {}
new2new_spans: Dict[LabeledSpan, LabeledSpan] = {}
for old_span in chain(document.labeled_spans, document.labeled_spans_pair):
new_span = old_span.copy()
# when detaching / copying the span, it may be the same as a previous span from the other
new_span = new2new_spans.get(new_span, new_span)
new2new_spans[new_span] = new_span
old2new_spans[old_span] = new_span
else:
new_doc = TextDocumentWithLabeledSpansAndBinaryRelations(
text=document.text + glue_text + document.text_pair,
id=document.id,
metadata=copy.deepcopy(document.metadata),
)
old2new_spans = {}
old2new_spans.update({span: span.copy() for span in document.labeled_spans})
offset = len(document.text) + len(glue_text)
old2new_spans.update(
{span: shift_span(span.copy(), offset) for span in document.labeled_spans_pair}
)
# sort to make order deterministic
new_doc.labeled_spans.extend(
sorted(old2new_spans.values(), key=lambda s: (s.start, s.end, s.label))
)
for old_rel in document.binary_coref_relations:
label = old_rel.label if old_rel.score > 0.0 else no_relation_label
if relation_label_mapping is not None:
label = relation_label_mapping.get(label, label)
new_rel = old_rel.copy(
head=old2new_spans[old_rel.head],
tail=old2new_spans[old_rel.tail],
label=label,
score=1.0,
)
new_doc.binary_relations.append(new_rel)
if add_span_mapping_to_metadata:
new_doc.metadata["span_mapping"] = old2new_spans
return new_doc
@TaskModule.register()
class CrossTextBinaryCorefByRETextClassificationTaskModule(
TaskModuleWithDocumentConverter,
RETextClassificationWithIndicesTaskModuleAndWithSharpBracketMarkers,
):
def __init__(
self,
coref_relation_label: str,
relation_annotation: str = "binary_relations",
probability_threshold: float = 0.0,
**kwargs,
):
if relation_annotation != "binary_relations":
raise ValueError(
f"{type(self).__name__} requires relation_annotation='binary_relations', "
f"but it is: {relation_annotation}"
)
super().__init__(relation_annotation=relation_annotation, **kwargs)
self.coref_relation_label = coref_relation_label
self.probability_threshold = probability_threshold
@property
def document_type(self) -> Optional[Type[Document]]:
return TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
def _get_glue_text(self) -> str:
result = self.tokenizer.decode(self._get_glue_token_ids())
return result
def _convert_document(
self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
) -> TextDocumentWithLabeledSpansAndBinaryRelations:
return construct_text_document_from_text_pair_coref_document(
document,
glue_text=self._get_glue_text(),
relation_label_mapping={"coref": self.coref_relation_label},
no_relation_label=self.none_label,
add_span_mapping_to_metadata=True,
)
def _integrate_predictions_from_converted_document(
self,
document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
converted_document: TextDocumentWithLabeledSpansAndBinaryRelations,
) -> None:
original2converted_span = converted_document.metadata["span_mapping"]
new2original_span = {
converted_s: orig_s for orig_s, converted_s in original2converted_span.items()
}
for rel in converted_document.binary_relations.predictions:
original_head = new2original_span[rel.head]
original_tail = new2original_span[rel.tail]
if rel.label != self.coref_relation_label:
raise ValueError(f"unexpected label: {rel.label}")
if rel.score >= self.probability_threshold:
original_predicted_rel = BinaryCorefRelation(
head=original_head, tail=original_tail, label="coref", score=rel.score
)
document.binary_coref_relations.predictions.append(original_predicted_rel)
def unbatch_output(self, model_output: REModelTargetType) -> Sequence[RETaskOutputType]:
coref_relation_idx = self.label_to_id[self.coref_relation_label]
# we are just concerned with the coref class, so we overwrite the labels field
model_output = copy.copy(model_output)
model_output["labels"] = torch.ones_like(model_output["labels"]) * coref_relation_idx
return super().unbatch_output(model_output=model_output)