from __future__ import annotations import logging from typing import Any, Dict, Iterable, List, Sequence, Set, Tuple, TypeVar, Union import networkx as nx from pie_modules.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan from pie_modules.documents import TextDocumentWithLabeledMultiSpansAndBinaryRelations from pytorch_ie import AnnotationLayer from pytorch_ie.core import Document logger = logging.getLogger(__name__) D = TypeVar("D", bound=Document) def _remove_overlapping_entities( entities: Iterable[Dict[str, Any]], relations: Iterable[Dict[str, Any]] ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: sorted_entities = sorted(entities, key=lambda span: span["start"]) entities_wo_overlap = [] skipped_entities = [] last_end = 0 for entity_dict in sorted_entities: if entity_dict["start"] < last_end: skipped_entities.append(entity_dict) else: entities_wo_overlap.append(entity_dict) last_end = entity_dict["end"] if len(skipped_entities) > 0: logger.warning(f"skipped overlapping entities: {skipped_entities}") valid_entity_ids = set(entity_dict["_id"] for entity_dict in entities_wo_overlap) valid_relations = [ relation_dict for relation_dict in relations if relation_dict["head"] in valid_entity_ids and relation_dict["tail"] in valid_entity_ids ] return entities_wo_overlap, valid_relations def remove_overlapping_entities( doc: D, entity_layer_name: str = "entities", relation_layer_name: str = "relations", ) -> D: # TODO: use document.add_all_annotations_from_other() document_dict = doc.asdict() entities_wo_overlap, valid_relations = _remove_overlapping_entities( entities=document_dict[entity_layer_name]["annotations"], relations=document_dict[relation_layer_name]["annotations"], ) document_dict[entity_layer_name] = { "annotations": entities_wo_overlap, "predictions": [], } document_dict[relation_layer_name] = { "annotations": valid_relations, "predictions": [], } new_doc = type(doc).fromdict(document_dict) return new_doc def _merge_spans_via_relation( spans: Sequence[LabeledSpan], relations: Sequence[BinaryRelation], link_relation_label: str, create_multi_spans: bool = True, ) -> Tuple[Union[Set[LabeledSpan], Set[LabeledMultiSpan]], Set[BinaryRelation]]: # convert list of relations to a graph to easily calculate connected components to merge g = nx.Graph() link_relations = [] other_relations = [] for rel in relations: if rel.label == link_relation_label: link_relations.append(rel) # never merge spans that have not the same label if ( not (isinstance(rel.head, LabeledSpan) or isinstance(rel.tail, LabeledSpan)) or rel.head.label == rel.tail.label ): g.add_edge(rel.head, rel.tail) else: logger.debug( f"spans to merge do not have the same label, do not merge them: {rel.head}, {rel.tail}" ) else: other_relations.append(rel) span_mapping = {} connected_components: Set[LabeledSpan] for connected_components in nx.connected_components(g): # all spans in a connected component have the same label label = list(span.label for span in connected_components)[0] connected_components_sorted = sorted(connected_components, key=lambda span: span.start) if create_multi_spans: new_span = LabeledMultiSpan( slices=tuple((span.start, span.end) for span in connected_components_sorted), label=label, ) else: new_span = LabeledSpan( start=min(span.start for span in connected_components_sorted), end=max(span.end for span in connected_components_sorted), label=label, ) for span in connected_components_sorted: span_mapping[span] = new_span for span in spans: if span not in span_mapping: if create_multi_spans: span_mapping[span] = LabeledMultiSpan( slices=((span.start, span.end),), label=span.label, score=span.score ) else: span_mapping[span] = LabeledSpan( start=span.start, end=span.end, label=span.label, score=span.score ) new_spans = set(span_mapping.values()) new_relations = set( BinaryRelation( head=span_mapping[rel.head], tail=span_mapping[rel.tail], label=rel.label, score=rel.score, ) for rel in other_relations ) return new_spans, new_relations def merge_spans_via_relation( document: D, relation_layer: str, link_relation_label: str, use_predicted_spans: bool = False, process_predictions: bool = True, create_multi_spans: bool = False, ) -> D: rel_layer = document[relation_layer] span_layer = rel_layer.target_layer new_gold_spans, new_gold_relations = _merge_spans_via_relation( spans=span_layer, relations=rel_layer, link_relation_label=link_relation_label, create_multi_spans=create_multi_spans, ) if process_predictions: new_pred_spans, new_pred_relations = _merge_spans_via_relation( spans=span_layer.predictions if use_predicted_spans else span_layer, relations=rel_layer.predictions, link_relation_label=link_relation_label, create_multi_spans=create_multi_spans, ) else: assert not use_predicted_spans new_pred_spans = set(span_layer.predictions.clear()) new_pred_relations = set(rel_layer.predictions.clear()) relation_layer_name = relation_layer span_layer_name = document[relation_layer].target_name if create_multi_spans: doc_dict = document.asdict() for f in document.annotation_fields(): doc_dict.pop(f.name) result = TextDocumentWithLabeledMultiSpansAndBinaryRelations.fromdict(doc_dict) result.labeled_multi_spans.extend(new_gold_spans) result.labeled_multi_spans.predictions.extend(new_pred_spans) result.binary_relations.extend(new_gold_relations) result.binary_relations.predictions.extend(new_pred_relations) else: result = document.copy(with_annotations=False) result[span_layer_name].extend(new_gold_spans) result[span_layer_name].predictions.extend(new_pred_spans) result[relation_layer_name].extend(new_gold_relations) result[relation_layer_name].predictions.extend(new_pred_relations) return result def remove_partitions_by_labels( document: D, partition_layer: str, label_blacklist: List[str] ) -> D: document = document.copy() layer: AnnotationLayer = document[partition_layer] new_partitions = [] for partition in layer.clear(): if partition.label not in label_blacklist: new_partitions.append(partition) layer.extend(new_partitions) return document D_text = TypeVar("D_text", bound=Document) def replace_substrings_in_text( document: D_text, replacements: Dict[str, str], enforce_same_length: bool = True ) -> D_text: new_text = document.text for old_str, new_str in replacements.items(): if enforce_same_length and len(old_str) != len(new_str): raise ValueError( f'Replacement strings must have the same length, but got "{old_str}" -> "{new_str}"' ) new_text = new_text.replace(old_str, new_str) result_dict = document.asdict() result_dict["text"] = new_text result = type(document).fromdict(result_dict) result.text = new_text return result def replace_substrings_in_text_with_spaces(document: D_text, substrings: Iterable[str]) -> D_text: replacements = {substring: " " * len(substring) for substring in substrings} return replace_substrings_in_text(document, replacements=replacements)