ScientificArgumentRecommender / rendering_utils.py
ArneBinder's picture
new demo setup with langchain retriever
2cc87ec verified
raw
history blame
12 kB
import json
import logging
from collections import defaultdict
from typing import Dict, List, Optional, Sequence, Union
from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
from rendering_utils_displacy import EntityRenderer
logger = logging.getLogger(__name__)
# adjusted from rendering_utils_displacy.TPL_ENT
TPL_ENT_WITH_ID = """
<mark class="entity" data-entity-id="{entity_id}" data-slice-idx="{slice_idx}" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;">
{text}
<span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span>
</mark>
"""
HIGHLIGHT_SPANS_JS = """
() => {
function maybeSetColor(entity, colorAttributeKey, colorDictKey) {
var color = entity.getAttribute('data-color-' + colorAttributeKey);
// if color is a json string, parse it and use the value at colorDictKey
try {
const colors = JSON.parse(color);
color = colors[colorDictKey];
} catch (e) {}
if (color) {
entity.style.backgroundColor = color;
entity.style.color = '#000';
}
}
function highlightRelationArguments(entityId) {
const entities = document.querySelectorAll('.entity');
// reset all entities
entities.forEach(entity => {
const color = entity.getAttribute('data-color-original');
entity.style.backgroundColor = color;
entity.style.color = '';
});
if (entityId !== null) {
var visitedEntities = new Set();
// highlight selected entity
// get all elements with attribute data-entity-id==entityId
const selectedEntityParts = document.querySelectorAll(`[data-entity-id="${entityId}"]`);
selectedEntityParts.forEach(selectedEntityPart => {
const label = selectedEntityPart.getAttribute('data-label');
maybeSetColor(selectedEntityPart, 'selected', label);
visitedEntities.add(selectedEntityPart);
}); // <-- Corrected closing parenthesis here
// if there is at least one part, get the first one and ...
if (selectedEntityParts.length > 0) {
const selectedEntity = selectedEntityParts[0];
// ... highlight tails and ...
const relationTailsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-tails'));
relationTailsAndLabels.forEach(relationTail => {
const tailEntityId = relationTail['entity-id'];
const tailEntityParts = document.querySelectorAll(`[data-entity-id="${tailEntityId}"]`);
tailEntityParts.forEach(tailEntity => {
const label = relationTail['label'];
maybeSetColor(tailEntity, 'tail', label);
visitedEntities.add(tailEntity);
}); // <-- Corrected closing parenthesis here
}); // <-- Corrected closing parenthesis here
// .. highlight heads
const relationHeadsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-heads'));
relationHeadsAndLabels.forEach(relationHead => {
const headEntityId = relationHead['entity-id'];
const headEntityParts = document.querySelectorAll(`[data-entity-id="${headEntityId}"]`);
headEntityParts.forEach(headEntity => {
const label = relationHead['label'];
maybeSetColor(headEntity, 'head', label);
visitedEntities.add(headEntity);
}); // <-- Corrected closing parenthesis here
}); // <-- Corrected closing parenthesis here
}
// highlight other entities
entities.forEach(entity => {
if (!visitedEntities.has(entity)) {
const label = entity.getAttribute('data-label');
maybeSetColor(entity, 'other', label);
}
});
}
}
function setHoverAduId(entityId) {
// get the textarea element that holds the reference adu id
let hoverAduIdDiv = document.querySelector('#hover_adu_id textarea');
// set the value of the input field
hoverAduIdDiv.value = entityId;
// trigger an input event to update the state
var event = new Event('input');
hoverAduIdDiv.dispatchEvent(event);
}
function setReferenceAduIdFromHover() {
// get the hover adu id
const hoverAduIdDiv = document.querySelector('#hover_adu_id textarea');
// get the value of the input field
const entityId = hoverAduIdDiv.value;
// get the textarea element that holds the reference adu id
let referenceAduIdDiv = document.querySelector('#selected_adu_id textarea');
// set the value of the input field
referenceAduIdDiv.value = entityId;
// trigger an input event to update the state
var event = new Event('input');
referenceAduIdDiv.dispatchEvent(event);
}
const entities = document.querySelectorAll('.entity');
entities.forEach(entity => {
// make the cursor a pointer
entity.style.cursor = 'pointer';
const alreadyHasListener = entity.getAttribute('data-has-listener');
if (alreadyHasListener) {
return;
}
entity.addEventListener('mouseover', () => {
const entityId = entity.getAttribute('data-entity-id');
highlightRelationArguments(entityId);
setHoverAduId(entityId);
});
entity.addEventListener('mouseout', () => {
highlightRelationArguments(null);
});
entity.setAttribute('data-has-listener', 'true');
});
const entityContainer = document.querySelector('.entities');
if (entityContainer) {
entityContainer.addEventListener('click', () => {
setReferenceAduIdFromHover();
});
// make the cursor a pointer
// entityContainer.style.cursor = 'pointer';
}
}
"""
def render_pretty_table(
text: str,
spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
span_id2idx: Dict[str, int],
binary_relations: Sequence[BinaryRelation],
**render_kwargs,
):
from prettytable import PrettyTable
t = PrettyTable()
t.field_names = ["head", "tail", "relation"]
t.align = "l"
for relation in list(binary_relations) + list(binary_relations):
t.add_row([str(relation.head), str(relation.tail), relation.label])
html = t.get_html_string(format=True)
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
return html
def render_displacy(
text: str,
spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
span_id2idx: Dict[str, int],
binary_relations: Sequence[BinaryRelation],
inject_relations=True,
colors_hover=None,
entity_options={},
**render_kwargs,
):
ents = []
for entity_id, idx in span_id2idx.items():
labeled_span = spans[idx]
# pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
# on hover and to inject the relation data.
if isinstance(labeled_span, LabeledSpan):
ents.append(
{
"start": labeled_span.start,
"end": labeled_span.end,
"label": labeled_span.label,
"params": {"entity_id": entity_id, "slice_idx": 0},
}
)
elif isinstance(labeled_span, LabeledMultiSpan):
for i, (start, end) in enumerate(labeled_span.slices):
ents.append(
{
"start": start,
"end": end,
"label": labeled_span.label,
"params": {"entity_id": entity_id, "slice_idx": i},
}
)
else:
raise ValueError(f"Unsupported labeled span type: {type(labeled_span)}")
spacy_doc = {
"text": text,
# the ents MUST be sorted by start and end
"ents": sorted(ents, key=lambda x: (x["start"], x["end"])),
"title": None,
}
# copy to avoid modifying the original options
entity_options = entity_options.copy()
# use the custom template with the entity ID
entity_options["template"] = TPL_ENT_WITH_ID
renderer = EntityRenderer(options=entity_options)
html = renderer.render([spacy_doc], page=True, minify=True).strip()
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
if inject_relations:
html = inject_relation_data(
html,
spans=spans,
span_id2idx=span_id2idx,
binary_relations=binary_relations,
additional_colors=colors_hover,
)
return html
def inject_relation_data(
html: str,
spans: Union[Sequence[LabeledSpan], Sequence[LabeledMultiSpan]],
span_id2idx: Dict[str, int],
binary_relations: Sequence[BinaryRelation],
additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
) -> str:
from bs4 import BeautifulSoup
# Parse the HTML using BeautifulSoup
soup = BeautifulSoup(html, "html.parser")
entity2tails = defaultdict(list)
entity2heads = defaultdict(list)
for relation in binary_relations:
entity2heads[relation.tail].append((relation.head, relation.label))
entity2tails[relation.head].append((relation.tail, relation.label))
annotation2id = {spans[span_idx]: span_id for span_id, span_idx in span_id2idx.items()}
# Add unique IDs to each entity
entities = soup.find_all(class_="entity")
for entity in entities:
original_color = entity["style"].split("background:")[1].split(";")[0].strip()
entity["data-color-original"] = original_color
if additional_colors is not None:
for key, color in additional_colors.items():
entity[f"data-color-{key}"] = (
json.dumps(color) if isinstance(color, dict) else color
)
entity_annotation = spans[span_id2idx[entity["data-entity-id"]]]
# sanity check.
if isinstance(entity_annotation, LabeledSpan):
annotation_text = entity_annotation.resolve()[1]
elif isinstance(entity_annotation, LabeledMultiSpan):
slice_idx = int(entity["data-slice-idx"])
annotation_text = entity_annotation.resolve()[1][slice_idx]
else:
raise ValueError(f"Unsupported entity type: {type(entity_annotation)}")
annotation_text_without_newline = annotation_text.replace("\n", "")
# Just check the start, because the text has the label attached to the end
if not entity.text.startswith(annotation_text_without_newline):
logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}")
entity["data-label"] = entity_annotation.label
entity["data-relation-tails"] = json.dumps(
[
{"entity-id": annotation2id[tail], "label": label}
for tail, label in entity2tails.get(entity_annotation, [])
if tail in annotation2id
]
)
entity["data-relation-heads"] = json.dumps(
[
{"entity-id": annotation2id[head], "label": label}
for head, label in entity2heads.get(entity_annotation, [])
if head in annotation2id
]
)
# Return the modified HTML as a string
return str(soup)