kolcontrl / app (14).py
fantos's picture
Upload app (14).py
a39c53b verified
raw
history blame
13.5 kB
import spaces
import random
import torch
import cv2
import gradio as gr
import numpy as np
from huggingface_hub import snapshot_download
from transformers import CLIPVisionModelWithProjection,CLIPImageProcessor
from diffusers.utils import load_image
from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import StableDiffusionXLControlNetImg2ImgPipeline
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from kolors.models.controlnet import ControlNetModel
from diffusers import AutoencoderKL
from kolors.models.unet_2d_condition import UNet2DConditionModel
from diffusers import EulerDiscreteScheduler
from PIL import Image
from annotator.midas import MidasDetector
from annotator.dwpose import DWposeDetector
from annotator.util import resize_image, HWC3
device = "cuda"
ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
ckpt_dir_depth = snapshot_download(repo_id="Kwai-Kolors/Kolors-ControlNet-Depth")
ckpt_dir_canny = snapshot_download(repo_id="Kwai-Kolors/Kolors-ControlNet-Canny")
ckpt_dir_pose = snapshot_download(repo_id="Kwai-Kolors/Kolors-ControlNet-Pose")
text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
controlnet_depth = ControlNetModel.from_pretrained(f"{ckpt_dir_depth}", revision=None).half().to(device)
controlnet_canny = ControlNetModel.from_pretrained(f"{ckpt_dir_canny}", revision=None).half().to(device)
controlnet_pose = ControlNetModel.from_pretrained(f"{ckpt_dir_pose}", revision=None).half().to(device)
pipe_depth = StableDiffusionXLControlNetImg2ImgPipeline(
vae=vae,
controlnet = controlnet_depth,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
force_zeros_for_empty_prompt=False
)
pipe_canny = StableDiffusionXLControlNetImg2ImgPipeline(
vae=vae,
controlnet = controlnet_canny,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
force_zeros_for_empty_prompt=False
)
pipe_pose = StableDiffusionXLControlNetImg2ImgPipeline(
vae=vae,
controlnet = controlnet_pose,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
force_zeros_for_empty_prompt=False
)
@spaces.GPU
def process_canny_condition(image, canny_threods=[100,200]):
np_image = image.copy()
np_image = cv2.Canny(np_image, canny_threods[0], canny_threods[1])
np_image = np_image[:, :, None]
np_image = np.concatenate([np_image, np_image, np_image], axis=2)
np_image = HWC3(np_image)
return Image.fromarray(np_image)
model_midas = MidasDetector()
@spaces.GPU
def process_depth_condition_midas(img, res = 1024):
h,w,_ = img.shape
img = resize_image(HWC3(img), res)
result = HWC3(model_midas(img))
result = cv2.resize(result, (w,h))
return Image.fromarray(result)
model_dwpose = DWposeDetector()
@spaces.GPU
def process_dwpose_condition(image, res=1024):
h,w,_ = image.shape
img = resize_image(HWC3(image), res)
out_res, out_img = model_dwpose(image)
result = HWC3(out_img)
result = cv2.resize( result, (w,h) )
return Image.fromarray(result)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
@spaces.GPU
def infer_depth(prompt,
image = None,
negative_prompt = "nsfwοΌŒθ„Έιƒ¨ι˜΄ε½±οΌŒδ½Žεˆ†θΎ¨ηŽ‡οΌŒjpegδΌͺε½±γ€ζ¨‘η³Šγ€η³Ÿη³•οΌŒι»‘θ„ΈοΌŒιœ“θ™Ήη―",
seed = 397886929,
randomize_seed = False,
guidance_scale = 6.0,
num_inference_steps = 50,
controlnet_conditioning_scale = 0.7,
control_guidance_end = 0.9,
strength = 1.0
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
init_image = resize_image(image, MAX_IMAGE_SIZE)
pipe = pipe_depth.to("cuda")
condi_img = process_depth_condition_midas( np.array(init_image), MAX_IMAGE_SIZE)
image = pipe(
prompt= prompt ,
image = init_image,
controlnet_conditioning_scale = controlnet_conditioning_scale,
control_guidance_end = control_guidance_end,
strength= strength ,
control_image = condi_img,
negative_prompt= negative_prompt ,
num_inference_steps= num_inference_steps,
guidance_scale= guidance_scale,
num_images_per_prompt=1,
generator=generator,
).images[0]
return [condi_img, image], seed
@spaces.GPU
def infer_canny(prompt,
image = None,
negative_prompt = "nsfwοΌŒθ„Έιƒ¨ι˜΄ε½±οΌŒδ½Žεˆ†θΎ¨ηŽ‡οΌŒjpegδΌͺε½±γ€ζ¨‘η³Šγ€η³Ÿη³•οΌŒι»‘θ„ΈοΌŒιœ“θ™Ήη―",
seed = 397886929,
randomize_seed = False,
guidance_scale = 6.0,
num_inference_steps = 50,
controlnet_conditioning_scale = 0.7,
control_guidance_end = 0.9,
strength = 1.0
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
init_image = resize_image(image, MAX_IMAGE_SIZE)
pipe = pipe_canny.to("cuda")
condi_img = process_canny_condition(np.array(init_image))
image = pipe(
prompt= prompt ,
image = init_image,
controlnet_conditioning_scale = controlnet_conditioning_scale,
control_guidance_end = control_guidance_end,
strength= strength ,
control_image = condi_img,
negative_prompt= negative_prompt ,
num_inference_steps= num_inference_steps,
guidance_scale= guidance_scale,
num_images_per_prompt=1,
generator=generator,
).images[0]
return [condi_img, image], seed
@spaces.GPU
def infer_pose(prompt,
image = None,
negative_prompt = "nsfwοΌŒθ„Έιƒ¨ι˜΄ε½±οΌŒδ½Žεˆ†θΎ¨ηŽ‡οΌŒjpegδΌͺε½±γ€ζ¨‘η³Šγ€η³Ÿη³•οΌŒι»‘θ„ΈοΌŒιœ“θ™Ήη―",
seed = 66,
randomize_seed = False,
guidance_scale = 6.0,
num_inference_steps = 50,
controlnet_conditioning_scale = 0.7,
control_guidance_end = 0.9,
strength = 1.0
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
init_image = resize_image(image, MAX_IMAGE_SIZE)
pipe = pipe_pose.to("cuda")
condi_img = process_dwpose_condition(np.array(init_image), MAX_IMAGE_SIZE)
image = pipe(
prompt= prompt ,
image = init_image,
controlnet_conditioning_scale = controlnet_conditioning_scale,
control_guidance_end = control_guidance_end,
strength= strength ,
control_image = condi_img,
negative_prompt= negative_prompt ,
num_inference_steps= num_inference_steps,
guidance_scale= guidance_scale,
num_images_per_prompt=1,
generator=generator,
).images[0]
return [condi_img, image], seed
canny_examples = [
["μ•„λ¦„λ‹€μš΄ μ†Œλ…€, κ³ ν’ˆμ§ˆ, 맀우 μ„ λͺ…, μƒμƒν•œ 색상, μ΄ˆκ³ ν•΄μƒλ„, μ΅œμƒμ˜ ν’ˆμ§ˆ, 8k, κ³ ν™”μ§ˆ, 4K",
"image/woman_1.png"],
["νŒŒλ…ΈλΌλ§ˆ, μ»΅ μ•ˆμ— μ•‰μ•„μžˆλŠ” κ·€μ—¬μš΄ 흰 κ°•μ•„μ§€, 카메라λ₯Ό λ°”λΌλ³΄λŠ”, μ• λ‹ˆλ©”μ΄μ…˜ μŠ€νƒ€μΌ, 3D λ Œλ”λ§, μ˜₯ν…ŒμΈ λ Œλ”",
"image/dog.png"]
]
depth_examples = [
["신카이 λ§ˆμ½”ν†  μŠ€νƒ€μΌ, ν’λΆ€ν•œ 색감, 초둝 μ…”μΈ λ₯Ό μž…μ€ 여성이 λ“€νŒμ— μ„œ μžˆλŠ”, μ•„λ¦„λ‹€μš΄ 풍경, λ§‘κ³  밝은, 얼룩진 λΉ›κ³Ό 그림자, 졜고의 ν’ˆμ§ˆ, μ΄ˆμ„Έλ°€, 8K ν™”μ§ˆ",
"image/woman_2.png"],
["ν™”λ €ν•œ μƒ‰μƒμ˜ μž‘μ€ μƒˆ, κ³ ν’ˆμ§ˆ, 맀우 μ„ λͺ…, μƒμƒν•œ 색상, μ΄ˆκ³ ν•΄μƒλ„, μ΅œμƒμ˜ ν’ˆμ§ˆ, 8k, κ³ ν™”μ§ˆ, 4K",
"image/bird.png"]
]
pose_examples = [
["보라색 퍼프 슬리브 λ“œλ ˆμŠ€λ₯Ό μž…κ³  μ™•κ΄€κ³Ό 흰색 레이슀 μž₯갑을 λ‚€ μ†Œλ…€κ°€ μ–‘ μ†μœΌλ‘œ 얼꡴을 감싸고 μžˆλŠ”, κ³ ν’ˆμ§ˆ, 맀우 μ„ λͺ…, μƒμƒν•œ 색상, μ΄ˆκ³ ν•΄μƒλ„, μ΅œμƒμ˜ ν’ˆμ§ˆ, 8k, κ³ ν™”μ§ˆ, 4K",
"image/woman_3.png"],
["검은색 슀포츠 μž¬ν‚·κ³Ό 흰색 μ΄λ„ˆλ₯Ό μž…κ³  λͺ©κ±Έμ΄λ₯Ό ν•œ 여성이 거리에 μ„œ μžˆλŠ”, 배경은 λΉ¨κ°„ 건물과 녹색 λ‚˜λ¬΄, κ³ ν’ˆμ§ˆ, 맀우 μ„ λͺ…, μƒμƒν•œ 색상, μ΄ˆκ³ ν•΄μƒλ„, μ΅œμƒμ˜ ν’ˆμ§ˆ, 8k, κ³ ν™”μ§ˆ, 4K",
"image/woman_4.png"]
]
css = """
footer {
visibility: hidden;
}
"""
def load_description(fp):
with open(fp, 'r', encoding='utf-8') as f:
content = f.read()
return content
with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as Kolors:
with gr.Row():
with gr.Column(elem_id="col-left"):
with gr.Row():
prompt = gr.Textbox(
label="ν”„λ‘¬ν”„νŠΈ",
placeholder="ν”„λ‘¬ν”„νŠΈλ₯Ό μž…λ ₯ν•˜μ„Έμš”",
lines=2
)
with gr.Row():
image = gr.Image(label="이미지", type="pil")
with gr.Accordion("κ³ κΈ‰ μ„€μ •", open=False):
negative_prompt = gr.Textbox(
label="λ„€κ±°ν‹°λΈŒ ν”„λ‘¬ν”„νŠΈ",
placeholder="λ„€κ±°ν‹°λΈŒ ν”„λ‘¬ν”„νŠΈλ₯Ό μž…λ ₯ν•˜μ„Έμš”",
visible=True,
value="nsfw, μ–Όκ΅΄ 그림자, 저해상도, jpeg μ•„ν‹°νŒ©νŠΈ, 흐릿함, 열악함, 검은 μ–Όκ΅΄, λ„€μ˜¨ μ‘°λͺ…"
)
seed = gr.Slider(
label="μ‹œλ“œ",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="μ‹œλ“œ λ¬΄μž‘μœ„ν™”", value=True)
with gr.Row():
guidance_scale = gr.Slider(
label="κ°€μ΄λ˜μŠ€ μŠ€μΌ€μΌ",
minimum=0.0,
maximum=10.0,
step=0.1,
value=6.0,
)
num_inference_steps = gr.Slider(
label="μΆ”λ‘  단계 수",
minimum=10,
maximum=50,
step=1,
value=30,
)
with gr.Row():
controlnet_conditioning_scale = gr.Slider(
label="μ»¨νŠΈλ‘€λ„· 컨디셔닝 μŠ€μΌ€μΌ",
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.7,
)
control_guidance_end = gr.Slider(
label="컨트둀 κ°€μ΄λ˜μŠ€ μ’…λ£Œ",
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.9,
)
with gr.Row():
strength = gr.Slider(
label="강도",
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
)
with gr.Row():
canny_button = gr.Button("μΊλ‹ˆ", elem_id="button")
depth_button = gr.Button("깊이", elem_id="button")
pose_button = gr.Button("포즈", elem_id="button")
with gr.Column(elem_id="col-right"):
result = gr.Gallery(label="κ²°κ³Ό", show_label=False, columns=2)
seed_used = gr.Number(label="μ‚¬μš©λœ μ‹œλ“œ")
with gr.Row():
gr.Examples(
fn = infer_canny,
examples = canny_examples,
inputs = [prompt, image],
outputs = [result, seed_used],
label = "Canny"
)
with gr.Row():
gr.Examples(
fn = infer_depth,
examples = depth_examples,
inputs = [prompt, image],
outputs = [result, seed_used],
label = "Depth"
)
with gr.Row():
gr.Examples(
fn = infer_pose,
examples = pose_examples,
inputs = [prompt, image],
outputs = [result, seed_used],
label = "Pose"
)
canny_button.click(
fn = infer_canny,
inputs = [prompt, image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, controlnet_conditioning_scale, control_guidance_end, strength],
outputs = [result, seed_used]
)
depth_button.click(
fn = infer_depth,
inputs = [prompt, image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, controlnet_conditioning_scale, control_guidance_end, strength],
outputs = [result, seed_used]
)
pose_button.click(
fn = infer_pose,
inputs = [prompt, image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, controlnet_conditioning_scale, control_guidance_end, strength],
outputs = [result, seed_used]
)
Kolors.queue().launch(debug=True)