Spaces:
Running
Running
from transformers import AutoModelForTokenClassification, AutoTokenizer | |
import torch | |
from typing import List, Tuple | |
import logging | |
from .base_analyzer import BaseAnalyzer | |
logger = logging.getLogger(__name__) | |
class NERAnalyzer(BaseAnalyzer): | |
def __init__(self): | |
self.model_name = "dominguesm/ner-legal-bert-base-cased-ptbr" | |
logger.info(f"Carregando o modelo NER: {self.model_name}") | |
self.model = AutoModelForTokenClassification.from_pretrained(self.model_name) | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
logger.info("Modelo NER e tokenizador carregados com sucesso") | |
def extract_entities(self, text: str) -> List[Tuple[str, str]]: | |
logger.debug("Iniciando extração de entidades com NER") | |
inputs = self.tokenizer(text, max_length=512, truncation=True, return_tensors="pt") | |
tokens = inputs.tokens() | |
with torch.no_grad(): | |
outputs = self.model(**inputs).logits | |
predictions = torch.argmax(outputs, dim=2) | |
entities = [] | |
for token, prediction in zip(tokens, predictions[0].numpy()): | |
entity_label = self.model.config.id2label[prediction] | |
if entity_label != "O": | |
entities.append((token, entity_label)) | |
return entities | |
def extract_representatives(self, entities: List[Tuple[str, str]]) -> List[str]: | |
representatives = [] | |
current_person = "" | |
current_organization = "" | |
for token, label in entities: | |
if label in ["B-PESSOA", "I-PESSOA"]: | |
current_person += token.replace("##", "") | |
else: | |
if current_person: | |
representatives.append(current_person) | |
current_person = "" | |
if label in ["B-ORGANIZACAO", "I-ORGANIZACAO"]: | |
current_organization += token.replace("##", "") | |
else: | |
if current_organization: | |
representatives.append(current_organization) | |
current_organization = "" | |
if current_person: | |
representatives.append(current_person) | |
if current_organization: | |
representatives.append(current_organization) | |
return representatives | |
def analyze(self, text: str) -> List[str]: | |
entities = self.extract_entities(text) | |
return self.extract_representatives(entities) | |
def format_output(self, representatives: List[str]) -> str: | |
output = "ANÁLISE DO CONTRATO SOCIAL (NER)\n\n" | |
output += "REPRESENTANTES IDENTIFICADOS:\n" | |
for rep in representatives: | |
output += f"- {rep}\n" | |
return output |