import streamlit as st from transformers import AutoProcessor, AutoModel, BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel import torch import cv2 import numpy as np from PIL import Image import tempfile import os @st.cache_resource def load_models(): ltx = AutoModel.from_pretrained("Lightricks/LTX-Video", trust_remote_code=True) ltx_processor = AutoProcessor.from_pretrained("Lightricks/LTX-Video") blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large") blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") return ltx, ltx_processor, blip, blip_processor, clip, clip_processor def enhance_image(image): img = np.array(image) denoised = cv2.fastNlMeansDenoisingColored(img) lab = cv2.cvtColor(denoised, cv2.COLOR_RGB2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) l = clahe.apply(l) enhanced = cv2.cvtColor(cv2.merge([l,a,b]), cv2.COLOR_LAB2RGB) return Image.fromarray(enhanced) def get_descriptions(image, blip_model, blip_processor, clip_model, clip_processor): blip_inputs = blip_processor(images=image, return_tensors="pt") blip_output = blip_model.generate(**blip_inputs, max_length=50) blip_desc = blip_processor.decode(blip_output[0], skip_special_tokens=True) clip_inputs = clip_processor(images=image, return_tensors="pt", text=None) image_features = clip_model.get_image_features(**clip_inputs) attributes = ["bright", "dark", "colorful", "natural", "indoor", "outdoor"] text_inputs = clip_processor(text=attributes, return_tensors="pt", images=None) text_features = clip_model.get_text_features(**text_inputs) similarity = torch.nn.functional.cosine_similarity(image_features, text_features) detected_attrs = [attr for i, attr in enumerate(attributes) if similarity[i] > 0.2] return f"{blip_desc} The image appears {', '.join(detected_attrs)}." def generate_video(model, processor, image, description): # Prepare the input inputs = processor( images=image, text=description, return_tensors="pt" ) # Generate video frames with torch.no_grad(): frames = model.generate( **inputs, num_frames=30, # 10 seconds at 3fps num_inference_steps=50, guidance_scale=7.5 ) # Save video to temporary file temp_dir = tempfile.mkdtemp() temp_path = os.path.join(temp_dir, "output.mp4") # Convert frames to video frames = frames.cpu().numpy() height, width = frames[0].shape[:2] fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_writer = cv2.VideoWriter(temp_path, fourcc, 3, (width, height)) for frame in frames: video_writer.write(frame) video_writer.release() return temp_path def main(): st.title("Enhanced Video Generator") models = load_models() ltx, ltx_processor, blip, blip_processor, clip, clip_processor = models image_file = st.file_uploader("Upload Image", type=['png', 'jpg', 'jpeg']) if image_file: image = Image.open(image_file) enhanced_image = enhance_image(image) st.image(enhanced_image, caption="Enhanced Image") description = get_descriptions( enhanced_image, blip, blip_processor, clip, clip_processor ) st.write("Image Analysis:", description) if st.button("Generate Video"): with st.spinner("Generating video..."): video_path = generate_video(ltx, ltx_processor, enhanced_image, description) st.video(video_path) if __name__ == "__main__": main()