Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -13,6 +13,7 @@ from PIL import Image
|
|
13 |
import requests
|
14 |
import yaml
|
15 |
import numpy as np
|
|
|
16 |
|
17 |
from src.core import YAMLConfig
|
18 |
|
@@ -108,7 +109,7 @@ def download_weights(model_name):
|
|
108 |
print(f"Downloaded weights to: {weights_path}")
|
109 |
return weights_path
|
110 |
|
111 |
-
|
112 |
def process_image_for_gradio(model, device, image, model_name, threshold=0.4):
|
113 |
"""Process image function for Gradio interface"""
|
114 |
if isinstance(image, np.ndarray):
|
@@ -185,10 +186,37 @@ class ModelWrapper(nn.Module):
|
|
185 |
return outputs
|
186 |
|
187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
def load_model(model_name):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
cfgfile = model_configs[model_name]["cfgfile"]
|
190 |
weights_path = download_weights(model_name)
|
191 |
|
|
|
192 |
cfg = YAMLConfig(cfgfile, resume=weights_path)
|
193 |
|
194 |
if "HGNetv2" in cfg.yaml_cfg:
|
@@ -197,7 +225,11 @@ def load_model(model_name):
|
|
197 |
checkpoint = torch.load(weights_path, map_location="cpu")
|
198 |
state = checkpoint["ema"]["module"] if "ema" in checkpoint else checkpoint["model"]
|
199 |
|
200 |
-
|
|
|
|
|
|
|
|
|
201 |
|
202 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
203 |
model = ModelWrapper(cfg).to(device)
|
@@ -205,26 +237,34 @@ def load_model(model_name):
|
|
205 |
|
206 |
return model, device
|
207 |
|
208 |
-
|
209 |
-
# Dictionary to store loaded models
|
210 |
-
loaded_models = {}
|
211 |
-
|
212 |
@spaces.GPU
|
213 |
def process_image(image, model_name, confidence_threshold):
|
214 |
"""Main processing function for Gradio interface"""
|
215 |
-
global loaded_models
|
216 |
|
217 |
-
#
|
218 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
print(f"Loading model: {model_name}")
|
220 |
model, device = load_model(model_name)
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
|
226 |
-
|
227 |
-
return process_image_for_gradio(model, device, image, model_name, confidence_threshold)
|
228 |
|
229 |
|
230 |
# Create Gradio interface
|
@@ -256,4 +296,6 @@ demo = gr.Interface(
|
|
256 |
]
|
257 |
)
|
258 |
|
259 |
-
|
|
|
|
|
|
13 |
import requests
|
14 |
import yaml
|
15 |
import numpy as np
|
16 |
+
import gc
|
17 |
|
18 |
from src.core import YAMLConfig
|
19 |
|
|
|
109 |
print(f"Downloaded weights to: {weights_path}")
|
110 |
return weights_path
|
111 |
|
112 |
+
@torch.no_grad()
|
113 |
def process_image_for_gradio(model, device, image, model_name, threshold=0.4):
|
114 |
"""Process image function for Gradio interface"""
|
115 |
if isinstance(image, np.ndarray):
|
|
|
186 |
return outputs
|
187 |
|
188 |
|
189 |
+
# YAMLConfig ํด๋์ค์ ๋ด๋ถ ์ํ๋ฅผ ์ด๊ธฐํํ๋ ํจ์ ์ถ๊ฐ
|
190 |
+
def reset_yaml_config():
|
191 |
+
"""YAMLConfig ํด๋์ค์ ๋ด๋ถ ์ํ๋ฅผ ์ด๊ธฐํ"""
|
192 |
+
# ํด๋์ค ๋ด๋ถ์ ์บ์ฑ๋ ์ ๋ณด๊ฐ ์๋ค๋ฉด ์ญ์
|
193 |
+
if hasattr(YAMLConfig, '_instances'):
|
194 |
+
YAMLConfig._instances = {}
|
195 |
+
if hasattr(YAMLConfig, '_configs'):
|
196 |
+
YAMLConfig._configs = {}
|
197 |
+
|
198 |
+
# ๊ฐ๋ฅํ ๋ค๋ฅธ ๋ชจ๋ ๋ชจ๋ ์บ์ ๋ฆฌ์
|
199 |
+
import importlib
|
200 |
+
for module_name in list(sys.modules.keys()):
|
201 |
+
if module_name.startswith('src.'):
|
202 |
+
try:
|
203 |
+
importlib.reload(sys.modules[module_name])
|
204 |
+
except:
|
205 |
+
pass
|
206 |
+
|
207 |
def load_model(model_name):
|
208 |
+
# ๋ชจ๋ธ ๋ก๋ ์ ์ CUDA ์บ์์ ๊ฐ๋น์ง ์ปฌ๋ ์
์ ๋ฆฌ
|
209 |
+
if torch.cuda.is_available():
|
210 |
+
torch.cuda.empty_cache()
|
211 |
+
gc.collect()
|
212 |
+
|
213 |
+
# YAMLConfig ๋ด๋ถ ์ํ ์ด๊ธฐํ
|
214 |
+
reset_yaml_config()
|
215 |
+
|
216 |
cfgfile = model_configs[model_name]["cfgfile"]
|
217 |
weights_path = download_weights(model_name)
|
218 |
|
219 |
+
# ์์ ํ ์๋ก์ด YAMLConfig ์ธ์คํด์ค ์์ฑ
|
220 |
cfg = YAMLConfig(cfgfile, resume=weights_path)
|
221 |
|
222 |
if "HGNetv2" in cfg.yaml_cfg:
|
|
|
225 |
checkpoint = torch.load(weights_path, map_location="cpu")
|
226 |
state = checkpoint["ema"]["module"] if "ema" in checkpoint else checkpoint["model"]
|
227 |
|
228 |
+
# ๋ชจ๋ธ ์์ฑ ์ ํ๋ฒ ๋ ํ์ธ
|
229 |
+
torch.cuda.empty_cache()
|
230 |
+
gc.collect()
|
231 |
+
|
232 |
+
cfg.model.load_state_dict(state, strict=False)
|
233 |
|
234 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
235 |
model = ModelWrapper(cfg).to(device)
|
|
|
237 |
|
238 |
return model, device
|
239 |
|
|
|
|
|
|
|
|
|
240 |
@spaces.GPU
|
241 |
def process_image(image, model_name, confidence_threshold):
|
242 |
"""Main processing function for Gradio interface"""
|
|
|
243 |
|
244 |
+
# ๋ชจ๋ ์ฌ์ฉ ๊ฐ๋ฅํ CUDA ์ฅ์น ๋ฉ๋ชจ๋ฆฌ ํ๋ณด
|
245 |
+
if torch.cuda.is_available():
|
246 |
+
torch.cuda.empty_cache()
|
247 |
+
|
248 |
+
# ๋ชจ๋ Python ๊ฐ์ฒด ๊ฐ๋น์ง ์ปฌ๋ ์
|
249 |
+
gc.collect()
|
250 |
+
|
251 |
+
try:
|
252 |
print(f"Loading model: {model_name}")
|
253 |
model, device = load_model(model_name)
|
254 |
+
|
255 |
+
# ์ด๋ฏธ์ง ์ฒ๋ฆฌ
|
256 |
+
result = process_image_for_gradio(model, device, image, model_name, confidence_threshold)
|
257 |
+
|
258 |
+
# ๋ชจ๋ธ ๊ฐ์ฒด ๋ฐ ๊ด๋ จ ๋ฐ์ดํฐ ๋ช
์์ ์ ๊ฑฐ
|
259 |
+
del model
|
260 |
+
|
261 |
+
finally:
|
262 |
+
# ํญ์ ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ๋ณด์ฅ
|
263 |
+
if torch.cuda.is_available():
|
264 |
+
torch.cuda.empty_cache()
|
265 |
+
gc.collect()
|
266 |
|
267 |
+
return result
|
|
|
268 |
|
269 |
|
270 |
# Create Gradio interface
|
|
|
296 |
]
|
297 |
)
|
298 |
|
299 |
+
if __name__ == "__main__":
|
300 |
+
# Launch the Gradio app
|
301 |
+
demo.launch(share=True)
|