ScientificArgumentRecommender / src /taskmodules /cross_text_binary_coref.py
ArneBinder's picture
https://github.com/ArneBinder/pie-document-level/pull/312
3133b5e verified
raw
history blame
4.71 kB
import logging
from typing import Optional, Sequence, TypeVar, Union
from pie_modules.taskmodules import CrossTextBinaryCorefTaskModule
from pie_modules.taskmodules.cross_text_binary_coref import (
DocumentType,
SpanDoesNotFitIntoAvailableWindow,
TaskEncodingType,
)
from pie_modules.utils.tokenization import SpanNotAlignedWithTokenException
from pytorch_ie.annotations import Span
from pytorch_ie.core import TaskEncoding, TaskModule
logger = logging.getLogger(__name__)
S = TypeVar("S", bound=Span)
def shift_span(span: S, offset: int) -> S:
return span.copy(start=span.start + offset, end=span.end + offset)
@TaskModule.register()
class CrossTextBinaryCorefTaskModuleWithOptionalContext(CrossTextBinaryCorefTaskModule):
"""Same as CrossTextBinaryCorefTaskModule, but:
- optionally without context.
"""
def __init__(
self,
without_context: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.without_context = without_context
def encode_input(
self,
document: DocumentType,
is_training: bool = False,
) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]:
if self.without_context:
return self.encode_input_without_context(document)
else:
return super().encode_input(document)
def encode_input_without_context(
self, document: DocumentType
) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]:
self.collect_all_relations(kind="available", relations=document.binary_coref_relations)
tokenizer_kwargs = dict(
padding=False,
truncation=False,
add_special_tokens=False,
)
task_encodings = []
for coref_rel in document.binary_coref_relations:
# TODO: This can miss instances if both texts are the same. We could check that
# coref_rel.head is in document.labeled_spans (same for the tail), but would this
# slow down the encoding?
if not (
coref_rel.head.target == document.text
or coref_rel.tail.target == document.text_pair
):
raise ValueError(
f"It is expected that coref relations go from (head) spans over 'text' "
f"to (tail) spans over 'text_pair', but this is not the case for this "
f"relation (i.e. it points into the other direction): {coref_rel.resolve()}"
)
encoding = self.tokenizer(text=str(coref_rel.head), **tokenizer_kwargs)
encoding_pair = self.tokenizer(text=str(coref_rel.tail), **tokenizer_kwargs)
try:
current_encoding, token_span = self.truncate_encoding_around_span(
encoding=encoding, char_span=shift_span(coref_rel.head, -coref_rel.head.start)
)
current_encoding_pair, token_span_pair = self.truncate_encoding_around_span(
encoding=encoding_pair,
char_span=shift_span(coref_rel.tail, -coref_rel.tail.start),
)
except SpanNotAlignedWithTokenException as e:
logger.warning(
f"Could not get token offsets for argument ({e.span}) of coref relation: "
f"{coref_rel.resolve()}. Skip it."
)
self.collect_relation(kind="skipped_args_not_aligned", relation=coref_rel)
continue
except SpanDoesNotFitIntoAvailableWindow as e:
logger.warning(
f"Argument span [{e.span}] does not fit into available token window "
f"({self.available_window}). Skip it."
)
self.collect_relation(
kind="skipped_span_does_not_fit_into_window", relation=coref_rel
)
continue
task_encodings.append(
TaskEncoding(
document=document,
inputs={
"encoding": current_encoding,
"encoding_pair": current_encoding_pair,
"pooler_start_indices": token_span.start,
"pooler_end_indices": token_span.end,
"pooler_pair_start_indices": token_span_pair.start,
"pooler_pair_end_indices": token_span_pair.end,
},
metadata={"candidate_annotation": coref_rel},
)
)
self.collect_relation("used", coref_rel)
return task_encodings