Saranath07
model to huggingface
1484d4b
raw
history blame
16.6 kB
import os
import re
import tqdm
import ujson
from pydload import dload
import zipfile
from abc import ABC, abstractmethod, abstractproperty
from indicnlp.normalize.indic_normalize import IndicNormalizerFactory
from urduhack import normalize as shahmukhi_normalize
from ..utils import *
LANG_WORD_REGEXES = {
lang_name: re.compile(f"[{SCRIPT_CODE_TO_UNICODE_CHARS_RANGE_STR[script_name]}]+")
for lang_name, script_name in LANG_CODE_TO_SCRIPT_CODE.items()
}
MODEL_FILE = 'transformer/indicxlit.pt'
DICTS_FOLDER = 'word_prob_dicts'
CHARS_FOLDER = 'corpus-bin'
DICT_FILE_FORMAT = '%s_word_prob_dict.json'
LANG_LIST_FILE = '../lang_list.txt'
normalizer_factory = IndicNormalizerFactory()
class BaseEngineTransformer(ABC):
@abstractproperty
def all_supported_langs(self):
pass
@abstractproperty
def tgt_langs(self):
pass
def __init__(self, models_path, beam_width, rescore):
# added by yash
print("Initializing Multilingual model for transliteration")
if 'en' in self.tgt_langs:
lang_pairs_csv = ','.join([lang+"-en" for lang in self.all_supported_langs])
else:
lang_pairs_csv = ','.join(["en-"+lang for lang in self.all_supported_langs])
# initialize the model
from .custom_interactive import Transliterator
self.transliterator = Transliterator(
os.path.join(models_path, CHARS_FOLDER),
os.path.join(models_path, MODEL_FILE),
lang_pairs_csv = lang_pairs_csv,
lang_list_file = os.path.join(models_path, LANG_LIST_FILE),
beam = beam_width, batch_size = 32,
)
self.beam_width = beam_width
self._rescore = rescore
if self._rescore:
# loading the word_prob_dict for rescoring module
dicts_folder = os.path.join(models_path, DICTS_FOLDER)
self.word_prob_dicts = {}
for la in tqdm.tqdm(self.tgt_langs, desc="Loading dicts into RAM"):
self.word_prob_dicts[la] = ujson.load(open(
os.path.join(dicts_folder, DICT_FILE_FORMAT%la)
))
def download_models(self, models_path, download_url):
'''
Download models from bucket
'''
# added by yash
model_file_path = os.path.join(models_path, MODEL_FILE)
if not os.path.isfile(model_file_path):
print('Downloading Multilingual model for transliteration')
remote_url = download_url
downloaded_zip_path = os.path.join(models_path, 'model.zip')
dload(url=remote_url, save_to_path=downloaded_zip_path, max_time=None)
if not os.path.isfile(downloaded_zip_path):
exit(f'ERROR: Unable to download model from {remote_url} into {models_path}')
with zipfile.ZipFile(downloaded_zip_path, 'r') as zip_ref:
zip_ref.extractall(models_path)
if os.path.isfile(model_file_path):
os.remove(downloaded_zip_path)
else:
exit(f'ERROR: Unable to find models in {models_path} after download')
print("Models downloaded to:", models_path)
print("NOTE: When uninstalling this library, REMEMBER to delete the models manually")
return model_file_path
def download_dicts(self, models_path, download_url):
'''
Download language model probablitites dictionaries
'''
dicts_folder = os.path.join(models_path, DICTS_FOLDER)
if not os.path.isdir(dicts_folder):
# added by yash
print('Downloading language model probablitites dictionaries for rescoring module')
remote_url = download_url
downloaded_zip_path = os.path.join(models_path, 'dicts.zip')
dload(url=remote_url, save_to_path=downloaded_zip_path, max_time=None)
if not os.path.isfile(downloaded_zip_path):
exit(f'ERROR: Unable to download model from {remote_url} into {models_path}')
with zipfile.ZipFile(downloaded_zip_path, 'r') as zip_ref:
zip_ref.extractall(models_path)
if os.path.isdir(dicts_folder):
os.remove(downloaded_zip_path)
else:
exit(f'ERROR: Unable to find models in {models_path} after download')
return dicts_folder
def indic_normalize(self, words, lang_code):
if lang_code not in ['gom', 'ks', 'ur', 'mai', 'brx', 'mni']:
normalizer = normalizer_factory.get_normalizer(lang_code)
words = [ normalizer.normalize(word) for word in words ]
if lang_code in ['mai', 'brx' ]:
normalizer = normalizer_factory.get_normalizer('hi')
words = [ normalizer.normalize(word) for word in words ]
if lang_code in [ 'ur' ]:
words = [ shahmukhi_normalize(word) for word in words ]
if lang_code == 'gom':
normalizer = normalizer_factory.get_normalizer('kK')
words = [ normalizer.normalize(word) for word in words ]
# normalize and tokenize the words
# words = self.normalize(words)
# manully mapping certain characters
# words = self.hard_normalizer(words)
return words
def pre_process(self, words, src_lang, tgt_lang):
# TODO: Move normalize outside to efficiently perform at sentence-level
if src_lang != 'en':
self.indic_normalize(words, src_lang)
# convert the word into sentence which contains space separated chars
words = [' '.join(list(word.lower())) for word in words]
lang_code = tgt_lang if src_lang == 'en' else src_lang
# adding language token
words = ['__'+ lang_code +'__ ' + word for word in words]
return words
def rescore(self, res_dict, result_dict, tgt_lang, alpha ):
alpha = alpha
# word_prob_dict = {}
word_prob_dict = self.word_prob_dicts[tgt_lang]
candidate_word_prob_norm_dict = {}
candidate_word_result_norm_dict = {}
input_data = {}
for i in res_dict.keys():
input_data[res_dict[i]['S']] = []
for j in range(len(res_dict[i]['H'])):
input_data[res_dict[i]['S']].append( res_dict[i]['H'][j][0] )
for src_word in input_data.keys():
candidates = input_data[src_word]
candidates = [' '.join(word.split(' ')) for word in candidates]
total_score = 0
if src_word.lower() in result_dict.keys():
for candidate_word in candidates:
if candidate_word in result_dict[src_word.lower()].keys():
total_score += result_dict[src_word.lower()][candidate_word]
candidate_word_result_norm_dict[src_word.lower()] = {}
for candidate_word in candidates:
candidate_word_result_norm_dict[src_word.lower()][candidate_word] = (result_dict[src_word.lower()][candidate_word]/total_score)
candidates = [''.join(word.split(' ')) for word in candidates ]
total_prob = 0
for candidate_word in candidates:
if candidate_word in word_prob_dict.keys():
total_prob += word_prob_dict[candidate_word]
candidate_word_prob_norm_dict[src_word.lower()] = {}
for candidate_word in candidates:
if candidate_word in word_prob_dict.keys():
candidate_word_prob_norm_dict[src_word.lower()][candidate_word] = (word_prob_dict[candidate_word]/total_prob)
output_data = {}
for src_word in input_data.keys():
temp_candidates_tuple_list = []
candidates = input_data[src_word]
candidates = [ ''.join(word.split(' ')) for word in candidates]
for candidate_word in candidates:
if candidate_word in word_prob_dict.keys():
temp_candidates_tuple_list.append((candidate_word, alpha*candidate_word_result_norm_dict[src_word.lower()][' '.join(list(candidate_word))] + (1-alpha)*candidate_word_prob_norm_dict[src_word.lower()][candidate_word] ))
else:
temp_candidates_tuple_list.append((candidate_word, 0 ))
temp_candidates_tuple_list.sort(key = lambda x: x[1], reverse = True )
temp_candidates_list = []
for cadidate_tuple in temp_candidates_tuple_list:
temp_candidates_list.append(' '.join(list(cadidate_tuple[0])))
output_data[src_word] = temp_candidates_list
return output_data
def post_process(self, translation_str, tgt_lang):
lines = translation_str.split('\n')
list_s = [line for line in lines if 'S-' in line]
# list_t = [line for line in lines if 'T-' in line]
list_h = [line for line in lines if 'H-' in line]
# list_d = [line for line in lines if 'D-' in line]
list_s.sort(key = lambda x: int(x.split('\t')[0].split('-')[1]) )
# list_t.sort(key = lambda x: int(x.split('\t')[0].split('-')[1]) )
list_h.sort(key = lambda x: int(x.split('\t')[0].split('-')[1]) )
# list_d.sort(key = lambda x: int(x.split('\t')[0].split('-')[1]) )
res_dict = {}
for s in list_s:
s_id = int(s.split('\t')[0].split('-')[1])
res_dict[s_id] = { 'S' : s.split('\t')[1] }
# for t in list_t:
# t_id = int(t.split('\t')[0].split('-')[1])
# if s_id == t_id:
# res_dict[s_id]['T'] = t.split('\t')[1]
res_dict[s_id]['H'] = []
# res_dict[s_id]['D'] = []
for h in list_h:
h_id = int(h.split('\t')[0].split('-')[1])
if s_id == h_id:
res_dict[s_id]['H'].append( ( h.split('\t')[2], pow(2,float(h.split('\t')[1])) ) )
# for d in list_d:
# d_id = int(d.split('\t')[0].split('-')[1])
# if s_id == d_id:
# res_dict[s_id]['D'].append( ( d.split('\t')[2], pow(2,float(d.split('\t')[1])) ) )
for r in res_dict.keys():
res_dict[r]['H'].sort(key = lambda x : float(x[1]) ,reverse =True)
# res_dict[r]['D'].sort(key = lambda x : float(x[1]) ,reverse =True)
# for rescoring
result_dict = {}
for i in res_dict.keys():
result_dict[res_dict[i]['S']] = {}
for j in range(len(res_dict[i]['H'])):
result_dict[res_dict[i]['S']][res_dict[i]['H'][j][0]] = res_dict[i]['H'][j][1]
transliterated_word_list = []
if self._rescore:
output_dir = self.rescore(res_dict, result_dict, tgt_lang, alpha = 0.9)
for src_word in output_dir.keys():
for j in range(len(output_dir[src_word])):
transliterated_word_list.append( output_dir[src_word][j] )
else:
for i in res_dict.keys():
# transliterated_word_list.append( res_dict[i]['S'] + ' : ' + res_dict[i]['H'][0][0] )
for j in range(len(res_dict[i]['H'])):
transliterated_word_list.append( res_dict[i]['H'][j][0] )
# remove extra spaces
# transliterated_word_list = [''.join(pair.split(':')[0].split(' ')[1:]) + ' : ' + ''.join(pair.split(':')[1].split(' ')) for pair in transliterated_word_list]
transliterated_word_list = [''.join(word.split(' ')) for word in transliterated_word_list]
return transliterated_word_list
def _transliterate_word(self, text, src_lang, tgt_lang, topk=4, nativize_punctuations=True, nativize_numerals=False):
if not text:
return text
text = text.lower().strip()
if src_lang != 'en':
# Our model does not transliterate native punctuations or numerals
# So process them first so that they are not considered for transliteration
text = text.translate(INDIC_TO_LATIN_PUNCT_TRANSLATOR)
text = text.translate(INDIC_TO_STANDARD_NUMERALS_TRANSLATOR)
else:
# Transliterate punctuations & numerals if tgt_lang is Indic
if nativize_punctuations:
if tgt_lang in RTL_LANG_CODES:
text = text.translate(LATIN_TO_PERSOARABIC_PUNC_TRANSLATOR)
text = nativize_latin_fullstop(text, tgt_lang)
if nativize_numerals:
text = text.translate(LATIN_TO_NATIVE_NUMERALS_TRANSLATORS[tgt_lang])
matches = LANG_WORD_REGEXES[src_lang].findall(text)
if not matches:
return [text]
src_word = matches[-1]
transliteration_list = self.batch_transliterate_words([src_word], src_lang, tgt_lang, topk=topk)[0]
if tgt_lang != 'en' or tgt_lang != 'sa':
# If users want to avoid yuktAkshara, this is facilitated by allowing them to type subwords inorder to construct a word
# For example, "ଜନ୍‍ସନ୍‍ଙ୍କୁ" can be written by "ଜନ୍‍" + "ସନ୍‍" + "କୁ"
# Not enabled for Sanskrit, as sandhi compounds are generally written word-by-word
for i in range(len(transliteration_list)):
transliteration_list[i] = hardfix_wordfinal_virama(transliteration_list[i])
if src_word == text:
return transliteration_list
return [
rreplace(text, src_word, tgt_word)
for tgt_word in transliteration_list
]
def batch_transliterate_words(self, words, src_lang, tgt_lang, topk=4):
perprcossed_words = self.pre_process(words, src_lang, tgt_lang)
translation_str = self.transliterator.translate(perprcossed_words, nbest=topk)
# FIXME: Handle properly in `post_process()` to return results for all words
transliteration_list = self.post_process(translation_str, tgt_lang)
# Lang-specific patches. TODO: Move to indic-nlp-library
if tgt_lang == 'mr':
for i in range(len(transliteration_list)):
transliteration_list[i] = transliteration_list[i].replace("अॅ", 'ॲ')
if tgt_lang == 'or':
for i in range(len(transliteration_list)):
transliteration_list[i] = fix_odia_confusing_ambiguous_yuktakshara(transliteration_list[i])
if tgt_lang == 'sa':
for i in range(len(transliteration_list)):
transliteration_list[i] = explicit_devanagari_wordfinal_schwa_delete(words[0], transliteration_list[i])
# Retain only unique, preserving order
transliteration_list = list(dict.fromkeys(transliteration_list))
return [transliteration_list]
def _transliterate_sentence(self, text, src_lang, tgt_lang, nativize_punctuations=True, nativize_numerals=False):
# TODO: Minimize code redundancy with `_transliterate_word()`
if not text:
return text
text = text.lower().strip()
if src_lang != 'en':
# Our model does not transliterate native punctuations or numerals
# So process them first so that they are not considered for transliteration
text = text.translate(INDIC_TO_LATIN_PUNCT_TRANSLATOR)
text = text.translate(INDIC_TO_STANDARD_NUMERALS_TRANSLATOR)
else:
# Transliterate punctuations & numerals if tgt_lang is Indic
if nativize_punctuations:
if tgt_lang in RTL_LANG_CODES:
text = text.translate(LATIN_TO_PERSOARABIC_PUNC_TRANSLATOR)
text = nativize_latin_fullstop(text, tgt_lang)
if nativize_numerals:
text = text.translate(LATIN_TO_NATIVE_NUMERALS_TRANSLATORS[tgt_lang])
matches = LANG_WORD_REGEXES[src_lang].findall(text)
if not matches:
return text
out_str = text
for match in matches:
result = self.batch_transliterate_words([match], src_lang, tgt_lang)[0][0]
out_str = re.sub(match, result, out_str, 1)
return out_str