|
from __future__ import annotations |
|
|
|
import logging |
|
from collections import defaultdict |
|
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union |
|
|
|
from pie_modules.utils.span import have_overlap |
|
from pytorch_ie import AnnotationLayer |
|
from pytorch_ie.annotations import LabeledMultiSpan, LabeledSpan, MultiSpan, Span |
|
from pytorch_ie.core import Document |
|
from pytorch_ie.core.document import Annotation, _enumerate_dependencies |
|
|
|
from src.document.types import ( |
|
RelatedRelation, |
|
TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations, |
|
) |
|
from src.utils import distance, distance_slices |
|
from src.utils.span_utils import get_overlap_len |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
D = TypeVar("D", bound=Document) |
|
|
|
|
|
def _remove_overlapping_entities( |
|
entities: Iterable[Dict[str, Any]], relations: Iterable[Dict[str, Any]] |
|
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: |
|
sorted_entities = sorted(entities, key=lambda span: span["start"]) |
|
entities_wo_overlap = [] |
|
skipped_entities = [] |
|
last_end = 0 |
|
for entity_dict in sorted_entities: |
|
if entity_dict["start"] < last_end: |
|
skipped_entities.append(entity_dict) |
|
else: |
|
entities_wo_overlap.append(entity_dict) |
|
last_end = entity_dict["end"] |
|
if len(skipped_entities) > 0: |
|
logger.warning(f"skipped overlapping entities: {skipped_entities}") |
|
valid_entity_ids = set(entity_dict["_id"] for entity_dict in entities_wo_overlap) |
|
valid_relations = [ |
|
relation_dict |
|
for relation_dict in relations |
|
if relation_dict["head"] in valid_entity_ids and relation_dict["tail"] in valid_entity_ids |
|
] |
|
return entities_wo_overlap, valid_relations |
|
|
|
|
|
def remove_overlapping_entities( |
|
doc: D, |
|
entity_layer_name: str = "entities", |
|
relation_layer_name: str = "relations", |
|
) -> D: |
|
|
|
document_dict = doc.asdict() |
|
entities_wo_overlap, valid_relations = _remove_overlapping_entities( |
|
entities=document_dict[entity_layer_name]["annotations"], |
|
relations=document_dict[relation_layer_name]["annotations"], |
|
) |
|
|
|
document_dict[entity_layer_name] = { |
|
"annotations": entities_wo_overlap, |
|
"predictions": [], |
|
} |
|
document_dict[relation_layer_name] = { |
|
"annotations": valid_relations, |
|
"predictions": [], |
|
} |
|
new_doc = type(doc).fromdict(document_dict) |
|
|
|
return new_doc |
|
|
|
|
|
def remove_partitions_by_labels( |
|
document: D, partition_layer: str, label_blacklist: List[str], span_layer: Optional[str] = None |
|
) -> D: |
|
"""Remove partitions with labels in the blacklist from a document. |
|
|
|
Args: |
|
document: The document to process. |
|
partition_layer: The name of the partition layer. |
|
label_blacklist: The list of labels to remove. |
|
span_layer: The name of the span layer to remove spans from if they are not fully |
|
contained in any remaining partition. Any dependent annotations will be removed as well. |
|
|
|
Returns: |
|
The processed document. |
|
""" |
|
|
|
document = document.copy() |
|
p_layer: AnnotationLayer = document[partition_layer] |
|
new_partitions = [] |
|
for partition in p_layer.clear(): |
|
if partition.label not in label_blacklist: |
|
new_partitions.append(partition) |
|
p_layer.extend(new_partitions) |
|
|
|
if span_layer is not None: |
|
result = document.copy(with_annotations=False) |
|
removed_span_ids = set() |
|
for span in document[span_layer]: |
|
|
|
if any( |
|
partition.start <= span.start and span.end <= partition.end |
|
for partition in new_partitions |
|
): |
|
result[span_layer].append(span.copy()) |
|
else: |
|
removed_span_ids.add(span._id) |
|
|
|
result.add_all_annotations_from_other( |
|
document, |
|
removed_annotations={span_layer: removed_span_ids}, |
|
strict=False, |
|
verbose=False, |
|
) |
|
document = result |
|
|
|
return document |
|
|
|
|
|
D_text = TypeVar("D_text", bound=Document) |
|
|
|
|
|
def replace_substrings_in_text( |
|
document: D_text, replacements: Dict[str, str], enforce_same_length: bool = True |
|
) -> D_text: |
|
new_text = document.text |
|
for old_str, new_str in replacements.items(): |
|
if enforce_same_length and len(old_str) != len(new_str): |
|
raise ValueError( |
|
f'Replacement strings must have the same length, but got "{old_str}" -> "{new_str}"' |
|
) |
|
new_text = new_text.replace(old_str, new_str) |
|
result_dict = document.asdict() |
|
result_dict["text"] = new_text |
|
result = type(document).fromdict(result_dict) |
|
result.text = new_text |
|
return result |
|
|
|
|
|
def replace_substrings_in_text_with_spaces(document: D_text, substrings: Iterable[str]) -> D_text: |
|
replacements = {substring: " " * len(substring) for substring in substrings} |
|
return replace_substrings_in_text(document, replacements=replacements) |
|
|
|
|
|
def relabel_annotations( |
|
document: D, |
|
label_mapping: Dict[str, Dict[str, str]], |
|
) -> D: |
|
""" |
|
Replace annotation labels in a document. |
|
|
|
Args: |
|
document: The document to process. |
|
label_mapping: A mapping from layer names to mappings from old labels to new labels. |
|
|
|
Returns: |
|
The processed document. |
|
|
|
""" |
|
|
|
dependency_ordered_fields: List[str] = [] |
|
_enumerate_dependencies( |
|
dependency_ordered_fields, |
|
dependency_graph=document._annotation_graph, |
|
nodes=document._annotation_graph["_artificial_root"], |
|
) |
|
result = document.copy(with_annotations=False) |
|
store: Dict[int, Annotation] = {} |
|
|
|
invalid_annotation_ids: Set[int] = set() |
|
for field_name in dependency_ordered_fields: |
|
if field_name in document._annotation_fields: |
|
layer = document[field_name] |
|
for is_prediction, anns in [(False, layer), (True, layer.predictions)]: |
|
for ann in anns: |
|
new_ann = ann.copy_with_store( |
|
override_annotation_store=store, |
|
invalid_annotation_ids=invalid_annotation_ids, |
|
) |
|
if field_name in label_mapping: |
|
if ann.label in label_mapping[field_name]: |
|
new_label = label_mapping[field_name][ann.label] |
|
new_ann = new_ann.copy(label=new_label) |
|
else: |
|
raise ValueError( |
|
f"Label {ann.label} not found in label mapping for {field_name}" |
|
) |
|
store[ann._id] = new_ann |
|
target_layer = result[field_name] |
|
if is_prediction: |
|
target_layer.predictions.append(new_ann) |
|
else: |
|
target_layer.append(new_ann) |
|
|
|
return result |
|
|
|
|
|
DWithSpans = TypeVar("DWithSpans", bound=Document) |
|
|
|
|
|
def get_start_end(span: Union[Span, MultiSpan]) -> Tuple[int, int]: |
|
if isinstance(span, Span): |
|
return span.start, span.end |
|
elif isinstance(span, MultiSpan): |
|
starts, ends = zip(*span.slices) |
|
return min(starts), max(ends) |
|
else: |
|
raise ValueError(f"Unsupported span type: {type(span)}") |
|
|
|
|
|
def _get_aligned_span_mappings( |
|
gold_spans: Iterable[Span], pred_spans: Iterable[Span], distance_type: str |
|
) -> Tuple[Dict[int, Span], Dict[int, Span]]: |
|
old2new_pred_span = {} |
|
span_id2gold_span = {} |
|
for pred_span in pred_spans: |
|
|
|
gold_spans_with_distance = [ |
|
( |
|
gold_span, |
|
distance( |
|
start_end=get_start_end(pred_span), |
|
other_start_end=get_start_end(gold_span), |
|
distance_type=distance_type, |
|
), |
|
) |
|
for gold_span in gold_spans |
|
] |
|
if len(gold_spans_with_distance) == 0: |
|
continue |
|
|
|
closest_gold_span, min_distance = min(gold_spans_with_distance, key=lambda x: x[1]) |
|
|
|
if min_distance == 0.0: |
|
continue |
|
|
|
pred_start_end = get_start_end(pred_span) |
|
closest_gold_start_end = get_start_end(closest_gold_span) |
|
|
|
if have_overlap( |
|
start_end=pred_start_end, |
|
other_start_end=closest_gold_start_end, |
|
): |
|
overlap_len = get_overlap_len(pred_start_end, closest_gold_start_end) |
|
l_max = max( |
|
pred_start_end[1] - pred_start_end[0], |
|
closest_gold_start_end[1] - closest_gold_start_end[0], |
|
) |
|
|
|
valid_match = overlap_len >= (l_max / 2) |
|
else: |
|
valid_match = False |
|
|
|
if valid_match: |
|
if isinstance(pred_span, Span): |
|
aligned_pred_span = pred_span.copy( |
|
start=closest_gold_span.start, end=closest_gold_span.end |
|
) |
|
elif isinstance(pred_span, MultiSpan): |
|
aligned_pred_span = pred_span.copy(slices=closest_gold_span.slices) |
|
else: |
|
raise ValueError(f"Unsupported span type: {type(pred_span)}") |
|
old2new_pred_span[pred_span._id] = aligned_pred_span |
|
span_id2gold_span[pred_span._id] = closest_gold_span |
|
|
|
return old2new_pred_span, span_id2gold_span |
|
|
|
|
|
def get_spans2multi_spans_mapping(multi_spans: Iterable[MultiSpan]) -> Dict[Span, MultiSpan]: |
|
result = {} |
|
for multi_span in multi_spans: |
|
for start, end in multi_span.slices: |
|
span_kwargs = dict(start=start, end=end, score=multi_span.score) |
|
if isinstance(multi_span, LabeledMultiSpan): |
|
result[LabeledSpan(label=multi_span.label, **span_kwargs)] = multi_span |
|
else: |
|
result[Span(**span_kwargs)] = multi_span |
|
|
|
return result |
|
|
|
|
|
def align_predicted_span_annotations( |
|
document: DWithSpans, |
|
span_layer: str, |
|
distance_type: str = "center", |
|
simple_multi_span: bool = False, |
|
verbose: bool = False, |
|
) -> DWithSpans: |
|
""" |
|
Aligns predicted span annotations with the closest gold spans in a document. |
|
|
|
First, calculates the distance between each predicted span and each gold span. Then, |
|
for each predicted span, the gold span with the smallest distance is selected. If the |
|
predicted span and the gold span have an overlap of at least half of the maximum length |
|
of the two spans, the predicted span is aligned with the gold span. |
|
|
|
This also works for MultiSpan annotations, where the slices of the MultiSpan are used |
|
to align the predicted spans. If any of the slices is aligned with a gold slice, |
|
the MultiSpan is aligned with the respective gold MultiSpan. However, this may result in |
|
the predicted MultiSpan being aligned with multiple gold MultiSpans, in which case the |
|
closest gold MultiSpan is selected. A simplified version of this alignment can be achieved |
|
by setting `simple_multi_span=True`, which treats MultiSpan annotations as simple Spans |
|
by using their maximum and minimum start and end indices. |
|
|
|
Args: |
|
document: The document to process. |
|
span_layer: The name of the span layer. |
|
distance_type: The type of distance to calculate. One of: center, inner, outer |
|
simple_multi_span: Whether to treat MultiSpan annotations as simple Spans by using their |
|
maximum and minimum start and end indices. |
|
verbose: Whether to print debug information. |
|
|
|
Returns: |
|
The processed document. |
|
""" |
|
gold_spans = document[span_layer] |
|
if len(gold_spans) == 0: |
|
return document.copy() |
|
|
|
pred_spans = document[span_layer].predictions |
|
span_annotation_type = document.annotation_types()[span_layer] |
|
if issubclass(span_annotation_type, Span) or simple_multi_span: |
|
old2new_pred_span, span_id2gold_span = _get_aligned_span_mappings( |
|
gold_spans=gold_spans, pred_spans=pred_spans, distance_type=distance_type |
|
) |
|
elif issubclass(span_annotation_type, MultiSpan): |
|
|
|
gold_single_spans2multi_spans = get_spans2multi_spans_mapping(gold_spans) |
|
pred_single_spans2multi_spans = get_spans2multi_spans_mapping(pred_spans) |
|
|
|
single_old2new_pred_span, single_span_id2gold_span = _get_aligned_span_mappings( |
|
gold_spans=gold_single_spans2multi_spans.keys(), |
|
pred_spans=pred_single_spans2multi_spans.keys(), |
|
distance_type=distance_type, |
|
) |
|
|
|
pred_multi_span2single_spans: Dict[MultiSpan, List[Span]] = defaultdict(list) |
|
for pred_span, multi_span in pred_single_spans2multi_spans.items(): |
|
pred_multi_span2single_spans[multi_span].append(pred_span) |
|
|
|
|
|
old2new_pred_span = {} |
|
span_id2gold_span = {} |
|
for pred_multi_span, pred_single_spans in pred_multi_span2single_spans.items(): |
|
|
|
if any( |
|
pred_single_span._id in single_old2new_pred_span |
|
for pred_single_span in pred_single_spans |
|
): |
|
|
|
aligned_gold_multi_spans = set() |
|
for pred_single_span in pred_single_spans: |
|
if pred_single_span._id in single_old2new_pred_span: |
|
aligned_gold_single_span = single_span_id2gold_span[pred_single_span._id] |
|
aligned_gold_multi_span = gold_single_spans2multi_spans[ |
|
aligned_gold_single_span |
|
] |
|
aligned_gold_multi_spans.add(aligned_gold_multi_span) |
|
|
|
|
|
gold_multi_spans_with_distance = [ |
|
( |
|
gold_multi_span, |
|
distance_slices( |
|
slices=pred_multi_span.slices, |
|
other_slices=gold_multi_span.slices, |
|
distance_type=distance_type, |
|
), |
|
) |
|
for gold_multi_span in aligned_gold_multi_spans |
|
] |
|
|
|
if len(aligned_gold_multi_spans) > 1: |
|
logger.warning( |
|
f"Multiple gold multi spans aligned with predicted multi span ({pred_multi_span}): " |
|
f"{aligned_gold_multi_spans}" |
|
) |
|
|
|
closest_gold_multi_span, min_distance = min( |
|
gold_multi_spans_with_distance, key=lambda x: x[1] |
|
) |
|
old2new_pred_span[pred_multi_span._id] = pred_multi_span.copy( |
|
slices=closest_gold_multi_span.slices |
|
) |
|
span_id2gold_span[pred_multi_span._id] = closest_gold_multi_span |
|
else: |
|
raise ValueError(f"Unsupported span annotation type: {span_annotation_type}") |
|
|
|
result = document.copy(with_annotations=False) |
|
|
|
|
|
|
|
added_pred_span_ids = dict() |
|
for pred_span in pred_spans: |
|
|
|
if pred_span._id not in old2new_pred_span: |
|
|
|
if pred_span._id not in added_pred_span_ids: |
|
keep_pred_span = pred_span.copy() |
|
result[span_layer].predictions.append(keep_pred_span) |
|
added_pred_span_ids[pred_span._id] = keep_pred_span |
|
elif verbose: |
|
print(f"Skipping duplicate predicted span. pred_span='{str(pred_span)}'") |
|
else: |
|
aligned_pred_span = old2new_pred_span[pred_span._id] |
|
|
|
if aligned_pred_span._id not in added_pred_span_ids: |
|
result[span_layer].predictions.append(aligned_pred_span) |
|
added_pred_span_ids[aligned_pred_span._id] = aligned_pred_span |
|
elif verbose: |
|
prev_pred_span = added_pred_span_ids[aligned_pred_span._id] |
|
gold_span = span_id2gold_span[pred_span._id] |
|
print( |
|
f"Skipping duplicate aligned predicted span. aligned gold_span='{str(gold_span)}', " |
|
f"prev_pred_span='{str(prev_pred_span)}', current_pred_span='{str(pred_span)}'" |
|
) |
|
|
|
|
|
result[span_layer].extend([span.copy() for span in gold_spans]) |
|
|
|
|
|
_aligned_spans = result.add_all_annotations_from_other( |
|
document, override_annotations={span_layer: old2new_pred_span} |
|
) |
|
|
|
return result |
|
|
|
|
|
def add_related_relations_from_binary_relations( |
|
document: TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations, |
|
link_relation_label: str, |
|
link_partition_whitelist: Optional[List[List[str]]] = None, |
|
relation_label_whitelist: Optional[List[str]] = None, |
|
reversed_relation_suffix: str = "_reversed", |
|
symmetric_relations: Optional[List[str]] = None, |
|
) -> TextDocumentWithLabeledMultiSpansBinaryRelationsLabeledPartitionsAndRelatedRelations: |
|
span2partition = {} |
|
for multi_span in document.labeled_multi_spans: |
|
found_partition = False |
|
for partition in document.labeled_partitions or [ |
|
LabeledSpan(start=0, end=len(document.text), label="ALL") |
|
]: |
|
starts, ends = zip(*multi_span.slices) |
|
if partition.start <= min(starts) and max(ends) <= partition.end: |
|
span2partition[multi_span] = partition |
|
found_partition = True |
|
break |
|
if not found_partition: |
|
raise ValueError(f"No partition found for multi_span {multi_span}") |
|
|
|
rel_head2rels = defaultdict(list) |
|
rel_tail2rels = defaultdict(list) |
|
for rel in document.binary_relations: |
|
rel_head2rels[rel.head].append(rel) |
|
rel_tail2rels[rel.tail].append(rel) |
|
|
|
link_partition_whitelist_tuples = None |
|
if link_partition_whitelist is not None: |
|
link_partition_whitelist_tuples = {tuple(pair) for pair in link_partition_whitelist} |
|
|
|
skipped_labels = [] |
|
for link_rel in document.binary_relations: |
|
if link_rel.label == link_relation_label: |
|
head_partition = span2partition[link_rel.head] |
|
tail_partition = span2partition[link_rel.tail] |
|
if link_partition_whitelist_tuples is None or ( |
|
(head_partition.label, tail_partition.label) in link_partition_whitelist_tuples |
|
): |
|
|
|
for rel in rel_head2rels.get(link_rel.tail, []): |
|
label = rel.label |
|
if relation_label_whitelist is None or label in relation_label_whitelist: |
|
new_rel = RelatedRelation( |
|
head=link_rel.head, |
|
tail=rel.tail, |
|
link_relation=link_rel, |
|
relation=rel, |
|
label=label, |
|
) |
|
document.related_relations.append(new_rel) |
|
else: |
|
skipped_labels.append(label) |
|
|
|
|
|
if reversed_relation_suffix is not None: |
|
for reversed_rel in rel_tail2rels.get(link_rel.tail, []): |
|
label = reversed_rel.label |
|
if not (symmetric_relations is not None and label in symmetric_relations): |
|
label = f"{label}{reversed_relation_suffix}" |
|
if relation_label_whitelist is None or label in relation_label_whitelist: |
|
new_rel = RelatedRelation( |
|
head=link_rel.head, |
|
tail=reversed_rel.head, |
|
link_relation=link_rel, |
|
relation=reversed_rel, |
|
label=label, |
|
) |
|
document.related_relations.append(new_rel) |
|
else: |
|
skipped_labels.append(label) |
|
|
|
else: |
|
logger.warning( |
|
f"Skipping related relation because of partition whitelist ({[head_partition.label, tail_partition.label]}): {link_rel.resolve()}" |
|
) |
|
if len(skipped_labels) > 0: |
|
logger.warning( |
|
f"Skipped relations with labels not in whitelist: {sorted(set(skipped_labels))}" |
|
) |
|
|
|
return document |
|
|