|
import logging |
|
from typing import Optional, Sequence, TypeVar, Union |
|
|
|
from pie_modules.taskmodules import CrossTextBinaryCorefTaskModule |
|
from pie_modules.taskmodules.cross_text_binary_coref import ( |
|
DocumentType, |
|
SpanDoesNotFitIntoAvailableWindow, |
|
TaskEncodingType, |
|
) |
|
from pie_modules.utils.tokenization import SpanNotAlignedWithTokenException |
|
from pytorch_ie.annotations import Span |
|
from pytorch_ie.core import TaskEncoding, TaskModule |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
S = TypeVar("S", bound=Span) |
|
|
|
|
|
def shift_span(span: S, offset: int) -> S: |
|
return span.copy(start=span.start + offset, end=span.end + offset) |
|
|
|
|
|
@TaskModule.register() |
|
class CrossTextBinaryCorefTaskModuleWithOptionalContext(CrossTextBinaryCorefTaskModule): |
|
"""Same as CrossTextBinaryCorefTaskModule, but: |
|
- optionally without context. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
without_context: bool = False, |
|
**kwargs, |
|
) -> None: |
|
super().__init__(**kwargs) |
|
self.without_context = without_context |
|
|
|
def encode_input( |
|
self, |
|
document: DocumentType, |
|
is_training: bool = False, |
|
) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]: |
|
if self.without_context: |
|
return self.encode_input_without_context(document) |
|
else: |
|
return super().encode_input(document) |
|
|
|
def encode_input_without_context( |
|
self, document: DocumentType |
|
) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]: |
|
self.collect_all_relations(kind="available", relations=document.binary_coref_relations) |
|
tokenizer_kwargs = dict( |
|
padding=False, |
|
truncation=False, |
|
add_special_tokens=False, |
|
) |
|
|
|
task_encodings = [] |
|
for coref_rel in document.binary_coref_relations: |
|
|
|
|
|
|
|
|
|
if not ( |
|
coref_rel.head.target == document.text |
|
or coref_rel.tail.target == document.text_pair |
|
): |
|
raise ValueError( |
|
f"It is expected that coref relations go from (head) spans over 'text' " |
|
f"to (tail) spans over 'text_pair', but this is not the case for this " |
|
f"relation (i.e. it points into the other direction): {coref_rel.resolve()}" |
|
) |
|
encoding = self.tokenizer(text=str(coref_rel.head), **tokenizer_kwargs) |
|
encoding_pair = self.tokenizer(text=str(coref_rel.tail), **tokenizer_kwargs) |
|
|
|
try: |
|
current_encoding, token_span = self.truncate_encoding_around_span( |
|
encoding=encoding, char_span=shift_span(coref_rel.head, -coref_rel.head.start) |
|
) |
|
current_encoding_pair, token_span_pair = self.truncate_encoding_around_span( |
|
encoding=encoding_pair, |
|
char_span=shift_span(coref_rel.tail, -coref_rel.tail.start), |
|
) |
|
except SpanNotAlignedWithTokenException as e: |
|
logger.warning( |
|
f"Could not get token offsets for argument ({e.span}) of coref relation: " |
|
f"{coref_rel.resolve()}. Skip it." |
|
) |
|
self.collect_relation(kind="skipped_args_not_aligned", relation=coref_rel) |
|
continue |
|
except SpanDoesNotFitIntoAvailableWindow as e: |
|
logger.warning( |
|
f"Argument span [{e.span}] does not fit into available token window " |
|
f"({self.available_window}). Skip it." |
|
) |
|
self.collect_relation( |
|
kind="skipped_span_does_not_fit_into_window", relation=coref_rel |
|
) |
|
continue |
|
|
|
task_encodings.append( |
|
TaskEncoding( |
|
document=document, |
|
inputs={ |
|
"encoding": current_encoding, |
|
"encoding_pair": current_encoding_pair, |
|
"pooler_start_indices": token_span.start, |
|
"pooler_end_indices": token_span.end, |
|
"pooler_pair_start_indices": token_span_pair.start, |
|
"pooler_pair_end_indices": token_span_pair.end, |
|
}, |
|
metadata={"candidate_annotation": coref_rel}, |
|
) |
|
) |
|
self.collect_relation("used", coref_rel) |
|
return task_encodings |
|
|