File size: 4,709 Bytes
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import logging
from typing import Optional, Sequence, TypeVar, Union

from pie_modules.taskmodules import CrossTextBinaryCorefTaskModule
from pie_modules.taskmodules.cross_text_binary_coref import (
    DocumentType,
    SpanDoesNotFitIntoAvailableWindow,
    TaskEncodingType,
)
from pie_modules.utils.tokenization import SpanNotAlignedWithTokenException
from pytorch_ie.annotations import Span
from pytorch_ie.core import TaskEncoding, TaskModule

logger = logging.getLogger(__name__)


S = TypeVar("S", bound=Span)


def shift_span(span: S, offset: int) -> S:
    return span.copy(start=span.start + offset, end=span.end + offset)


@TaskModule.register()
class CrossTextBinaryCorefTaskModuleWithOptionalContext(CrossTextBinaryCorefTaskModule):
    """Same as CrossTextBinaryCorefTaskModule, but:
    - optionally without context.
    """

    def __init__(
        self,
        without_context: bool = False,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.without_context = without_context

    def encode_input(
        self,
        document: DocumentType,
        is_training: bool = False,
    ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]:
        if self.without_context:
            return self.encode_input_without_context(document)
        else:
            return super().encode_input(document)

    def encode_input_without_context(
        self, document: DocumentType
    ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]:
        self.collect_all_relations(kind="available", relations=document.binary_coref_relations)
        tokenizer_kwargs = dict(
            padding=False,
            truncation=False,
            add_special_tokens=False,
        )

        task_encodings = []
        for coref_rel in document.binary_coref_relations:

            # TODO: This can miss instances if both texts are the same. We could check that
            #   coref_rel.head is in document.labeled_spans (same for the tail), but would this
            #   slow down the encoding?
            if not (
                coref_rel.head.target == document.text
                or coref_rel.tail.target == document.text_pair
            ):
                raise ValueError(
                    f"It is expected that coref relations go from (head) spans over 'text' "
                    f"to (tail) spans over 'text_pair', but this is not the case for this "
                    f"relation (i.e. it points into the other direction): {coref_rel.resolve()}"
                )
            encoding = self.tokenizer(text=str(coref_rel.head), **tokenizer_kwargs)
            encoding_pair = self.tokenizer(text=str(coref_rel.tail), **tokenizer_kwargs)

            try:
                current_encoding, token_span = self.truncate_encoding_around_span(
                    encoding=encoding, char_span=shift_span(coref_rel.head, -coref_rel.head.start)
                )
                current_encoding_pair, token_span_pair = self.truncate_encoding_around_span(
                    encoding=encoding_pair,
                    char_span=shift_span(coref_rel.tail, -coref_rel.tail.start),
                )
            except SpanNotAlignedWithTokenException as e:
                logger.warning(
                    f"Could not get token offsets for argument ({e.span}) of coref relation: "
                    f"{coref_rel.resolve()}. Skip it."
                )
                self.collect_relation(kind="skipped_args_not_aligned", relation=coref_rel)
                continue
            except SpanDoesNotFitIntoAvailableWindow as e:
                logger.warning(
                    f"Argument span [{e.span}] does not fit into available token window "
                    f"({self.available_window}). Skip it."
                )
                self.collect_relation(
                    kind="skipped_span_does_not_fit_into_window", relation=coref_rel
                )
                continue

            task_encodings.append(
                TaskEncoding(
                    document=document,
                    inputs={
                        "encoding": current_encoding,
                        "encoding_pair": current_encoding_pair,
                        "pooler_start_indices": token_span.start,
                        "pooler_end_indices": token_span.end,
                        "pooler_pair_start_indices": token_span_pair.start,
                        "pooler_pair_end_indices": token_span_pair.end,
                    },
                    metadata={"candidate_annotation": coref_rel},
                )
            )
            self.collect_relation("used", coref_rel)
        return task_encodings