|
import streamlit as st |
|
from transformers import AutoProcessor, AutoModel, BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel |
|
import torch |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
import tempfile |
|
import os |
|
|
|
@st.cache_resource |
|
def load_models(): |
|
ltx = AutoModel.from_pretrained("Lightricks/LTX-Video", trust_remote_code=True) |
|
ltx_processor = AutoProcessor.from_pretrained("Lightricks/LTX-Video") |
|
|
|
blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large") |
|
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") |
|
|
|
clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
return ltx, ltx_processor, blip, blip_processor, clip, clip_processor |
|
|
|
def enhance_image(image): |
|
img = np.array(image) |
|
denoised = cv2.fastNlMeansDenoisingColored(img) |
|
lab = cv2.cvtColor(denoised, cv2.COLOR_RGB2LAB) |
|
l, a, b = cv2.split(lab) |
|
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) |
|
l = clahe.apply(l) |
|
enhanced = cv2.cvtColor(cv2.merge([l,a,b]), cv2.COLOR_LAB2RGB) |
|
return Image.fromarray(enhanced) |
|
|
|
def get_descriptions(image, blip_model, blip_processor, clip_model, clip_processor): |
|
blip_inputs = blip_processor(images=image, return_tensors="pt") |
|
blip_output = blip_model.generate(**blip_inputs, max_length=50) |
|
blip_desc = blip_processor.decode(blip_output[0], skip_special_tokens=True) |
|
|
|
clip_inputs = clip_processor(images=image, return_tensors="pt", text=None) |
|
image_features = clip_model.get_image_features(**clip_inputs) |
|
|
|
attributes = ["bright", "dark", "colorful", "natural", "indoor", "outdoor"] |
|
text_inputs = clip_processor(text=attributes, return_tensors="pt", images=None) |
|
text_features = clip_model.get_text_features(**text_inputs) |
|
|
|
similarity = torch.nn.functional.cosine_similarity(image_features, text_features) |
|
detected_attrs = [attr for i, attr in enumerate(attributes) if similarity[i] > 0.2] |
|
|
|
return f"{blip_desc} The image appears {', '.join(detected_attrs)}." |
|
|
|
def generate_video(model, processor, image, description): |
|
|
|
inputs = processor( |
|
images=image, |
|
text=description, |
|
return_tensors="pt" |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
frames = model.generate( |
|
**inputs, |
|
num_frames=30, |
|
num_inference_steps=50, |
|
guidance_scale=7.5 |
|
) |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
temp_path = os.path.join(temp_dir, "output.mp4") |
|
|
|
|
|
frames = frames.cpu().numpy() |
|
height, width = frames[0].shape[:2] |
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
video_writer = cv2.VideoWriter(temp_path, fourcc, 3, (width, height)) |
|
|
|
for frame in frames: |
|
video_writer.write(frame) |
|
video_writer.release() |
|
|
|
return temp_path |
|
|
|
def main(): |
|
st.title("Enhanced Video Generator") |
|
|
|
models = load_models() |
|
ltx, ltx_processor, blip, blip_processor, clip, clip_processor = models |
|
|
|
image_file = st.file_uploader("Upload Image", type=['png', 'jpg', 'jpeg']) |
|
if image_file: |
|
image = Image.open(image_file) |
|
enhanced_image = enhance_image(image) |
|
|
|
st.image(enhanced_image, caption="Enhanced Image") |
|
|
|
description = get_descriptions( |
|
enhanced_image, blip, blip_processor, clip, clip_processor |
|
) |
|
st.write("Image Analysis:", description) |
|
|
|
if st.button("Generate Video"): |
|
with st.spinner("Generating video..."): |
|
video_path = generate_video(ltx, ltx_processor, enhanced_image, description) |
|
st.video(video_path) |
|
|
|
if __name__ == "__main__": |
|
main() |