openfree commited on
Commit
faae4da
ยท
verified ยท
1 Parent(s): abb19a7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +914 -0
app.py ADDED
@@ -0,0 +1,914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import argparse
3
+ import os
4
+ import shutil
5
+ import cv2
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
10
+ import huggingface_hub
11
+ from huggingface_hub import hf_hub_download
12
+ from PIL import Image
13
+ from torchvision.transforms.functional import normalize
14
+ from gradio_client import Client
15
+ import logging
16
+ import time
17
+
18
+ from dreamo.dreamo_pipeline import DreamOPipeline
19
+ from dreamo.utils import img2tensor, resize_numpy_image_area, tensor2img, resize_numpy_image_long
20
+ from tools import BEN2
21
+
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument('--port', type=int, default=8080)
24
+ parser.add_argument('--no_turbo', action='store_true')
25
+ args = parser.parse_args()
26
+
27
+ huggingface_hub.login(os.getenv('HF_TOKEN'))
28
+
29
+ # Text-to-Image API URL
30
+ TEXT2IMG_API_URL = "http://211.233.58.201:7896"
31
+
32
+ # ๋กœ๊น… ์„ค์ •
33
+ logging.basicConfig(
34
+ level=logging.DEBUG,
35
+ format='%(asctime)s - %(levelname)s - %(message)s')
36
+
37
+ try:
38
+ shutil.rmtree('gradio_cached_examples')
39
+ except FileNotFoundError:
40
+ print("cache folder not exist")
41
+
42
+ class Generator:
43
+ def __init__(self):
44
+ device = torch.device('cuda')
45
+ # preprocessing models
46
+ # background remove model: BEN2
47
+ self.bg_rm_model = BEN2.BEN_Base().to(device).eval()
48
+ hf_hub_download(repo_id='PramaLLC/BEN2', filename='BEN2_Base.pth', local_dir='models')
49
+ self.bg_rm_model.loadcheckpoints('models/BEN2_Base.pth')
50
+ # face crop and align tool: facexlib
51
+ self.face_helper = FaceRestoreHelper(
52
+ upscale_factor=1,
53
+ face_size=512,
54
+ crop_ratio=(1, 1),
55
+ det_model='retinaface_resnet50',
56
+ save_ext='png',
57
+ device=device,
58
+ )
59
+
60
+ # load dreamo
61
+ model_root = 'black-forest-labs/FLUX.1-dev'
62
+ dreamo_pipeline = DreamOPipeline.from_pretrained(model_root, torch_dtype=torch.bfloat16)
63
+ dreamo_pipeline.load_dreamo_model(device, use_turbo=not args.no_turbo)
64
+ self.dreamo_pipeline = dreamo_pipeline.to(device)
65
+
66
+ @torch.no_grad()
67
+ def get_align_face(self, img):
68
+ # the face preprocessing code is same as PuLID
69
+ self.face_helper.clean_all()
70
+ image_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
71
+ self.face_helper.read_image(image_bgr)
72
+ self.face_helper.get_face_landmarks_5(only_center_face=True)
73
+ self.face_helper.align_warp_face()
74
+ if len(self.face_helper.cropped_faces) == 0:
75
+ return None
76
+ align_face = self.face_helper.cropped_faces[0]
77
+
78
+ input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0
79
+ input = input.to(torch.device("cuda"))
80
+ parsing_out = self.face_helper.face_parse(normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0]
81
+ parsing_out = parsing_out.argmax(dim=1, keepdim=True)
82
+ bg_label = [0, 16, 18, 7, 8, 9, 14, 15]
83
+ bg = sum(parsing_out == i for i in bg_label).bool()
84
+ white_image = torch.ones_like(input)
85
+ # only keep the face features
86
+ face_features_image = torch.where(bg, white_image, input)
87
+ face_features_image = tensor2img(face_features_image, rgb2bgr=False)
88
+
89
+ return face_features_image
90
+
91
+
92
+ generator = Generator()
93
+
94
+
95
+ @spaces.GPU
96
+ @torch.inference_mode()
97
+ def generate_image(
98
+ ref_image1,
99
+ ref_image2,
100
+ ref_task1,
101
+ ref_task2,
102
+ prompt,
103
+ seed,
104
+ width=1024,
105
+ height=1024,
106
+ ref_res=512,
107
+ num_steps=12,
108
+ guidance=3.5,
109
+ true_cfg=1,
110
+ cfg_start_step=0,
111
+ cfg_end_step=0,
112
+ neg_prompt='',
113
+ neg_guidance=3.5,
114
+ first_step_guidance=0,
115
+ ):
116
+ print(prompt)
117
+ ref_conds = []
118
+ debug_images = []
119
+
120
+ ref_images = [ref_image1, ref_image2]
121
+ ref_tasks = [ref_task1, ref_task2]
122
+
123
+ for idx, (ref_image, ref_task) in enumerate(zip(ref_images, ref_tasks)):
124
+ if ref_image is not None:
125
+ if ref_task == "id":
126
+ ref_image = resize_numpy_image_long(ref_image, 1024)
127
+ ref_image = generator.get_align_face(ref_image)
128
+ elif ref_task != "style":
129
+ ref_image = generator.bg_rm_model.inference(Image.fromarray(ref_image))
130
+ if ref_task != "id":
131
+ ref_image = resize_numpy_image_area(np.array(ref_image), ref_res * ref_res)
132
+ debug_images.append(ref_image)
133
+ ref_image = img2tensor(ref_image, bgr2rgb=False).unsqueeze(0) / 255.0
134
+ ref_image = 2 * ref_image - 1.0
135
+ ref_conds.append(
136
+ {
137
+ 'img': ref_image,
138
+ 'task': ref_task,
139
+ 'idx': idx + 1,
140
+ }
141
+ )
142
+
143
+ seed = int(seed)
144
+ if seed == -1:
145
+ seed = torch.Generator(device="cpu").seed()
146
+
147
+ image = generator.dreamo_pipeline(
148
+ prompt=prompt,
149
+ width=width,
150
+ height=height,
151
+ num_inference_steps=num_steps,
152
+ guidance_scale=guidance,
153
+ ref_conds=ref_conds,
154
+ generator=torch.Generator(device="cpu").manual_seed(seed),
155
+ true_cfg_scale=true_cfg,
156
+ true_cfg_start_step=cfg_start_step,
157
+ true_cfg_end_step=cfg_end_step,
158
+ negative_prompt=neg_prompt,
159
+ neg_guidance_scale=neg_guidance,
160
+ first_step_guidance_scale=first_step_guidance if first_step_guidance > 0 else guidance,
161
+ ).images[0]
162
+
163
+ return image, debug_images, seed
164
+
165
+
166
+ # Video generation functions
167
+ import requests
168
+ import random
169
+ import tempfile
170
+ import subprocess
171
+ from gradio_client import Client, handle_file
172
+
173
+ REMOTE_ENDPOINT = os.getenv("H100_URL")
174
+
175
+ client = Client(REMOTE_ENDPOINT)
176
+
177
+ def run_process_video_api(image_path: str, prompt: str, video_length: float = 2.0):
178
+ seed_val = random.randint(0, 9999999)
179
+ negative_prompt = ""
180
+ use_teacache = True
181
+
182
+ result = client.predict(
183
+ input_image=handle_file(image_path),
184
+ prompt=prompt,
185
+ n_prompt=negative_prompt,
186
+ seed=seed_val,
187
+ use_teacache=use_teacache,
188
+ video_length=video_length,
189
+ api_name="/process",
190
+ )
191
+ video_dict, preview_dict, md_text, html_text = result
192
+ video_path = video_dict.get("video") if isinstance(video_dict, dict) else None
193
+ return video_path
194
+
195
+ def add_watermark_to_video(input_video_path: str, watermark_text="Ginigen.com") -> str:
196
+ if not os.path.exists(input_video_path):
197
+ raise FileNotFoundError(f"Input video not found: {input_video_path}")
198
+
199
+ base, ext = os.path.splitext(input_video_path)
200
+ watermarked_path = base + "_wm" + ext
201
+ cmd = [
202
+ "ffmpeg", "-y",
203
+ "-i", input_video_path,
204
+ "-vf", f"drawtext=fontsize=20:fontcolor=white:text='{watermark_text}':x=w-tw-10:y=h-th-10:box=1:[email protected]:boxborderw=5",
205
+ "-codec:a", "copy",
206
+ watermarked_path
207
+ ]
208
+ try:
209
+ subprocess.run(cmd, check=True)
210
+ except Exception as e:
211
+ print(f"[WARN] FFmpeg watermark failed: {e}")
212
+ return input_video_path
213
+
214
+ return watermarked_path
215
+
216
+ def generate_video_from_image(image_array: np.ndarray):
217
+ if image_array is None:
218
+ raise gr.Error("์ด๋ฏธ์ง€๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.")
219
+
220
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as fp:
221
+ temp_img_path = fp.name
222
+ Image.fromarray(image_array).save(temp_img_path, format="PNG")
223
+
224
+ default_video_prompt = "Generate a video with smooth and natural movement. Objects should have visible motion while maintaining fluid transitions."
225
+ result_video_path = run_process_video_api(
226
+ image_path=temp_img_path,
227
+ prompt=default_video_prompt,
228
+ video_length=2.0,
229
+ )
230
+ if result_video_path is None:
231
+ raise gr.Error("์˜์ƒ API ํ˜ธ์ถœ ์‹คํŒจ ๋˜๋Š” ๊ฒฐ๊ณผ ์—†์Œ")
232
+
233
+ final_video = add_watermark_to_video(result_video_path, watermark_text="Ginigen.com")
234
+ return final_video
235
+
236
+
237
+ # Text-to-Image functions
238
+ def test_text2img_api_connection() -> str:
239
+ """Text-to-Image API ์„œ๋ฒ„ ์—ฐ๊ฒฐ ํ…Œ์ŠคํŠธ"""
240
+ try:
241
+ client = Client(TEXT2IMG_API_URL)
242
+ return "API ์—ฐ๊ฒฐ ์„ฑ๊ณต: ์ •์ƒ ์ž‘๋™ ์ค‘"
243
+ except Exception as e:
244
+ logging.error(f"API connection test failed: {e}")
245
+ return f"API ์—ฐ๊ฒฐ ์‹คํŒจ: {e}"
246
+
247
+ def generate_text_to_image(prompt: str, width: int, height: int, guidance: float, inference_steps: int, seed: int) -> tuple:
248
+ """ํ…์ŠคํŠธ๋ฅผ ์ด๋ฏธ์ง€๋กœ ์ƒ์„ฑํ•˜๋Š” ํ•จ์ˆ˜"""
249
+ if not prompt:
250
+ return None, "์˜ค๋ฅ˜: ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”"
251
+
252
+ try:
253
+ client = Client(TEXT2IMG_API_URL)
254
+ result = client.predict(
255
+ prompt=prompt,
256
+ width=int(width),
257
+ height=int(height),
258
+ guidance=float(guidance),
259
+ inference_steps=int(inference_steps),
260
+ seed=int(seed),
261
+ do_img2img=False,
262
+ init_image=None,
263
+ image2image_strength=0.8,
264
+ resize_img=True,
265
+ api_name="/generate_image"
266
+ )
267
+ return result[0], f"์‚ฌ์šฉ๋œ ์‹œ๋“œ: {result[1]}"
268
+ except Exception as e:
269
+ logging.error(f"Image generation failed: {str(e)}")
270
+ return None, f"์˜ค๋ฅ˜: {str(e)}"
271
+
272
+ # Image size presets
273
+ IMAGE_PRESETS = {
274
+ "์ปค์Šคํ…€": {"width": 1024, "height": 1024, "label": "์ปค์Šคํ…€ ํฌ๊ธฐ"},
275
+ "1:1 ์ •์‚ฌ๊ฐํ˜•": {"width": 1024, "height": 1024, "label": "1:1 (์ •์‚ฌ๊ฐํ˜•)"},
276
+ "4:3 ํ‘œ์ค€": {"width": 1024, "height": 768, "label": "4:3 (ํ‘œ์ค€)"},
277
+ "16:9 ์™€์ด๋“œ์Šคํฌ๋ฆฐ": {"width": 1024, "height": 576, "label": "16:9 (์™€์ด๋“œ์Šคํฌ๋ฆฐ)"},
278
+ "9:16 ์„ธ๋กœํ˜•": {"width": 576, "height": 1024, "label": "9:16 (์„ธ๋กœํ˜•)"},
279
+ "6:19 ํŠน์ˆ˜ ์„ธ๋กœํ˜•": {"width": 324, "height": 1024, "label": "6:19 (ํŠน์ˆ˜ ์„ธ๋กœํ˜•)"},
280
+ "Instagram ์ •์‚ฌ๊ฐํ˜•": {"width": 1080, "height": 1080, "label": "Instagram ์ •์‚ฌ๊ฐํ˜• (1:1)"},
281
+ "Instagram ์Šคํ† ๋ฆฌ": {"width": 1080, "height": 1920, "label": "Instagram ์Šคํ† ๋ฆฌ (9:16)"},
282
+ "Instagram ๊ฐ€๋กœํ˜•": {"width": 1080, "height": 566, "label": "Instagram ๊ฐ€๋กœํ˜• (1.91:1)"},
283
+ "Facebook ์ปค๋ฒ„": {"width": 820, "height": 312, "label": "Facebook ์ปค๋ฒ„ (2.63:1)"},
284
+ "Twitter ํ—ค๋”": {"width": 1500, "height": 500, "label": "Twitter ํ—ค๋” (3:1)"},
285
+ "YouTube ์ธ๋„ค์ผ": {"width": 1280, "height": 720, "label": "YouTube ์ธ๋„ค์ผ (16:9)"},
286
+ "LinkedIn ๋ฐฐ๋„ˆ": {"width": 1584, "height": 396, "label": "LinkedIn ๋ฐฐ๋„ˆ (4:1)"},
287
+ }
288
+
289
+ def update_dimensions(preset):
290
+ """์„ ํƒ๋œ ํ”„๋ฆฌ์…‹์— ๋”ฐ๋ผ width, height ์—…๋ฐ์ดํŠธ"""
291
+ if preset in IMAGE_PRESETS:
292
+ return IMAGE_PRESETS[preset]["width"], IMAGE_PRESETS[preset]["height"]
293
+ return 1024, 1024
294
+
295
+
296
+ # Custom CSS
297
+ _CUSTOM_CSS_ = """
298
+ :root {
299
+ --primary-color: #f8c3cd;
300
+ --secondary-color: #b3e5fc;
301
+ --background-color: #f5f5f7;
302
+ --card-background: #ffffff;
303
+ --text-color: #424242;
304
+ --accent-color: #ffb6c1;
305
+ --success-color: #c8e6c9;
306
+ --warning-color: #fff9c4;
307
+ --shadow-color: rgba(0, 0, 0, 0.1);
308
+ --border-radius: 12px;
309
+ }
310
+
311
+ body {
312
+ background-color: var(--background-color) !important;
313
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif !important;
314
+ }
315
+
316
+ .gradio-container {
317
+ max-width: 1200px !important;
318
+ margin: 0 auto !important;
319
+ }
320
+
321
+ h1 {
322
+ color: #9c27b0 !important;
323
+ font-weight: 800 !important;
324
+ text-shadow: 2px 2px 4px rgba(156, 39, 176, 0.2) !important;
325
+ letter-spacing: -0.5px !important;
326
+ }
327
+
328
+ .panel-box {
329
+ border-radius: var(--border-radius) !important;
330
+ box-shadow: 0 8px 16px var(--shadow-color) !important;
331
+ background-color: var(--card-background) !important;
332
+ border: none !important;
333
+ overflow: hidden !important;
334
+ padding: 20px !important;
335
+ margin-bottom: 20px !important;
336
+ }
337
+
338
+ button.gr-button {
339
+ background: linear-gradient(135deg, var(--primary-color), #e1bee7) !important;
340
+ border-radius: var(--border-radius) !important;
341
+ color: #4a148c !important;
342
+ font-weight: 600 !important;
343
+ border: none !important;
344
+ padding: 10px 20px !important;
345
+ transition: all 0.3s ease !important;
346
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1) !important;
347
+ }
348
+
349
+ button.gr-button:hover {
350
+ transform: translateY(-2px) !important;
351
+ box-shadow: 0 6px 10px rgba(0, 0, 0, 0.15) !important;
352
+ background: linear-gradient(135deg, #e1bee7, var(--primary-color)) !important;
353
+ }
354
+
355
+ input, select, textarea, .gr-input {
356
+ border-radius: 8px !important;
357
+ border: 2px solid #e0e0e0 !important;
358
+ padding: 10px 15px !important;
359
+ transition: all 0.3s ease !important;
360
+ background-color: #fafafa !important;
361
+ }
362
+
363
+ input:focus, select:focus, textarea:focus, .gr-input:focus {
364
+ border-color: var(--primary-color) !important;
365
+ box-shadow: 0 0 0 3px rgba(248, 195, 205, 0.3) !important;
366
+ }
367
+
368
+ .gr-form input[type=range] {
369
+ appearance: none !important;
370
+ width: 100% !important;
371
+ height: 6px !important;
372
+ background: #e0e0e0 !important;
373
+ border-radius: 5px !important;
374
+ outline: none !important;
375
+ }
376
+
377
+ .gr-form input[type=range]::-webkit-slider-thumb {
378
+ appearance: none !important;
379
+ width: 16px !important;
380
+ height: 16px !important;
381
+ background: var(--primary-color) !important;
382
+ border-radius: 50% !important;
383
+ cursor: pointer !important;
384
+ border: 2px solid white !important;
385
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1) !important;
386
+ }
387
+
388
+ .gr-form select {
389
+ background-color: white !important;
390
+ border: 2px solid #e0e0e0 !important;
391
+ border-radius: 8px !important;
392
+ padding: 10px 15px !important;
393
+ }
394
+
395
+ .gr-image-input {
396
+ border: 2px dashed #b39ddb !important;
397
+ border-radius: var(--border-radius) !important;
398
+ background-color: #f3e5f5 !important;
399
+ padding: 20px !important;
400
+ display: flex !important;
401
+ flex-direction: column !important;
402
+ align-items: center !important;
403
+ justify-content: center !important;
404
+ transition: all 0.3s ease !important;
405
+ }
406
+
407
+ .gr-image-input:hover {
408
+ background-color: #ede7f6 !important;
409
+ border-color: #9575cd !important;
410
+ }
411
+
412
+ body::before {
413
+ content: "" !important;
414
+ position: fixed !important;
415
+ top: 0 !important;
416
+ left: 0 !important;
417
+ width: 100% !important;
418
+ height: 100% !important;
419
+ background:
420
+ radial-gradient(circle at 10% 20%, rgba(248, 195, 205, 0.1) 0%, rgba(245, 245, 247, 0) 20%),
421
+ radial-gradient(circle at 80% 70%, rgba(179, 229, 252, 0.1) 0%, rgba(245, 245, 247, 0) 20%) !important;
422
+ pointer-events: none !important;
423
+ z-index: -1 !important;
424
+ }
425
+
426
+ .gr-gallery {
427
+ grid-gap: 15px !important;
428
+ }
429
+
430
+ .gr-gallery-item {
431
+ border-radius: var(--border-radius) !important;
432
+ overflow: hidden !important;
433
+ box-shadow: 0 4px 8px var(--shadow-color) !important;
434
+ transition: transform 0.3s ease !important;
435
+ }
436
+
437
+ .gr-gallery-item:hover {
438
+ transform: scale(1.02) !important;
439
+ }
440
+
441
+ .gr-form label {
442
+ font-weight: 600 !important;
443
+ color: #673ab7 !important;
444
+ margin-bottom: 5px !important;
445
+ }
446
+
447
+ .gr-padded {
448
+ padding: 20px !important;
449
+ }
450
+
451
+ .gr-compact {
452
+ gap: 15px !important;
453
+ }
454
+
455
+ .gr-form > div {
456
+ margin-bottom: 16px !important;
457
+ }
458
+
459
+ .gr-form h3 {
460
+ color: #7b1fa2 !important;
461
+ margin-top: 5px !important;
462
+ margin-bottom: 15px !important;
463
+ border-bottom: 2px solid #e1bee7 !important;
464
+ padding-bottom: 8px !important;
465
+ }
466
+
467
+ #examples-panel {
468
+ background-color: #f3e5f5 !important;
469
+ border-radius: var(--border-radius) !important;
470
+ padding: 15px !important;
471
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.05) !important;
472
+ }
473
+
474
+ #examples-panel h2 {
475
+ color: #7b1fa2 !important;
476
+ font-size: 1.5rem !important;
477
+ margin-bottom: 15px !important;
478
+ }
479
+
480
+ .gr-accordion {
481
+ border: 1px solid #e0e0e0 !important;
482
+ border-radius: var(--border-radius) !important;
483
+ overflow: hidden !important;
484
+ }
485
+
486
+ .gr-accordion summary {
487
+ padding: 12px 16px !important;
488
+ background-color: #f9f9f9 !important;
489
+ cursor: pointer !important;
490
+ font-weight: 600 !important;
491
+ color: #673ab7 !important;
492
+ }
493
+
494
+ #generate-btn, #text2img-generate-btn {
495
+ background: linear-gradient(135deg, #ff9a9e, #fad0c4) !important;
496
+ font-size: 1.1rem !important;
497
+ padding: 12px 24px !important;
498
+ margin-top: 10px !important;
499
+ margin-bottom: 15px !important;
500
+ width: 100% !important;
501
+ }
502
+
503
+ #generate-btn:hover, #text2img-generate-btn:hover {
504
+ background: linear-gradient(135deg, #fad0c4, #ff9a9e) !important;
505
+ }
506
+
507
+ /* Tab styling */
508
+ .gr-tabs {
509
+ border: none !important;
510
+ margin-top: 20px !important;
511
+ }
512
+
513
+ .gr-tab {
514
+ background-color: #f3e5f5 !important;
515
+ border: none !important;
516
+ padding: 12px 24px !important;
517
+ font-weight: 600 !important;
518
+ color: #673ab7 !important;
519
+ transition: all 0.3s ease !important;
520
+ }
521
+
522
+ .gr-tab.selected {
523
+ background: linear-gradient(135deg, var(--primary-color), #e1bee7) !important;
524
+ color: white !important;
525
+ }
526
+
527
+ .gr-tab:hover {
528
+ background-color: #ede7f6 !important;
529
+ }
530
+ """
531
+
532
+ _HEADER_ = '''
533
+ <div style="text-align: center; max-width: 850px; margin: 0 auto; padding: 25px 0;">
534
+ <div style="background: linear-gradient(135deg, #f8c3cd, #e1bee7, #b3e5fc); color: white; padding: 15px; border-radius: 15px; box-shadow: 0 10px 20px rgba(0,0,0,0.1); margin-bottom: 20px;">
535
+ <h1 style="font-size: 3rem; font-weight: 800; margin: 0; color: white; text-shadow: 2px 2px 4px rgba(0,0,0,0.2);">โœจ DreamO Video โœจ</h1>
536
+ <p style="font-size: 1.2rem; margin: 10px 0 0;">Create customized images with advanced AI</p>
537
+ </div>
538
+
539
+ <div style="background: white; padding: 15px; border-radius: 12px; box-shadow: 0 5px 15px rgba(0,0,0,0.05);">
540
+ <p style="font-size: 1rem; margin: 0;">In the current demo version, due to ZeroGPU limitations, video generation is restricted to 2 seconds only. (The full version supports generation of up to 60 seconds)</p>
541
+ </div>
542
+
543
+ </div>
544
+
545
+ <div style="background: #fff9c4; padding: 15px; border-radius: 12px; margin-bottom: 20px; border-left: 5px solid #ffd54f; box-shadow: 0 5px 15px rgba(0,0,0,0.05);">
546
+ <h3 style="margin-top: 0; color: #ff6f00;">๐Ÿšฉ Update Notes:</h3>
547
+ <ul style="margin-bottom: 0; padding-left: 20px;">
548
+ <li><b>2025.05.11:</b> We have updated the model to mitigate over-saturation and plastic-face issues. The new version shows consistent improvements over the previous release.</li>
549
+ <li><b>2025.05.13:</b> 'DreamO Video' Integration version Release</li>
550
+ <li><b>2025.05.28:</b> Added 'Text-to-Image' tab with multiple aspect ratios and SNS presets</li>
551
+ </ul>
552
+ </div>
553
+ '''
554
+
555
+ _CITE_ = r"""
556
+ <div style="background: white; padding: 20px; border-radius: 12px; margin-top: 20px; box-shadow: 0 5px 15px rgba(0,0,0,0.05);">
557
+ <p style="margin: 0; font-size: 1.1rem;">If DreamO is helpful, please help to โญ the <a href='https://discord.gg/openfreeai' target='_blank' style="color: #9c27b0; font-weight: 600;">community</a>. Thanks!</p>
558
+ <hr style="border: none; height: 1px; background-color: #e0e0e0; margin: 15px 0;">
559
+ <h4 style="margin: 0 0 10px; color: #7b1fa2;">๐Ÿ“ง Contact</h4>
560
+ <p style="margin: 0;">If you have any questions or feedback, feel free to open a discussion or contact <b>[email protected]</b></p>
561
+ </div>
562
+ """
563
+
564
+ def create_demo():
565
+ with gr.Blocks(css=_CUSTOM_CSS_) as demo:
566
+ gr.HTML(_HEADER_)
567
+
568
+ with gr.Tabs():
569
+ # DreamO Tab
570
+ with gr.Tab("DreamO (์ฐธ์กฐ ์ด๋ฏธ์ง€ ๊ธฐ๋ฐ˜)"):
571
+ with gr.Row():
572
+ with gr.Column(scale=6):
573
+ with gr.Group(elem_id="input-panel", elem_classes="panel-box"):
574
+ gr.Markdown("### ๐Ÿ“ธ Reference Images")
575
+ with gr.Row():
576
+ with gr.Column():
577
+ ref_image1 = gr.Image(label="Reference Image 1", type="numpy", height=256, elem_id="ref-image-1")
578
+ ref_task1 = gr.Dropdown(choices=["ip", "id", "style"], value="ip", label="Task for Reference Image 1", elem_id="ref-task-1")
579
+
580
+ with gr.Column():
581
+ ref_image2 = gr.Image(label="Reference Image 2", type="numpy", height=256, elem_id="ref-image-2")
582
+ ref_task2 = gr.Dropdown(choices=["ip", "id", "style"], value="ip", label="Task for Reference Image 2", elem_id="ref-task-2")
583
+
584
+ gr.Markdown("### โœ๏ธ Generation Parameters")
585
+ prompt = gr.Textbox(label="Prompt", value="a person playing guitar in the street", elem_id="prompt-input")
586
+
587
+ with gr.Row():
588
+ width = gr.Slider(768, 1024, 1024, step=16, label="Width", elem_id="width-slider")
589
+ height = gr.Slider(768, 1024, 1024, step=16, label="Height", elem_id="height-slider")
590
+
591
+ with gr.Row():
592
+ num_steps = gr.Slider(8, 30, 12, step=1, label="Number of Steps", elem_id="steps-slider")
593
+ guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Guidance Scale", elem_id="guidance-slider")
594
+
595
+ seed = gr.Textbox(label="Seed (-1 for random)", value="-1", elem_id="seed-input")
596
+
597
+ with gr.Accordion("Advanced Options", open=False):
598
+ ref_res = gr.Slider(512, 1024, 512, step=16, label="Resolution for Reference Image")
599
+ neg_prompt = gr.Textbox(label="Negative Prompt", value="")
600
+ neg_guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Negative Guidance")
601
+
602
+ with gr.Row():
603
+ true_cfg = gr.Slider(1, 5, 1, step=0.1, label="True CFG")
604
+ first_step_guidance = gr.Slider(0, 10, 0, step=0.1, label="First Step Guidance")
605
+
606
+ with gr.Row():
607
+ cfg_start_step = gr.Slider(0, 30, 0, step=1, label="CFG Start Step")
608
+ cfg_end_step = gr.Slider(0, 30, 0, step=1, label="CFG End Step")
609
+
610
+ generate_btn = gr.Button("โœจ Generate Image", elem_id="generate-btn")
611
+ gr.HTML(_CITE_)
612
+
613
+ with gr.Column(scale=6):
614
+ with gr.Group(elem_id="output-panel", elem_classes="panel-box"):
615
+ gr.Markdown("### ๐Ÿ–ผ๏ธ Generated Result")
616
+ output_image = gr.Image(label="Generated Image", elem_id="output-image", format='png')
617
+ seed_output = gr.Textbox(label="Used Seed", elem_id="seed-output")
618
+
619
+ generate_video_btn = gr.Button("๐ŸŽฌ Generate Video from Image")
620
+ output_video = gr.Video(label="Generated Video", elem_id="video-output")
621
+
622
+ gr.Markdown("### ๐Ÿ” Preprocessing")
623
+ debug_image = gr.Gallery(
624
+ label="Preprocessing Results (including face crop and background removal)",
625
+ elem_id="debug-gallery",
626
+ )
627
+
628
+ with gr.Group(elem_id="examples-panel", elem_classes="panel-box"):
629
+ gr.Markdown("## ๐Ÿ“š Examples")
630
+ example_inps = [
631
+ [
632
+ 'example_inputs/choi.jpg',
633
+ None,
634
+ 'ip',
635
+ 'ip',
636
+ 'a woman sitting on the cloud, playing guitar',
637
+ 1206523688721442817,
638
+ ],
639
+ [
640
+ 'example_inputs/choi.jpg',
641
+ None,
642
+ 'id',
643
+ 'ip',
644
+ 'a woman holding a sign saying "TOP", on the mountain',
645
+ 10441727852953907380,
646
+ ],
647
+ [
648
+ 'example_inputs/perfume.png',
649
+ None,
650
+ 'ip',
651
+ 'ip',
652
+ 'a perfume under spotlight',
653
+ 116150031980664704,
654
+ ],
655
+ [
656
+ 'example_inputs/choi.jpg',
657
+ None,
658
+ 'id',
659
+ 'ip',
660
+ 'portrait, in alps',
661
+ 5443415087540486371,
662
+ ],
663
+ [
664
+ 'example_inputs/mickey.png',
665
+ None,
666
+ 'style',
667
+ 'ip',
668
+ 'generate a same style image. A rooster wearing overalls.',
669
+ 6245580464677124951,
670
+ ],
671
+ [
672
+ 'example_inputs/mountain.png',
673
+ None,
674
+ 'style',
675
+ 'ip',
676
+ 'generate a same style image. A pavilion by the river, and the distant mountains are endless',
677
+ 5248066378927500767,
678
+ ],
679
+ [
680
+ 'example_inputs/shirt.png',
681
+ 'example_inputs/skirt.jpeg',
682
+ 'ip',
683
+ 'ip',
684
+ 'A girl is wearing a short-sleeved shirt and a short skirt on the beach.',
685
+ 9514069256241143615,
686
+ ],
687
+ [
688
+ 'example_inputs/woman2.png',
689
+ 'example_inputs/dress.png',
690
+ 'id',
691
+ 'ip',
692
+ 'the woman wearing a dress, In the banquet hall',
693
+ 7698454872441022867,
694
+ ],
695
+ [
696
+ 'example_inputs/dog1.png',
697
+ 'example_inputs/dog2.png',
698
+ 'ip',
699
+ 'ip',
700
+ 'two dogs in the jungle',
701
+ 6187006025405083344,
702
+ ],
703
+ ]
704
+ gr.Examples(
705
+ examples=example_inps,
706
+ inputs=[ref_image1, ref_image2, ref_task1, ref_task2, prompt, seed],
707
+ label='Examples by category: IP task (rows 1-4), ID task (row 5), Style task (rows 6-7), Try-On task (rows 8-9)',
708
+ cache_examples='lazy',
709
+ outputs=[output_image, debug_image, seed_output],
710
+ fn=generate_image,
711
+ )
712
+
713
+ # Event handlers for DreamO tab
714
+ generate_btn.click(
715
+ fn=generate_image,
716
+ inputs=[
717
+ ref_image1,
718
+ ref_image2,
719
+ ref_task1,
720
+ ref_task2,
721
+ prompt,
722
+ seed,
723
+ width,
724
+ height,
725
+ ref_res,
726
+ num_steps,
727
+ guidance,
728
+ true_cfg,
729
+ cfg_start_step,
730
+ cfg_end_step,
731
+ neg_prompt,
732
+ neg_guidance,
733
+ first_step_guidance,
734
+ ],
735
+ outputs=[output_image, debug_image, seed_output],
736
+ )
737
+
738
+ def on_click_generate_video(img):
739
+ if img is None:
740
+ raise gr.Error("๋จผ์ € ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•ด์ฃผ์„ธ์š”.")
741
+ video_path = generate_video_from_image(img)
742
+ return video_path
743
+
744
+ generate_video_btn.click(
745
+ fn=on_click_generate_video,
746
+ inputs=[output_image],
747
+ outputs=[output_video],
748
+ )
749
+
750
+ # Text-to-Image Tab
751
+ with gr.Tab("ํ…์ŠคํŠธ to ์ด๋ฏธ์ง€"):
752
+ with gr.Row():
753
+ with gr.Column(scale=6):
754
+ with gr.Group(elem_id="text2img-input-panel", elem_classes="panel-box"):
755
+ gr.Markdown("### ๐Ÿ“ ํ…์ŠคํŠธ๋กœ ์ด๋ฏธ์ง€ ์ƒ์„ฑ")
756
+
757
+ # API ์ƒํƒœ ํ‘œ์‹œ
758
+ text2img_status = gr.Textbox(
759
+ label="API ์ƒํƒœ",
760
+ value="API ์—ฐ๊ฒฐ ํ™•์ธ ์ค‘...",
761
+ interactive=False
762
+ )
763
+
764
+ # ํ”„๋กฌํ”„ํŠธ ์ž…๋ ฅ
765
+ text2img_prompt = gr.Textbox(
766
+ label="ํ”„๋กฌํ”„ํŠธ",
767
+ placeholder="์ƒ์„ฑํ•˜๊ณ  ์‹ถ์€ ์ด๋ฏธ์ง€๋ฅผ ์„ค๋ช…ํ•˜์„ธ์š”...",
768
+ lines=3
769
+ )
770
+
771
+ # ์ด๋ฏธ์ง€ ํฌ๊ธฐ ํ”„๋ฆฌ์…‹
772
+ size_preset = gr.Dropdown(
773
+ choices=list(IMAGE_PRESETS.keys()),
774
+ value="1:1 ์ •์‚ฌ๊ฐํ˜•",
775
+ label="์ด๋ฏธ์ง€ ํฌ๊ธฐ ํ”„๋ฆฌ์…‹",
776
+ interactive=True
777
+ )
778
+
779
+ with gr.Row():
780
+ text2img_width = gr.Slider(
781
+ minimum=256,
782
+ maximum=2048,
783
+ value=1024,
784
+ step=64,
785
+ label="๋„ˆ๋น„"
786
+ )
787
+
788
+ text2img_height = gr.Slider(
789
+ minimum=256,
790
+ maximum=2048,
791
+ value=1024,
792
+ step=64,
793
+ label="๋†’์ด"
794
+ )
795
+
796
+ with gr.Row():
797
+ text2img_guidance = gr.Slider(
798
+ minimum=1.0,
799
+ maximum=20.0,
800
+ value=3.5,
801
+ step=0.1,
802
+ label="๊ฐ€์ด๋˜์Šค ์Šค์ผ€์ผ"
803
+ )
804
+
805
+ text2img_steps = gr.Slider(
806
+ minimum=1,
807
+ maximum=50,
808
+ value=30,
809
+ step=1,
810
+ label="์ธํผ๋Ÿฐ์Šค ์Šคํ…"
811
+ )
812
+
813
+ text2img_seed = gr.Number(
814
+ label="์‹œ๋“œ (-1: ๋žœ๋ค)",
815
+ value=-1,
816
+ precision=0
817
+ )
818
+
819
+ text2img_generate_btn = gr.Button("โœจ ์ด๋ฏธ์ง€ ์ƒ์„ฑ", elem_id="text2img-generate-btn")
820
+
821
+ # ์ƒ์„ฑ ์ƒํƒœ ํ‘œ์‹œ
822
+ text2img_generation_status = gr.Textbox(
823
+ label="์ƒ์„ฑ ์ƒํƒœ",
824
+ value="",
825
+ interactive=False,
826
+ visible=False
827
+ )
828
+
829
+ with gr.Column(scale=6):
830
+ with gr.Group(elem_id="text2img-output-panel", elem_classes="panel-box"):
831
+ gr.Markdown("### ๐Ÿ–ผ๏ธ ์ƒ์„ฑ ๊ฒฐ๊ณผ")
832
+ text2img_output = gr.Image(label="์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€", format='png')
833
+ text2img_used_seed = gr.Textbox(label="์‚ฌ์šฉ๋œ ์‹œ๋“œ")
834
+
835
+ # ๋น„๋””์˜ค ์ƒ์„ฑ ๋ฒ„ํŠผ
836
+ text2img_video_btn = gr.Button("๐ŸŽฌ ์ด๋ฏธ์ง€๋ฅผ ๋น„๋””์˜ค๋กœ ๋ณ€ํ™˜")
837
+ text2img_video = gr.Video(label="์ƒ์„ฑ๋œ ๋น„๋””์˜ค")
838
+
839
+ # Text-to-Image ํƒญ ์˜ˆ์ œ
840
+ with gr.Group(elem_id="text2img-examples-panel", elem_classes="panel-box"):
841
+ gr.Markdown("## ๐Ÿ“š ํ…์ŠคํŠธ to ์ด๋ฏธ์ง€ ์˜ˆ์ œ")
842
+ text2img_examples = [
843
+ ["A serene Japanese garden with cherry blossoms", "1:1 ์ •์‚ฌ๊ฐํ˜•", 3.5, 30, 42],
844
+ ["Futuristic cityscape at sunset, cyberpunk style", "16:9 ์™€์ด๋“œ์Šคํฌ๋ฆฐ", 4.0, 35, 123],
845
+ ["Portrait of a mysterious woman with flowing hair", "Instagram ์Šคํ† ๋ฆฌ", 3.0, 25, 789],
846
+ ["Epic fantasy dragon breathing fire", "YouTube ์ธ๋„ค์ผ", 5.0, 40, 456],
847
+ ["Minimalist logo design for tech company", "LinkedIn ๋ฐฐ๋„ˆ", 3.5, 30, 321],
848
+ ]
849
+ gr.Examples(
850
+ examples=text2img_examples,
851
+ inputs=[text2img_prompt, size_preset, text2img_guidance, text2img_steps, text2img_seed],
852
+ label='์˜ˆ์ œ ํ”„๋กฌํ”„ํŠธ์™€ ์„ค์ •',
853
+ cache_examples=False,
854
+ )
855
+
856
+ # Event handlers for Text-to-Image tab
857
+ size_preset.change(
858
+ fn=update_dimensions,
859
+ inputs=[size_preset],
860
+ outputs=[text2img_width, text2img_height]
861
+ )
862
+
863
+ def on_text2img_generate_click():
864
+ text2img_generation_status.visible = True
865
+ text2img_generation_status.value = "์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค‘... ์ž ์‹œ๋งŒ ๊ธฐ๋‹ค๋ ค์ฃผ์„ธ์š”"
866
+ return text2img_generation_status
867
+
868
+ def on_text2img_generate_complete():
869
+ text2img_generation_status.value = "์ด๋ฏธ์ง€ ์ƒ์„ฑ ์™„๋ฃŒ!"
870
+ return text2img_generation_status
871
+
872
+ text2img_generate_btn.click(
873
+ fn=on_text2img_generate_click,
874
+ outputs=[text2img_generation_status]
875
+ ).then(
876
+ fn=generate_text_to_image,
877
+ inputs=[text2img_prompt, text2img_width, text2img_height, text2img_guidance, text2img_steps, text2img_seed],
878
+ outputs=[text2img_output, text2img_used_seed]
879
+ ).then(
880
+ fn=on_text2img_generate_complete,
881
+ outputs=[text2img_generation_status]
882
+ )
883
+
884
+ def on_text2img_video_click(img):
885
+ if img is None:
886
+ raise gr.Error("๋จผ์ € ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•ด์ฃผ์„ธ์š”.")
887
+ video_path = generate_video_from_image(img)
888
+ return video_path
889
+
890
+ text2img_video_btn.click(
891
+ fn=on_text2img_video_click,
892
+ inputs=[text2img_output],
893
+ outputs=[text2img_video],
894
+ )
895
+
896
+ # API ์ƒํƒœ ํ™•์ธ
897
+ def check_text2img_api_status():
898
+ return test_text2img_api_connection()
899
+
900
+ demo.load(
901
+ fn=check_text2img_api_status,
902
+ outputs=[text2img_status]
903
+ )
904
+
905
+ return demo
906
+
907
+
908
+ if __name__ == '__main__':
909
+ demo = create_demo()
910
+ demo.launch(
911
+ server_name="0.0.0.0",
912
+ share=True,
913
+ ssr_mode=False
914
+ )