|
import onnxruntime as ort |
|
import numpy as np |
|
from tokenizers import Tokenizer |
|
import re |
|
|
|
|
|
NIKUD_CLASSES = [ |
|
"", |
|
"<MAT_LECT>", |
|
"\u05bc", |
|
"\u05b0", |
|
"\u05b1", |
|
"\u05b2", |
|
"\u05b3", |
|
"\u05b4", |
|
"\u05b5", |
|
"\u05b6", |
|
"\u05b7", |
|
"\u05b8", |
|
"\u05b9", |
|
"\u05ba", |
|
"\u05bb", |
|
"\u05bc\u05b0", |
|
"\u05bc\u05b1", |
|
"\u05bc\u05b2", |
|
"\u05bc\u05b3", |
|
"\u05bc\u05b4", |
|
"\u05bc\u05b5", |
|
"\u05bc\u05b6", |
|
"\u05bc\u05b7", |
|
"\u05bc\u05b8", |
|
"\u05bc\u05b9", |
|
"\u05bc\u05ba", |
|
"\u05bc\u05bb", |
|
"\u05c7", |
|
"\u05bc\u05c7", |
|
] |
|
SHIN_CLASSES = ["\u05c1", "\u05c2"] |
|
MAT_LECT_TOKEN = "<MAT_LECT>" |
|
MATRES_LETTERS = list("אוי") |
|
ALEF_ORD = ord("א") |
|
TAF_ORD = ord("ת") |
|
STRESS_CHAR = "\u05ab" |
|
MOBILE_SHVA_CHAR = "\u05bd" |
|
PREFIX_CHAR = "|" |
|
|
|
|
|
def is_hebrew_letter(char): |
|
return ALEF_ORD <= ord(char) <= TAF_ORD |
|
|
|
|
|
def is_matres_letter(char): |
|
return char in MATRES_LETTERS |
|
|
|
|
|
nikud_pattern = re.compile(r"[\u05B0-\u05BD\u05C1\u05C2\u05C7]") |
|
|
|
|
|
def remove_nikkud(text): |
|
return nikud_pattern.sub("", text) |
|
|
|
|
|
class OnnxModel: |
|
def __init__( |
|
self, model_path, tokenizer_name="dicta-il/dictabert-large-char-menaked" |
|
): |
|
|
|
self.tokenizer = Tokenizer.from_pretrained(tokenizer_name) |
|
|
|
|
|
self.session = ort.InferenceSession(model_path) |
|
self.input_names = [input.name for input in self.session.get_inputs()] |
|
self.output_names = [output.name for output in self.session.get_outputs()] |
|
|
|
def _create_inputs(self, sentences: list[str], padding: str): |
|
|
|
encodings = [] |
|
for sentence in sentences: |
|
encoding = self.tokenizer.encode(sentence) |
|
encodings.append(encoding) |
|
|
|
|
|
max_len = max(len(enc.ids) for enc in encodings) if padding == "longest" else 0 |
|
|
|
|
|
input_ids = [] |
|
attention_mask = [] |
|
offset_mapping = [] |
|
|
|
for encoding in encodings: |
|
ids = encoding.ids |
|
masks = [1] * len(ids) |
|
offsets = encoding.offsets |
|
|
|
|
|
if padding == "longest" and len(ids) < max_len: |
|
padding_length = max_len - len(ids) |
|
ids = ids + [self.tokenizer.token_to_id("[PAD]")] * padding_length |
|
masks = masks + [0] * padding_length |
|
offsets = offsets + [(0, 0)] * padding_length |
|
|
|
input_ids.append(ids) |
|
attention_mask.append(masks) |
|
offset_mapping.append(offsets) |
|
|
|
|
|
return { |
|
"input_ids": np.array(input_ids, dtype=np.int64), |
|
"attention_mask": np.array(attention_mask, dtype=np.int64), |
|
|
|
"token_type_ids": np.zeros_like(np.array(input_ids, dtype=np.int64)), |
|
}, offset_mapping |
|
|
|
def predict(self, sentences, mark_matres_lectionis=None, padding="longest"): |
|
sentences = [remove_nikkud(sentence) for sentence in sentences] |
|
inputs, offset_mapping = self._create_inputs(sentences, padding) |
|
|
|
|
|
outputs = self.session.run(self.output_names, inputs) |
|
|
|
|
|
nikud_idx = self.output_names.index("nikud_logits") |
|
shin_idx = self.output_names.index("shin_logits") |
|
nikud_logits = outputs[nikud_idx] |
|
shin_logits = outputs[shin_idx] |
|
|
|
additional_idx = self.output_names.index("additional_logits") |
|
additional_logits = outputs[additional_idx] |
|
|
|
|
|
nikud_predictions = np.argmax(nikud_logits, axis=-1) |
|
shin_predictions = np.argmax(shin_logits, axis=-1) |
|
stress_predictions = (additional_logits[..., 1] > 1).astype(np.int32) |
|
mobile_shva_predictions = (additional_logits[..., 2] > 1).astype(np.int32) |
|
prefix_predictions = (additional_logits[..., 3] > 1).astype(np.int32) |
|
|
|
ret = [] |
|
for sent_idx, (sentence, sent_offsets) in enumerate( |
|
zip(sentences, offset_mapping) |
|
): |
|
|
|
output = [] |
|
prev_index = 0 |
|
for idx, offsets in enumerate(sent_offsets): |
|
|
|
if offsets[0] > prev_index: |
|
output.append(sentence[prev_index : offsets[0]]) |
|
if offsets[1] - offsets[0] != 1: |
|
continue |
|
|
|
|
|
char = sentence[offsets[0] : offsets[1]] |
|
prev_index = offsets[1] |
|
if not is_hebrew_letter(char): |
|
output.append(char) |
|
continue |
|
|
|
nikud = NIKUD_CLASSES[nikud_predictions[sent_idx][idx]] |
|
shin = ( |
|
"" if char != "ש" else SHIN_CLASSES[shin_predictions[sent_idx][idx]] |
|
) |
|
|
|
|
|
if nikud == MAT_LECT_TOKEN: |
|
if not is_matres_letter(char): |
|
nikud = "" |
|
elif mark_matres_lectionis is not None: |
|
nikud = mark_matres_lectionis |
|
else: |
|
output.append(char) |
|
continue |
|
|
|
stress = ( |
|
STRESS_CHAR |
|
if stress_predictions is not None |
|
and stress_predictions[sent_idx][idx] == 1 |
|
else "" |
|
) |
|
mobile_shva = ( |
|
MOBILE_SHVA_CHAR |
|
if mobile_shva_predictions is not None |
|
and mobile_shva_predictions[sent_idx][idx] == 1 |
|
else "" |
|
) |
|
|
|
prefix = ( |
|
PREFIX_CHAR |
|
if prefix_predictions is not None |
|
and prefix_predictions[sent_idx][idx] == 1 |
|
else "" |
|
) |
|
|
|
output.append(char + shin + nikud + stress + mobile_shva + prefix) |
|
output.append(sentence[prev_index:]) |
|
ret.append("".join(output)) |
|
|
|
return ret |
|
|