developer0hye commited on
Commit
751073d
ยท
verified ยท
1 Parent(s): f370b56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -16
app.py CHANGED
@@ -13,6 +13,7 @@ from PIL import Image
13
  import requests
14
  import yaml
15
  import numpy as np
 
16
 
17
  from src.core import YAMLConfig
18
 
@@ -108,7 +109,7 @@ def download_weights(model_name):
108
  print(f"Downloaded weights to: {weights_path}")
109
  return weights_path
110
 
111
-
112
  def process_image_for_gradio(model, device, image, model_name, threshold=0.4):
113
  """Process image function for Gradio interface"""
114
  if isinstance(image, np.ndarray):
@@ -185,10 +186,37 @@ class ModelWrapper(nn.Module):
185
  return outputs
186
 
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  def load_model(model_name):
 
 
 
 
 
 
 
 
189
  cfgfile = model_configs[model_name]["cfgfile"]
190
  weights_path = download_weights(model_name)
191
 
 
192
  cfg = YAMLConfig(cfgfile, resume=weights_path)
193
 
194
  if "HGNetv2" in cfg.yaml_cfg:
@@ -197,7 +225,11 @@ def load_model(model_name):
197
  checkpoint = torch.load(weights_path, map_location="cpu")
198
  state = checkpoint["ema"]["module"] if "ema" in checkpoint else checkpoint["model"]
199
 
200
- cfg.model.load_state_dict(state)
 
 
 
 
201
 
202
  device = "cuda" if torch.cuda.is_available() else "cpu"
203
  model = ModelWrapper(cfg).to(device)
@@ -205,26 +237,34 @@ def load_model(model_name):
205
 
206
  return model, device
207
 
208
-
209
- # Dictionary to store loaded models
210
- loaded_models = {}
211
-
212
  @spaces.GPU
213
  def process_image(image, model_name, confidence_threshold):
214
  """Main processing function for Gradio interface"""
215
- global loaded_models
216
 
217
- # Load model if not already loaded
218
- if model_name not in loaded_models:
 
 
 
 
 
 
219
  print(f"Loading model: {model_name}")
220
  model, device = load_model(model_name)
221
- loaded_models[model_name] = (model, device)
222
- else:
223
- print(f"Using cached model: {model_name}")
224
- model, device = loaded_models[model_name]
 
 
 
 
 
 
 
 
225
 
226
- # Process the image
227
- return process_image_for_gradio(model, device, image, model_name, confidence_threshold)
228
 
229
 
230
  # Create Gradio interface
@@ -256,4 +296,6 @@ demo = gr.Interface(
256
  ]
257
  )
258
 
259
- demo.launch(debug=True)
 
 
 
13
  import requests
14
  import yaml
15
  import numpy as np
16
+ import gc
17
 
18
  from src.core import YAMLConfig
19
 
 
109
  print(f"Downloaded weights to: {weights_path}")
110
  return weights_path
111
 
112
+ @torch.no_grad()
113
  def process_image_for_gradio(model, device, image, model_name, threshold=0.4):
114
  """Process image function for Gradio interface"""
115
  if isinstance(image, np.ndarray):
 
186
  return outputs
187
 
188
 
189
+ # YAMLConfig ํด๋ž˜์Šค์˜ ๋‚ด๋ถ€ ์ƒํƒœ๋ฅผ ์ดˆ๊ธฐํ™”ํ•˜๋Š” ํ•จ์ˆ˜ ์ถ”๊ฐ€
190
+ def reset_yaml_config():
191
+ """YAMLConfig ํด๋ž˜์Šค์˜ ๋‚ด๋ถ€ ์ƒํƒœ๋ฅผ ์ดˆ๊ธฐํ™”"""
192
+ # ํด๋ž˜์Šค ๋‚ด๋ถ€์— ์บ์‹ฑ๋œ ์ •๋ณด๊ฐ€ ์žˆ๋‹ค๋ฉด ์‚ญ์ œ
193
+ if hasattr(YAMLConfig, '_instances'):
194
+ YAMLConfig._instances = {}
195
+ if hasattr(YAMLConfig, '_configs'):
196
+ YAMLConfig._configs = {}
197
+
198
+ # ๊ฐ€๋Šฅํ•œ ๋‹ค๋ฅธ ๋ชจ๋“  ๋ชจ๋“ˆ ์บ์‹œ ๋ฆฌ์…‹
199
+ import importlib
200
+ for module_name in list(sys.modules.keys()):
201
+ if module_name.startswith('src.'):
202
+ try:
203
+ importlib.reload(sys.modules[module_name])
204
+ except:
205
+ pass
206
+
207
  def load_model(model_name):
208
+ # ๋ชจ๋ธ ๋กœ๋“œ ์ „์— CUDA ์บ์‹œ์™€ ๊ฐ€๋น„์ง€ ์ปฌ๋ ‰์…˜ ์ •๋ฆฌ
209
+ if torch.cuda.is_available():
210
+ torch.cuda.empty_cache()
211
+ gc.collect()
212
+
213
+ # YAMLConfig ๋‚ด๋ถ€ ์ƒํƒœ ์ดˆ๊ธฐํ™”
214
+ reset_yaml_config()
215
+
216
  cfgfile = model_configs[model_name]["cfgfile"]
217
  weights_path = download_weights(model_name)
218
 
219
+ # ์™„์ „ํžˆ ์ƒˆ๋กœ์šด YAMLConfig ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
220
  cfg = YAMLConfig(cfgfile, resume=weights_path)
221
 
222
  if "HGNetv2" in cfg.yaml_cfg:
 
225
  checkpoint = torch.load(weights_path, map_location="cpu")
226
  state = checkpoint["ema"]["module"] if "ema" in checkpoint else checkpoint["model"]
227
 
228
+ # ๋ชจ๋ธ ์ƒ์„ฑ ์ „ ํ•œ๋ฒˆ ๋” ํ™•์ธ
229
+ torch.cuda.empty_cache()
230
+ gc.collect()
231
+
232
+ cfg.model.load_state_dict(state, strict=False)
233
 
234
  device = "cuda" if torch.cuda.is_available() else "cpu"
235
  model = ModelWrapper(cfg).to(device)
 
237
 
238
  return model, device
239
 
 
 
 
 
240
  @spaces.GPU
241
  def process_image(image, model_name, confidence_threshold):
242
  """Main processing function for Gradio interface"""
 
243
 
244
+ # ๋ชจ๋“  ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ CUDA ์žฅ์น˜ ๋ฉ”๋ชจ๋ฆฌ ํ™•๋ณด
245
+ if torch.cuda.is_available():
246
+ torch.cuda.empty_cache()
247
+
248
+ # ๋ชจ๋“  Python ๊ฐ์ฒด ๊ฐ€๋น„์ง€ ์ปฌ๋ ‰์…˜
249
+ gc.collect()
250
+
251
+ try:
252
  print(f"Loading model: {model_name}")
253
  model, device = load_model(model_name)
254
+
255
+ # ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ
256
+ result = process_image_for_gradio(model, device, image, model_name, confidence_threshold)
257
+
258
+ # ๋ชจ๋ธ ๊ฐ์ฒด ๋ฐ ๊ด€๋ จ ๋ฐ์ดํ„ฐ ๋ช…์‹œ์  ์ œ๊ฑฐ
259
+ del model
260
+
261
+ finally:
262
+ # ํ•ญ์ƒ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ๋ณด์žฅ
263
+ if torch.cuda.is_available():
264
+ torch.cuda.empty_cache()
265
+ gc.collect()
266
 
267
+ return result
 
268
 
269
 
270
  # Create Gradio interface
 
296
  ]
297
  )
298
 
299
+ if __name__ == "__main__":
300
+ # Launch the Gradio app
301
+ demo.launch(share=True)