File size: 11,172 Bytes
bdb955e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 |
import torch
import numpy as np
import cv2
from PIL import Image
from transformers import AutoProcessor, RTDetrForObjectDetection, VitPoseForPoseEstimation
from pathlib import Path
# --- Global variables for models and processor (lazy loading) ---
person_processor = None
person_model = None
pose_processor = None
pose_model = None
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Pose Estimator: Using device: {device}")
# --- Constantes pour la couleur et le dessin ---
# Utilisation de tuples BGR pour les couleurs
DEFAULT_MARKER_COLOR = (255, 255, 255) # Blanc
MIN_PIXELS_FOR_COLOR = 20 # Nombre minimum de pixels valides dans la ROI pour tenter de calculer la couleur
CONFIDENCE_THRESHOLD_KEYPOINTS = 0.3 # Seuil pour considérer un keypoint fiable pour la ROI et le dessin
SKELETON_THICKNESS = 2
# Définition des segments du squelette (indices COCO 0-16)
# 0:Nose, 1:L_Eye, 2:R_Eye, 3:L_Ear, 4:R_Ear, 5:L_Shoulder, 6:R_Shoulder,
# 7:L_Elbow, 8:R_Elbow, 9:L_Wrist, 10:R_Wrist, 11:L_Hip, 12:R_Hip,
# 13:L_Knee, 14:R_Knee, 15:L_Ankle, 16:R_Ankle
SKELETON_EDGES = [
# Tête
(0, 1), (0, 2), (1, 3), (2, 4),
# Torse
(5, 6), (5, 11), (6, 12), (11, 12),
# Bras Gauche
(5, 7), (7, 9),
# Bras Droit
(6, 8), (8, 10),
# Jambe Gauche
(11, 13), (13, 15),
# Jambe Droite
(12, 14), (14, 16)
]
# Indices des keypoints pour le torse et les chevilles
TORSO_KP_INDICES = [5, 6, 11, 12] # Épaules, Hanches
LEFT_ANKLE_KP_INDEX = 15
RIGHT_ANKLE_KP_INDEX = 16
def _load_models():
"""Loads the models if they haven't been loaded yet."""
global person_processor, person_model, pose_processor, pose_model
if person_processor is None:
print("Loading RTDetr person detector model...")
person_processor = AutoProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
person_model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365", device_map=device)
print("✓ RTDetr loaded.")
if pose_processor is None:
print("Loading ViTPose pose estimation model...")
pose_processor = AutoProcessor.from_pretrained("usyd-community/vitpose-base-simple")
pose_model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple", device_map=device)
print("✓ ViTPose loaded.")
def _is_color_greenish(bgr_pixel, threshold=10):
b, g, r = bgr_pixel
return g > b + threshold and g > r + threshold
def _is_color_grayscale(bgr_pixel, tolerance=30):
b, g, r = bgr_pixel
min_val, max_val = min(b, g, r), max(b, g, r)
is_dark = max_val < 50
is_light = min_val > 200
is_low_saturation = (max_val - min_val) < tolerance
return is_dark or is_light or is_low_saturation
def _get_average_color(roi_bgr):
"""Calcule la couleur moyenne d'une ROI après filtrage."""
if roi_bgr is None or roi_bgr.size == 0:
return DEFAULT_MARKER_COLOR
try:
pixels = roi_bgr.reshape(-1, 3)
valid_pixels = []
for pixel in pixels:
if not _is_color_greenish(pixel) and not _is_color_grayscale(pixel):
valid_pixels.append(pixel)
if len(valid_pixels) < MIN_PIXELS_FOR_COLOR:
return DEFAULT_MARKER_COLOR
avg_color = np.mean(valid_pixels, axis=0)
return tuple(map(int, avg_color))
except Exception as e:
print(f" Erreur calcul couleur moyenne: {e}. Utilisation couleur défaut.")
return DEFAULT_MARKER_COLOR
def get_player_data(image_bgr: np.ndarray) -> list:
"""
Detects persons, estimates pose, calculates average torso color,
and returns a list of data for each player.
Args:
image_bgr: The input image in BGR format (NumPy array).
Returns:
A list of dictionaries, each containing:
{
'keypoints': np.ndarray (17, 2),
'scores': np.ndarray (17,),
'bbox': np.ndarray (4,) [x1, y1, x2, y2],
'avg_color': tuple (b, g, r)
}
Returns an empty list if no persons are detected or an error occurs.
"""
_load_models()
player_list = []
height, width = image_bgr.shape[:2]
try:
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
image_pil = Image.fromarray(image_rgb)
# --- Stage 1: Detect humans ---
inputs_det = person_processor(images=image_pil, return_tensors="pt").to(device)
with torch.no_grad():
outputs_det = person_model(**inputs_det)
results_det = person_processor.post_process_object_detection(
outputs_det, target_sizes=torch.tensor([(height, width)]), threshold=0.5
)
result_det = results_det[0]
person_boxes = result_det["boxes"][result_det["labels"] == 0].cpu().numpy()
if len(person_boxes) == 0:
print("No persons detected.")
return player_list
person_boxes_coco = person_boxes.copy()
person_boxes_coco[:, 2] = person_boxes_coco[:, 2] - person_boxes_coco[:, 0]
person_boxes_coco[:, 3] = person_boxes_coco[:, 3] - person_boxes_coco[:, 1]
# --- Stage 2: Detect keypoints ---
inputs_pose = pose_processor(image_pil, boxes=[person_boxes_coco], return_tensors="pt").to(device)
with torch.no_grad():
outputs_pose = pose_model(**inputs_pose)
pose_results = pose_processor.post_process_pose_estimation(outputs_pose, boxes=[person_boxes_coco])
image_pose_result = pose_results[0]
if not image_pose_result:
print("Pose estimation did not return results.")
return player_list
# --- Stage 3: Process each person ---
for i, person_box_xyxy in enumerate(person_boxes):
if i >= len(image_pose_result): continue
pose_result = image_pose_result[i]
xy = pose_result['keypoints'].cpu().numpy()
scores = pose_result['scores'].cpu().numpy()
# Ensure xy shape is correct before proceeding
if xy.shape != (17, 2):
print(f"Person {i}: Unexpected keypoints shape {xy.shape}, skipping.")
continue
# -- Define Torso ROI --
reliable_torso_keypoints = xy[TORSO_KP_INDICES][scores[TORSO_KP_INDICES] > CONFIDENCE_THRESHOLD_KEYPOINTS]
x1_box, y1_box, x2_box, y2_box = map(int, person_box_xyxy)
box_h = y2_box - y1_box
box_w = x2_box - x1_box
if len(reliable_torso_keypoints) >= 3:
min_x_kp = int(np.min(reliable_torso_keypoints[:, 0]))
max_x_kp = int(np.max(reliable_torso_keypoints[:, 0]))
min_y_kp = int(np.min(reliable_torso_keypoints[:, 1]))
max_y_kp = int(np.max(reliable_torso_keypoints[:, 1]))
roi_x1 = max(x1_box, min_x_kp - 5); roi_y1 = max(y1_box, min_y_kp - 5)
roi_x2 = min(x2_box, max_x_kp + 5); roi_y2 = min(y2_box, max_y_kp + 5)
else:
roi_x1 = x1_box; roi_y1 = y1_box + int(0.1 * box_h)
roi_x2 = x2_box; roi_y2 = y1_box + int(0.6 * box_h)
roi_x1 = max(0, roi_x1); roi_y1 = max(0, roi_y1)
roi_x2 = min(width, roi_x2); roi_y2 = min(height, roi_y2)
# -- Extract Average Color --
avg_color = DEFAULT_MARKER_COLOR
if roi_y2 > roi_y1 and roi_x2 > roi_x1:
torso_roi = image_bgr[roi_y1:roi_y2, roi_x1:roi_x2]
avg_color = _get_average_color(torso_roi)
# else: # Pas besoin de message si ROI invalide, couleur par défaut suffit
# print(f"Person {i}: Invalid ROI, using default color.")
# -- Store player data --
player_data = {
'keypoints': xy,
'scores': scores,
'bbox': person_box_xyxy, # Utiliser la bbox originale xyxy
'avg_color': avg_color
}
player_list.append(player_data)
except Exception as e:
print(f"Error during player data extraction: {e}")
import traceback
traceback.print_exc()
# Retourner une liste vide en cas d'erreur majeure
return []
return player_list
# Example usage (optional, for testing the module directly)
if __name__ == '__main__':
test_image_path = 'img3.png'
if not Path(test_image_path).exists():
print(f"Test image not found: {test_image_path}")
else:
print(f"Testing player data extraction with image: {test_image_path}")
input_img = cv2.imread(test_image_path)
if input_img is None:
print(f"Failed to load test image: {test_image_path}")
else:
print("Getting player data...")
players = get_player_data(input_img)
print(f"✓ Found data for {len(players)} players.")
# --- Draw markers and info on original image for testing ---
output_img_test = input_img.copy()
for idx, p_data in enumerate(players):
kps = p_data['keypoints']
scores = p_data['scores']
bbox = p_data['bbox']
color = p_data['avg_color']
# Determine reference point (ankles or bbox bottom mid)
l_ankle_pt = kps[LEFT_ANKLE_KP_INDEX]
r_ankle_pt = kps[RIGHT_ANKLE_KP_INDEX]
l_ankle_score = scores[LEFT_ANKLE_KP_INDEX]
r_ankle_score = scores[RIGHT_ANKLE_KP_INDEX]
ref_point = None
if l_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS and r_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS:
ref_point = tuple(map(int, (l_ankle_pt + r_ankle_pt) / 2))
elif l_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS:
ref_point = tuple(map(int, l_ankle_pt))
elif r_ankle_score > CONFIDENCE_THRESHOLD_KEYPOINTS:
ref_point = tuple(map(int, r_ankle_pt))
else:
x1, y1, x2, y2 = map(int, bbox)
ref_point = (int((x1 + x2) / 2), y2)
# Draw marker at reference point
if ref_point:
cv2.circle(output_img_test, ref_point, 8, color, -1, cv2.LINE_AA)
cv2.circle(output_img_test, ref_point, 8, (0,0,0), 1, cv2.LINE_AA) # Black outline
# Draw player index
cv2.putText(output_img_test, str(idx), (ref_point[0]+5, ref_point[1]-5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,0), 2, cv2.LINE_AA)
cv2.putText(output_img_test, str(idx), (ref_point[0]+5, ref_point[1]-5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1, cv2.LINE_AA)
cv2.imshow("Original Image", input_img)
cv2.imshow("Player Markers Test", output_img_test)
print("Displaying test results. Press any key to exit.")
cv2.waitKey(0)
cv2.destroyAllWindows() |