docling / ner_analyzer.py
thlinhares's picture
Create ner_analyzer.py
8166d49 verified
raw
history blame
2.76 kB
from transformers import AutoModelForTokenClassification, AutoTokenizer
import torch
from typing import List, Tuple
import logging
from .base_analyzer import BaseAnalyzer
logger = logging.getLogger(__name__)
class NERAnalyzer(BaseAnalyzer):
def __init__(self):
self.model_name = "dominguesm/ner-legal-bert-base-cased-ptbr"
logger.info(f"Carregando o modelo NER: {self.model_name}")
self.model = AutoModelForTokenClassification.from_pretrained(self.model_name)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
logger.info("Modelo NER e tokenizador carregados com sucesso")
def extract_entities(self, text: str) -> List[Tuple[str, str]]:
logger.debug("Iniciando extração de entidades com NER")
inputs = self.tokenizer(text, max_length=512, truncation=True, return_tensors="pt")
tokens = inputs.tokens()
with torch.no_grad():
outputs = self.model(**inputs).logits
predictions = torch.argmax(outputs, dim=2)
entities = []
for token, prediction in zip(tokens, predictions[0].numpy()):
entity_label = self.model.config.id2label[prediction]
if entity_label != "O":
entities.append((token, entity_label))
return entities
def extract_representatives(self, entities: List[Tuple[str, str]]) -> List[str]:
representatives = []
current_person = ""
current_organization = ""
for token, label in entities:
if label in ["B-PESSOA", "I-PESSOA"]:
current_person += token.replace("##", "")
else:
if current_person:
representatives.append(current_person)
current_person = ""
if label in ["B-ORGANIZACAO", "I-ORGANIZACAO"]:
current_organization += token.replace("##", "")
else:
if current_organization:
representatives.append(current_organization)
current_organization = ""
if current_person:
representatives.append(current_person)
if current_organization:
representatives.append(current_organization)
return representatives
def analyze(self, text: str) -> List[str]:
entities = self.extract_entities(text)
return self.extract_representatives(entities)
def format_output(self, representatives: List[str]) -> str:
output = "ANÁLISE DO CONTRATO SOCIAL (NER)\n\n"
output += "REPRESENTANTES IDENTIFICADOS:\n"
for rep in representatives:
output += f"- {rep}\n"
return output