RamziBm's picture
init
bdb955e
raw
history blame
11.8 kB
import gradio as gr
import cv2
import numpy as np
import torch
from pathlib import Path
import time
import traceback
# Importer les éléments nécessaires depuis les autres modules du projet
try:
from tvcalib.infer.module import TvCalibInferModule
# On essaie d'importer la fonction de pré-traitement depuis main.py
# Si main.py n'est pas conçu pour être importé, il faudra peut-être copier/coller cette fonction ici
from main import preprocess_image_tvcalib, IMAGE_SHAPE, SEGMENTATION_MODEL_PATH
from visualizer import (
create_minimap_view,
create_minimap_with_offset_skeletons,
DYNAMIC_SCALE_MIN_MODULATION,
DYNAMIC_SCALE_MAX_MODULATION
)
from pose_estimator import get_player_data
except ImportError as e:
print(f"Erreur d'importation : {e}")
print("Assurez-vous que les modules tvcalib, main, visualizer, pose_estimator sont accessibles.")
# On pourrait mettre des stubs ou lever une exception ici pour Gradio
raise e
# --- Configuration Globale (Modèle, etc.) ---
# Essayer de charger le modèle une seule fois globalement peut améliorer les performances
# mais attention à la gestion de l'état dans les environnements multi-utilisateurs/threads de Spaces
# Pour l'instant, on le chargera dans la fonction de traitement.
# MODEL = None # Optionnel: Charger ici
# DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Utilisation du device : {DEVICE}")
if not SEGMENTATION_MODEL_PATH.exists():
print(f"AVERTISSEMENT : Modèle de segmentation introuvable : {SEGMENTATION_MODEL_PATH}")
print("L'application risque de ne pas fonctionner. Assurez-vous que le fichier est présent.")
# Gradio peut quand même démarrer, mais le traitement échouera.
# --- Fonction Principale de Traitement ---
def process_image_and_generate_minimaps(input_image_bgr, optim_steps, target_avg_scale):
"""
Prend une image BGR (NumPy), les étapes d'optimisation et l'échelle cible,
retourne les deux minimaps (NumPy BGR).
"""
global DEVICE # Utiliser le device défini globalement
print("\n--- Nouvelle requête ---")
print(f"Paramètres: optim_steps={optim_steps}, target_avg_scale={target_avg_scale}")
# Vérifier si le modèle de segmentation existe (important car on ne peut pas l'afficher dans l'UI facilement)
if not SEGMENTATION_MODEL_PATH.exists():
# Retourner des images noires ou des messages d'erreur
error_msg = f"Erreur: Modèle {SEGMENTATION_MODEL_PATH} introuvable."
print(error_msg)
placeholder = np.zeros((300, 500, 3), dtype=np.uint8) # Placeholder noir
cv2.putText(placeholder, error_msg, (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
return placeholder, placeholder.copy() # Retourner deux placeholders
try:
# 1. Initialisation du modèle TvCalib (peut être lent si fait à chaque fois)
# Pourrait être optimisé en chargeant globalement (voir commentaire plus haut)
print("Initialisation de TvCalibInferModule...")
start_init = time.time()
model = TvCalibInferModule(
segmentation_checkpoint=SEGMENTATION_MODEL_PATH,
image_shape=IMAGE_SHAPE, # Utilise la constante importée
optim_steps=int(optim_steps), # Assurer que c'est un entier
lens_dist=False
)
# Déplacer le modèle sur le bon device ici explicitement si nécessaire
# model.to(DEVICE) # TvCalibInferModule devrait gérer ça en interne ? A vérifier.
print(f"✓ Modèle chargé sur {next(model.model_calib.parameters()).device} en {time.time() - start_init:.3f}s")
model_device = next(model.model_calib.parameters()).device # Vérifier le device réel
# 2. Prétraitement de l'image
print("Prétraitement de l'image...")
start_preprocess = time.time()
# preprocess_image_tvcalib attend BGR, Gradio fournit BGR par défaut avec type="numpy"
# Assurez-vous que preprocess_image_tvcalib déplace bien le tenseur sur le bon device
image_tensor, image_bgr_resized, image_rgb_resized = preprocess_image_tvcalib(input_image_bgr)
# Vérifier/forcer le device du tenseur
image_tensor = image_tensor.to(model_device)
print(f"Temps de prétraitement TvCalib : {time.time() - start_preprocess:.3f}s")
# 3. Exécuter la calibration (Segmentation + Optimisation)
print("Exécution de la segmentation...")
start_segment = time.time()
with torch.no_grad():
keypoints = model._segment(image_tensor)
print(f"Temps de segmentation : {time.time() - start_segment:.3f}s")
print("Exécution de la calibration (optimisation)...")
start_calibrate = time.time()
homography = model._calibrate(keypoints)
print(f"Temps de calibration : {time.time() - start_calibrate:.3f}s")
if homography is None:
print("Aucune homographie n'a pu être calculée.")
# Retourner des placeholders avec message
placeholder = np.zeros((300, 500, 3), dtype=np.uint8)
cv2.putText(placeholder, "Homographie non calculee", (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
return placeholder, placeholder.copy()
if isinstance(homography, torch.Tensor):
homography_np = homography.detach().cpu().numpy()
else:
homography_np = np.array(homography) # Assurer que c'est un NumPy array
print("✓ Homographie calculée.")
# 4. Extraction des données joueurs
print("Extraction des données joueurs (pose+couleur)...")
start_pose = time.time()
# get_player_data attend une image BGR
player_list = get_player_data(image_bgr_resized)
print(f"Temps d'extraction données joueurs : {time.time() - start_pose:.3f}s ({len(player_list)} joueurs trouvés)")
# 5. Calcul de l'échelle de base
print("Calcul de l'échelle de base...")
# Reprend la logique de main.py pour estimer l'échelle de base
avg_modulation_expected = DYNAMIC_SCALE_MIN_MODULATION + \
(DYNAMIC_SCALE_MAX_MODULATION - DYNAMIC_SCALE_MIN_MODULATION) * (1.0 - 0.5)
estimated_base_scale = target_avg_scale
if avg_modulation_expected != 0:
estimated_base_scale = target_avg_scale / avg_modulation_expected
print(f" Échelle de base interne estimée pour cible {target_avg_scale:.3f} : {estimated_base_scale:.3f}")
# 6. Génération des minimaps
print("Génération des minimaps...")
start_viz = time.time()
# Minimap avec projection (image RGB attendue par la fonction)
minimap_original = create_minimap_view(image_rgb_resized, homography_np)
# Minimap avec squelettes (utilise l'échelle estimée)
minimap_offset_skeletons, actual_avg_scale = create_minimap_with_offset_skeletons(
player_list,
homography_np,
base_skeleton_scale=estimated_base_scale
)
print(f"Temps de génération des minimaps : {time.time() - start_viz:.3f}s")
if actual_avg_scale is not None:
print(f"Échelle moyenne CIBLE demandée : {target_avg_scale:.3f}")
print(f"Échelle moyenne FINALE RÉELLEMENT appliquée : {actual_avg_scale:.3f}")
# Vérifier si les minimaps ont été créées (peuvent être None en cas d'erreur interne)
if minimap_original is None:
print("Erreur: La minimap originale n'a pas pu être générée.")
minimap_original = np.zeros((300, 500, 3), dtype=np.uint8)
cv2.putText(minimap_original, "Erreur Minimap Originale", (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
if minimap_offset_skeletons is None:
print("Erreur: La minimap squelettes n'a pas pu être générée.")
minimap_offset_skeletons = np.zeros((300, 500, 3), dtype=np.uint8)
cv2.putText(minimap_offset_skeletons, "Erreur Minimap Squelettes", (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA)
# Gradio attend des images RGB pour l'affichage, nos fonctions retournent probablement BGR (via OpenCV)
# Conversion BGR -> RGB si nécessaire
if minimap_original.shape[2] == 3: # Assurer que c'est une image couleur
minimap_original = cv2.cvtColor(minimap_original, cv2.COLOR_BGR2RGB)
if minimap_offset_skeletons.shape[2] == 3:
minimap_offset_skeletons = cv2.cvtColor(minimap_offset_skeletons, cv2.COLOR_BGR2RGB)
print("✓ Traitement terminé.")
return minimap_original, minimap_offset_skeletons
except Exception as e:
print(f"Erreur majeure lors du traitement : {e}")
traceback.print_exc()
# Retourner des placeholders avec message d'erreur général
placeholder = np.zeros((300, 500, 3), dtype=np.uint8)
cv2.putText(placeholder, f"Erreur: {e}", (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1, cv2.LINE_AA)
return placeholder, placeholder.copy()
# --- Interface Gradio ---
with gr.Blocks() as demo:
gr.Markdown("# Foot Calib Pos Image Processor - Minimap Generator")
gr.Markdown(
"Upload a football pitch image to compute homography (TvCalib), "
"detect players (RT-DETR/ViTPose), and generate two minimap visualizations."
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="numpy", label="Input Image (.jpg, .png)")
optim_steps_slider = gr.Slider(
minimum=100, maximum=2000, step=50, value=500,
label="TvCalib Optimization Steps",
info="Number of iterations to refine homography."
)
target_scale_slider = gr.Slider(
minimum=0.1, maximum=2.5, step=0.05, value=0.35,
label="Target Average Skeleton Scale",
info="Adjusts the desired average size of skeletons on the minimap."
)
submit_button = gr.Button("Generate Minimaps", variant="primary")
with gr.Column(scale=2):
output_minimap_orig = gr.Image(type="numpy", label="Minimap with Original Projection", interactive=False)
output_minimap_skel = gr.Image(type="numpy", label="Minimap with Offset Skeletons", interactive=False)
# Connecter le bouton à la fonction de traitement
submit_button.click(
fn=process_image_and_generate_minimaps,
inputs=[input_image, optim_steps_slider, target_scale_slider],
outputs=[output_minimap_orig, output_minimap_skel]
)
# Ajouter des exemples (optionnel mais utile pour Spaces)
gr.Examples(
examples=[
["data/img1.png", 500, 1.35],
["data/img2.png", 1000, 1.5],
["data/img3.png", 500, 0.8],
["data/7.jpg", 500, 1], # Add .jpg examples
["data/15.jpg", 800, 1.35],
],
inputs=[input_image, optim_steps_slider, target_scale_slider],
outputs=[output_minimap_orig, output_minimap_skel], # Outputs won't be pre-calculated here, just to populate inputs
fn=process_image_and_generate_minimaps, # Function will be called if example is clicked
cache_examples=False # Important if processing is long or depends on external models
)
# --- Lancement de l'application ---
if __name__ == "__main__":
# share=True creates a temporary public link (useful for testing outside localhost)
# debug=True shows more Gradio logs in the console
demo.launch(debug=True)