kubinooo's picture
fixing broken predictions - debug prints enabled
01d3f3f
"""
Module used for making prediction wjile using inference from kubinooo/convnext-tiny-224-audio-deepfake-classification model
Author: Jakub Polnis
Copyright: Copyright 2025, Jakub Polnis
License: Apache 2.0
Email: [email protected]
"""
import torch
from process_audio import create_mel_spectrograms
def predict_image(image, processor, model):
if image.mode != 'RGB':
image = image.convert('RGB')
image = image
image.resize((224, 224))
pixel_values = processor(image, return_tensors="pt").pixel_values
model.eval()
with torch.no_grad():
outputs = model(pixel_values)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
prediction = model.config.id2label[predicted_class_idx]
if prediction.lower() == "real":
print("real ")
return {"real": 1.0, "fake": 0.0} # 100% confidence for real
else: # prediction == "fake"
print("fake ")
return {"real": 0.0, "fake": 1.0}
def prediction(file_path, processor, model):
total_real = 0.0
total_fake = 0.0
pil_images = create_mel_spectrograms(file_path, 2, 0)
for image in pil_images:
pred = predict_image(image, processor, model)
total_real += pred["real"]
total_fake += pred["fake"]
total = len(pil_images)
if total == 0:
return {"real": 0.0, "fake": 0.0}
print("real: " + str(round(total_real / total, 2)) + " fake" + str(round(total_fake / total, 2)))
return {
"real": round(total_real / total, 2),
"fake": round(total_fake / total, 2)
}