File size: 11,981 Bytes
bc6f57a 1681237 bc6f57a 2cc87ec bc6f57a efae5be bfcba2d bc6f57a 1681237 a8df5fb efae5be a8df5fb efae5be 2cc87ec efae5be 2cc87ec efae5be 2cc87ec efae5be 2cc87ec efae5be bc6f57a 2cc87ec efae5be bc6f57a 2cc87ec bc6f57a 5003662 2cc87ec bc6f57a 5003662 bc6f57a efae5be 2cc87ec efae5be bc6f57a 2cc87ec efae5be bc6f57a a8df5fb 5003662 bc6f57a 2cc87ec bc6f57a 2cc87ec bc6f57a 2cc87ec bc6f57a a8df5fb bc6f57a 2cc87ec efae5be a8df5fb efae5be a8df5fb 1681237 efae5be bc6f57a 2cc87ec bc6f57a 2cc87ec bc6f57a 2cc87ec bc6f57a 2cc87ec bc6f57a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 |
import json
import logging
from collections import defaultdict
from typing import Dict, List, Optional, Sequence, Union
from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
from rendering_utils_displacy import EntityRenderer
logger = logging.getLogger(__name__)
# adjusted from rendering_utils_displacy.TPL_ENT
TPL_ENT_WITH_ID = """
<mark class="entity" data-entity-id="{entity_id}" data-slice-idx="{slice_idx}" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;">
{text}
<span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span>
</mark>
"""
HIGHLIGHT_SPANS_JS = """
() => {
function maybeSetColor(entity, colorAttributeKey, colorDictKey) {
var color = entity.getAttribute('data-color-' + colorAttributeKey);
// if color is a json string, parse it and use the value at colorDictKey
try {
const colors = JSON.parse(color);
color = colors[colorDictKey];
} catch (e) {}
if (color) {
entity.style.backgroundColor = color;
entity.style.color = '#000';
}
}
function highlightRelationArguments(entityId) {
const entities = document.querySelectorAll('.entity');
// reset all entities
entities.forEach(entity => {
const color = entity.getAttribute('data-color-original');
entity.style.backgroundColor = color;
entity.style.color = '';
});
if (entityId !== null) {
var visitedEntities = new Set();
// highlight selected entity
// get all elements with attribute data-entity-id==entityId
const selectedEntityParts = document.querySelectorAll(`[data-entity-id="${entityId}"]`);
selectedEntityParts.forEach(selectedEntityPart => {
const label = selectedEntityPart.getAttribute('data-label');
maybeSetColor(selectedEntityPart, 'selected', label);
visitedEntities.add(selectedEntityPart);
}); // <-- Corrected closing parenthesis here
// if there is at least one part, get the first one and ...
if (selectedEntityParts.length > 0) {
const selectedEntity = selectedEntityParts[0];
// ... highlight tails and ...
const relationTailsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-tails'));
relationTailsAndLabels.forEach(relationTail => {
const tailEntityId = relationTail['entity-id'];
const tailEntityParts = document.querySelectorAll(`[data-entity-id="${tailEntityId}"]`);
tailEntityParts.forEach(tailEntity => {
const label = relationTail['label'];
maybeSetColor(tailEntity, 'tail', label);
visitedEntities.add(tailEntity);
}); // <-- Corrected closing parenthesis here
}); // <-- Corrected closing parenthesis here
// .. highlight heads
const relationHeadsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-heads'));
relationHeadsAndLabels.forEach(relationHead => {
const headEntityId = relationHead['entity-id'];
const headEntityParts = document.querySelectorAll(`[data-entity-id="${headEntityId}"]`);
headEntityParts.forEach(headEntity => {
const label = relationHead['label'];
maybeSetColor(headEntity, 'head', label);
visitedEntities.add(headEntity);
}); // <-- Corrected closing parenthesis here
}); // <-- Corrected closing parenthesis here
}
// highlight other entities
entities.forEach(entity => {
if (!visitedEntities.has(entity)) {
const label = entity.getAttribute('data-label');
maybeSetColor(entity, 'other', label);
}
});
}
}
function setHoverAduId(entityId) {
// get the textarea element that holds the reference adu id
let hoverAduIdDiv = document.querySelector('#hover_adu_id textarea');
// set the value of the input field
hoverAduIdDiv.value = entityId;
// trigger an input event to update the state
var event = new Event('input');
hoverAduIdDiv.dispatchEvent(event);
}
function setReferenceAduIdFromHover() {
// get the hover adu id
const hoverAduIdDiv = document.querySelector('#hover_adu_id textarea');
// get the value of the input field
const entityId = hoverAduIdDiv.value;
// get the textarea element that holds the reference adu id
let referenceAduIdDiv = document.querySelector('#selected_adu_id textarea');
// set the value of the input field
referenceAduIdDiv.value = entityId;
// trigger an input event to update the state
var event = new Event('input');
referenceAduIdDiv.dispatchEvent(event);
}
const entities = document.querySelectorAll('.entity');
entities.forEach(entity => {
// make the cursor a pointer
entity.style.cursor = 'pointer';
const alreadyHasListener = entity.getAttribute('data-has-listener');
if (alreadyHasListener) {
return;
}
entity.addEventListener('mouseover', () => {
const entityId = entity.getAttribute('data-entity-id');
highlightRelationArguments(entityId);
setHoverAduId(entityId);
});
entity.addEventListener('mouseout', () => {
highlightRelationArguments(null);
});
entity.setAttribute('data-has-listener', 'true');
});
const entityContainer = document.querySelector('.entities');
if (entityContainer) {
entityContainer.addEventListener('click', () => {
setReferenceAduIdFromHover();
});
// make the cursor a pointer
// entityContainer.style.cursor = 'pointer';
}
}
"""
def render_pretty_table(
text: str,
spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
span_id2idx: Dict[str, int],
binary_relations: Sequence[BinaryRelation],
**render_kwargs,
):
from prettytable import PrettyTable
t = PrettyTable()
t.field_names = ["head", "tail", "relation"]
t.align = "l"
for relation in list(binary_relations) + list(binary_relations):
t.add_row([str(relation.head), str(relation.tail), relation.label])
html = t.get_html_string(format=True)
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
return html
def render_displacy(
text: str,
spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
span_id2idx: Dict[str, int],
binary_relations: Sequence[BinaryRelation],
inject_relations=True,
colors_hover=None,
entity_options={},
**render_kwargs,
):
ents = []
for entity_id, idx in span_id2idx.items():
labeled_span = spans[idx]
# pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
# on hover and to inject the relation data.
if isinstance(labeled_span, LabeledSpan):
ents.append(
{
"start": labeled_span.start,
"end": labeled_span.end,
"label": labeled_span.label,
"params": {"entity_id": entity_id, "slice_idx": 0},
}
)
elif isinstance(labeled_span, LabeledMultiSpan):
for i, (start, end) in enumerate(labeled_span.slices):
ents.append(
{
"start": start,
"end": end,
"label": labeled_span.label,
"params": {"entity_id": entity_id, "slice_idx": i},
}
)
else:
raise ValueError(f"Unsupported labeled span type: {type(labeled_span)}")
spacy_doc = {
"text": text,
# the ents MUST be sorted by start and end
"ents": sorted(ents, key=lambda x: (x["start"], x["end"])),
"title": None,
}
# copy to avoid modifying the original options
entity_options = entity_options.copy()
# use the custom template with the entity ID
entity_options["template"] = TPL_ENT_WITH_ID
renderer = EntityRenderer(options=entity_options)
html = renderer.render([spacy_doc], page=True, minify=True).strip()
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
if inject_relations:
html = inject_relation_data(
html,
spans=spans,
span_id2idx=span_id2idx,
binary_relations=binary_relations,
additional_colors=colors_hover,
)
return html
def inject_relation_data(
html: str,
spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
span_id2idx: Dict[str, int],
binary_relations: Sequence[BinaryRelation],
additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
) -> str:
from bs4 import BeautifulSoup
# Parse the HTML using BeautifulSoup
soup = BeautifulSoup(html, "html.parser")
entity2tails = defaultdict(list)
entity2heads = defaultdict(list)
for relation in binary_relations:
entity2heads[relation.tail].append((relation.head, relation.label))
entity2tails[relation.head].append((relation.tail, relation.label))
annotation2id = {spans[span_idx]: span_id for span_id, span_idx in span_id2idx.items()}
# Add unique IDs to each entity
entities = soup.find_all(class_="entity")
for entity in entities:
original_color = entity["style"].split("background:")[1].split(";")[0].strip()
entity["data-color-original"] = original_color
if additional_colors is not None:
for key, color in additional_colors.items():
entity[f"data-color-{key}"] = (
json.dumps(color) if isinstance(color, dict) else color
)
entity_annotation = spans[span_id2idx[entity["data-entity-id"]]]
# sanity check.
if isinstance(entity_annotation, LabeledSpan):
annotation_text = entity_annotation.resolve()[1]
elif isinstance(entity_annotation, LabeledMultiSpan):
slice_idx = int(entity["data-slice-idx"])
annotation_text = entity_annotation.resolve()[1][slice_idx]
else:
raise ValueError(f"Unsupported entity type: {type(entity_annotation)}")
annotation_text_without_newline = annotation_text.replace("\n", "")
# Just check the start, because the text has the label attached to the end
if not entity.text.startswith(annotation_text_without_newline):
logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}")
entity["data-label"] = entity_annotation.label
entity["data-relation-tails"] = json.dumps(
[
{"entity-id": annotation2id[tail], "label": label}
for tail, label in entity2tails.get(entity_annotation, [])
if tail in annotation2id
]
)
entity["data-relation-heads"] = json.dumps(
[
{"entity-id": annotation2id[head], "label": label}
for head, label in entity2heads.get(entity_annotation, [])
if head in annotation2id
]
)
# Return the modified HTML as a string
return str(soup)
|