Ananya Uppal commited on
Commit
84d12f2
·
1 Parent(s): 51ed8c6

pushing app.py file

Browse files
Files changed (1) hide show
  1. app.py +325 -0
app.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """VTON_GarmentMasker.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1Y22abu3jZQ5qCKP7DTR6kYvXdQbHnJCu
8
+
9
+ Using YOLO Clothing Classification Model
10
+ """
11
+
12
+ # !pip install gradio
13
+ # !pip install ultralytics
14
+ # !pip install segment-anything
15
+
16
+ # !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
17
+
18
+ import torch
19
+ import numpy as np
20
+ import cv2
21
+ from PIL import Image
22
+ from torchvision import transforms
23
+ from ultralytics import YOLO
24
+ from segment_anything import SamPredictor, sam_model_registry
25
+ from transformers import YolosForObjectDetection, YolosImageProcessor
26
+ import gradio as gr
27
+ import os
28
+ import urllib.request
29
+
30
+ class GarmentMaskingPipeline:
31
+ def __init__(self):
32
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ print(f"Using device: {self.device}")
34
+ self.yolo_model, self.sam_predictor, self.classification_model = self.load_models()
35
+
36
+ self.clothing_to_body_parts = {
37
+ 'shirt': ['torso', 'arms'],
38
+ 't-shirt': ['torso', 'upper_arms'],
39
+ 'blouse': ['torso', 'arms'],
40
+ 'dress': ['torso', 'legs'],
41
+ 'skirt': ['lower_torso', 'legs'],
42
+ 'pants': ['legs'],
43
+ 'shorts': ['upper_legs'],
44
+ 'jacket': ['torso', 'arms'],
45
+ 'coat': ['torso', 'arms']
46
+ }
47
+
48
+ self.body_parts_positions = {
49
+ 'face': (0.0, 0.2),
50
+ 'torso': (0.2, 0.5),
51
+ 'arms': (0.2, 0.5),
52
+ 'upper_arms': (0.2, 0.35),
53
+ 'lower_torso': (0.4, 0.6),
54
+ 'legs': (0.5, 0.9),
55
+ 'upper_legs': (0.5, 0.7),
56
+ 'feet': (0.9, 1.0)
57
+ }
58
+
59
+ def load_models(self):
60
+ print("Loading models...")
61
+ # Download models if they don't exist
62
+ self.download_models()
63
+
64
+ # Load YOLO model
65
+ yolo_model = YOLO('yolov8n.pt')
66
+
67
+ # Load SAM model
68
+ sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
69
+ sam.to(self.device)
70
+ predictor = SamPredictor(sam)
71
+
72
+ # Load YOLOS-Fashionpedia model for clothing classification
73
+ print("Loading YOLOS-Fashionpedia model...")
74
+ model_name = "valentinafeve/yolos-fashionpedia"
75
+ processor = YolosImageProcessor.from_pretrained(model_name)
76
+ classification_model = YolosForObjectDetection.from_pretrained(model_name)
77
+ classification_model.to(self.device)
78
+ classification_model.eval()
79
+
80
+ print("Models loaded successfully!")
81
+ return yolo_model, predictor, classification_model
82
+
83
+ def download_models(self):
84
+ """Download required model files if they don't exist"""
85
+ models = {
86
+ "yolov8n.pt": "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt",
87
+ "sam_vit_h_4b8939.pth": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
88
+ }
89
+
90
+ for filename, url in models.items():
91
+ if not os.path.exists(filename):
92
+ print(f"Downloading {filename}...")
93
+ urllib.request.urlretrieve(url, filename)
94
+ print(f"Downloaded {filename}")
95
+ else:
96
+ print(f"{filename} already exists")
97
+
98
+ # The YOLOS-Fashionpedia model will be downloaded automatically by transformers
99
+
100
+ def classify_clothing(self, clothing_image):
101
+ if not isinstance(clothing_image, Image.Image):
102
+ clothing_image = Image.fromarray(clothing_image)
103
+
104
+ # Process image with YOLOS processor
105
+ processor = YolosImageProcessor.from_pretrained("valentinafeve/yolos-fashionpedia")
106
+ inputs = processor(images=clothing_image, return_tensors="pt").to(self.device)
107
+
108
+ # Run inference
109
+ with torch.no_grad():
110
+ outputs = self.classification_model(**inputs)
111
+
112
+ # Process results
113
+ target_sizes = torch.tensor([clothing_image.size[::-1]]).to(self.device)
114
+ results = processor.post_process_object_detection(
115
+ outputs, target_sizes=target_sizes, threshold=0.1
116
+ )[0]
117
+
118
+ # Extract detected labels and confidence scores
119
+ labels = results["labels"]
120
+ scores = results["scores"]
121
+
122
+ # Get class names from model config
123
+ id2label = self.classification_model.config.id2label
124
+
125
+ # Define Fashionpedia to our category mapping
126
+ fashionpedia_to_clothing = {
127
+ 'shirt': 'shirt',
128
+ 'blouse': 'shirt',
129
+ 'top': 't-shirt',
130
+ 't-shirt': 't-shirt',
131
+ 'sweater': 'shirt',
132
+ 'jacket': 'jacket',
133
+ 'cardigan': 'jacket',
134
+ 'coat': 'coat',
135
+ 'jumper': 'shirt',
136
+ 'dress': 'dress',
137
+ 'skirt': 'skirt',
138
+ 'shorts': 'shorts',
139
+ 'pants': 'pants',
140
+ 'jeans': 'pants',
141
+ 'leggings': 'pants',
142
+ 'jumpsuit': 'dress'
143
+ }
144
+
145
+ # Find the garment with highest confidence
146
+ if len(labels) > 0:
147
+ detections = [(id2label[label.item()].lower(), score.item())
148
+ for label, score in zip(labels, scores)]
149
+ detections.sort(key=lambda x: x[1], reverse=True)
150
+
151
+ for label, score in detections:
152
+ # Look for clothing keywords in the label
153
+ for keyword, category in fashionpedia_to_clothing.items():
154
+ if keyword in label:
155
+ return category
156
+
157
+ # If no mapping found, use the first detection as is
158
+ return 't-shirt'
159
+
160
+ # Default to t-shirt if nothing detected
161
+ return 't-shirt'
162
+
163
+ def create_garment_mask(self, person_image, garment_image):
164
+ clothing_type = self.classify_clothing(garment_image)
165
+ parts_to_mask = self.clothing_to_body_parts.get(clothing_type, [])
166
+
167
+ results = self.yolo_model(person_image, classes=[0])
168
+ mask = np.zeros(person_image.shape[:2], dtype=np.uint8)
169
+
170
+ if results and len(results[0].boxes.data) > 0:
171
+ person_boxes = results[0].boxes.data
172
+ person_areas = [(box[2] - box[0]) * (box[3] - box[1]) for box in person_boxes]
173
+ largest_person_index = np.argmax(person_areas)
174
+ person_box = person_boxes[largest_person_index][:4].cpu().numpy().astype(int)
175
+
176
+ self.sam_predictor.set_image(person_image)
177
+ masks, _, _ = self.sam_predictor.predict(box=person_box, multimask_output=False)
178
+ person_mask = masks[0].astype(np.uint8)
179
+
180
+ h, w = person_mask.shape
181
+ for part in parts_to_mask:
182
+ if part in self.body_parts_positions:
183
+ top_ratio, bottom_ratio = self.body_parts_positions[part]
184
+ top_px, bottom_px = int(h * top_ratio), int(h * bottom_ratio)
185
+
186
+ part_mask = np.zeros_like(person_mask)
187
+ part_mask[top_px:bottom_px, :] = 1
188
+ part_mask = np.logical_and(part_mask, person_mask).astype(np.uint8)
189
+
190
+ mask = np.logical_or(mask, part_mask).astype(np.uint8)
191
+
192
+ # Remove face from the mask
193
+ face_top_px, face_bottom_px = int(h * 0.0), int(h * 0.2)
194
+ face_mask = np.zeros_like(person_mask)
195
+ face_mask[face_top_px:face_bottom_px, :] = 1
196
+ face_mask = np.logical_and(face_mask, person_mask).astype(np.uint8)
197
+ mask = np.logical_and(mask, np.logical_not(face_mask)).astype(np.uint8)
198
+
199
+ # Remove feet from the mask
200
+ feet_top_px, feet_bottom_px = int(h * 0.9), int(h * 1.0)
201
+ feet_mask = np.zeros_like(person_mask)
202
+ feet_mask[feet_top_px:feet_bottom_px, :] = 1
203
+ feet_mask = np.logical_and(feet_mask, person_mask).astype(np.uint8)
204
+ mask = np.logical_and(mask, np.logical_not(feet_mask)).astype(np.uint8)
205
+
206
+ return mask * 255
207
+
208
+ def process(self, person_image_pil, garment_image_pil, mask_color_hex="#00FF00", opacity=0.5):
209
+ """Process the input images and return the masked result"""
210
+ # Convert PIL to numpy array
211
+ person_image = np.array(person_image_pil)
212
+ garment_image = np.array(garment_image_pil)
213
+
214
+ # Convert to RGB if needed
215
+ if person_image.shape[2] == 4: # RGBA
216
+ person_image = person_image[:, :, :3]
217
+ if garment_image.shape[2] == 4: # RGBA
218
+ garment_image = garment_image[:, :, :3]
219
+
220
+ # Create garment mask
221
+ garment_mask = self.create_garment_mask(person_image, garment_image)
222
+
223
+ # Convert hex color to RGB
224
+ r = int(mask_color_hex[1:3], 16)
225
+ g = int(mask_color_hex[3:5], 16)
226
+ b = int(mask_color_hex[5:7], 16)
227
+ color = (r, g, b)
228
+
229
+ # Create a colored mask
230
+ colored_mask = np.zeros_like(person_image)
231
+ for i in range(3):
232
+ colored_mask[:, :, i] = garment_mask * (color[i] / 255.0)
233
+
234
+ # Create binary mask for visualization
235
+ binary_mask = np.stack([garment_mask, garment_mask, garment_mask], axis=2)
236
+
237
+ # Overlay mask on original image
238
+ mask_3d = garment_mask[:, :, np.newaxis] / 255.0
239
+ overlay = person_image * (1 - opacity * mask_3d) + colored_mask * opacity
240
+ overlay = overlay.astype(np.uint8)
241
+
242
+ # Get classification result
243
+ clothing_type = self.classify_clothing(garment_image)
244
+ parts_to_mask = self.clothing_to_body_parts.get(clothing_type, [])
245
+
246
+ return overlay, binary_mask, f"Detected garment: {clothing_type}\nBody parts to mask: {', '.join(parts_to_mask)}"
247
+
248
+ def process_images(person_img, garment_img, mask_color, opacity):
249
+ """Gradio processing function"""
250
+ try:
251
+ pipeline = GarmentMaskingPipeline()
252
+ result = pipeline.process(person_img, garment_img, mask_color, opacity)
253
+ return result
254
+ except Exception as e:
255
+ import traceback
256
+ error_msg = f"Error processing images: {str(e)}\n{traceback.format_exc()}"
257
+ print(error_msg)
258
+ return None, None, error_msg
259
+
260
+ def create_gradio_interface():
261
+ """Create and launch the Gradio interface"""
262
+ with gr.Blocks(title="VTON SAM Garment Masking Pipeline") as interface:
263
+ gr.Markdown("""
264
+ # Virtual Try-On Garment Masking Pipeline with SAM and YOLOS-Fashionpedia
265
+
266
+ Upload a person image and a garment image to generate a mask for a virtual try-on application.
267
+ The system will:
268
+ 1. Detect the person using YOLO
269
+ 2. Create a high-quality segmentation using SAM (Segment Anything Model)
270
+ 3. Classify the garment type using YOLOS-Fashionpedia
271
+ 4. Generate a mask of the area where the garment should be placed
272
+
273
+ **Note**: This system uses state-of-the-art AI segmentation and fashion detection models for accurate results.
274
+ """)
275
+
276
+ with gr.Row():
277
+ with gr.Column():
278
+ person_input = gr.Image(label="Person Image (Image A)", type="pil")
279
+ garment_input = gr.Image(label="Garment Image (Image B)", type="pil")
280
+
281
+ with gr.Row():
282
+ mask_color = gr.ColorPicker(label="Mask Color", value="#00FF00")
283
+ opacity = gr.Slider(label="Mask Opacity", minimum=0.1, maximum=0.9, value=0.5, step=0.1)
284
+
285
+ submit_btn = gr.Button("Generate Mask")
286
+
287
+ with gr.Column():
288
+ masked_output = gr.Image(label="Person with Masked Region")
289
+ mask_output = gr.Image(label="Standalone Mask")
290
+ result_text = gr.Textbox(label="Detection Results", lines=3)
291
+
292
+ # Set up the processing flow
293
+ submit_btn.click(
294
+ fn=process_images,
295
+ inputs=[person_input, garment_input, mask_color, opacity],
296
+ outputs=[masked_output, mask_output, result_text]
297
+ )
298
+
299
+ gr.Markdown("""
300
+ ## How It Works
301
+
302
+ 1. **Person Detection**: Uses YOLO to detect and locate the person in the image
303
+ 2. **Segmentation**: Uses SAM (Segment Anything Model) to create a high-quality segmentation mask
304
+ 3. **Garment Classification**: Uses YOLOS-Fashionpedia to identify the garment type with fashion-specific detection
305
+ 4. **Mask Generation**: Creates a mask based on the garment type and body part mapping
306
+
307
+ ## Supported Garment Types
308
+
309
+ - Shirts, Blouses, Tops, and T-shirts
310
+ - Sweaters and Cardigans
311
+ - Dresses and Jumpsuits
312
+ - Skirts
313
+ - Pants, Jeans, and Leggings
314
+ - Shorts
315
+ -
316
+ Jackets and Coats
317
+
318
+ """)
319
+
320
+ return interface
321
+
322
+ if __name__ == "__main__":
323
+ # Create and launch the Gradio interface
324
+ interface = create_gradio_interface()
325
+ interface.launch(debug=True,share=True)