from transformers import AutoModelForTokenClassification, AutoTokenizer import torch from typing import List, Tuple import logging from .base_analyzer import BaseAnalyzer logger = logging.getLogger(__name__) class NERAnalyzer(BaseAnalyzer): def __init__(self): self.model_name = "dominguesm/ner-legal-bert-base-cased-ptbr" logger.info(f"Carregando o modelo NER: {self.model_name}") self.model = AutoModelForTokenClassification.from_pretrained(self.model_name) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) logger.info("Modelo NER e tokenizador carregados com sucesso") def extract_entities(self, text: str) -> List[Tuple[str, str]]: logger.debug("Iniciando extração de entidades com NER") inputs = self.tokenizer(text, max_length=512, truncation=True, return_tensors="pt") tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert ids back to tokens with torch.no_grad(): outputs = self.model(**inputs).logits predictions = torch.argmax(outputs, dim=2) entities = [] for token, prediction in zip(tokens, predictions[0].numpy()): entity_label = self.model.config.id2label[prediction] if entity_label != "O": # Ignore non-entity tokens entities.append((token, entity_label)) logger.info(f"Entidades extraídas: {entities}") return entities def extract_representatives(self, entities: List[Tuple[str, str]]) -> List[str]: representatives = [] current_entity = [] current_label = None for token, label in entities: # Se o token começa com "##", é uma continuação da entidade anterior if token.startswith("##"): if current_entity: # Se já houver uma entidade em construção, continue current_entity.append(token[2:]) # Remove "##" e adiciona à entidade elif label.startswith("B-"): # Novo começo de entidade # Se há uma entidade em curso, adiciona à lista if current_entity: representatives.append(" ".join(current_entity)) # Reinicia a entidade current_entity = [token] current_label = label else: # Continuar com "I-" ou "O" if current_entity: current_entity.append(token) # Adicionar a última entidade if current_entity: representatives.append(" ".join(current_entity)) logger.info(f"Representantes extraídos: {representatives}") return representatives def analyze(self, text: str) -> List[str]: entities = self.extract_entities(text) return self.extract_representatives(entities) def format_output(self, representatives: List[str]) -> str: output = "ANÁLISE DO CONTRATO SOCIAL (NER)\n\n" output += "REPRESENTANTES IDENTIFICADOS:\n" for rep in representatives: output += f"- {rep}\n" return output