videoGen / app.py
Dhan98's picture
app added
aa83b59 verified
raw
history blame
3.95 kB
import streamlit as st
from transformers import AutoProcessor, AutoModel, BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel
import torch
import cv2
import numpy as np
from PIL import Image
import tempfile
import os
@st.cache_resource
def load_models():
ltx = AutoModel.from_pretrained("Lightricks/LTX-Video", trust_remote_code=True)
ltx_processor = AutoProcessor.from_pretrained("Lightricks/LTX-Video")
blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
return ltx, ltx_processor, blip, blip_processor, clip, clip_processor
def enhance_image(image):
img = np.array(image)
denoised = cv2.fastNlMeansDenoisingColored(img)
lab = cv2.cvtColor(denoised, cv2.COLOR_RGB2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
l = clahe.apply(l)
enhanced = cv2.cvtColor(cv2.merge([l,a,b]), cv2.COLOR_LAB2RGB)
return Image.fromarray(enhanced)
def get_descriptions(image, blip_model, blip_processor, clip_model, clip_processor):
blip_inputs = blip_processor(images=image, return_tensors="pt")
blip_output = blip_model.generate(**blip_inputs, max_length=50)
blip_desc = blip_processor.decode(blip_output[0], skip_special_tokens=True)
clip_inputs = clip_processor(images=image, return_tensors="pt", text=None)
image_features = clip_model.get_image_features(**clip_inputs)
attributes = ["bright", "dark", "colorful", "natural", "indoor", "outdoor"]
text_inputs = clip_processor(text=attributes, return_tensors="pt", images=None)
text_features = clip_model.get_text_features(**text_inputs)
similarity = torch.nn.functional.cosine_similarity(image_features, text_features)
detected_attrs = [attr for i, attr in enumerate(attributes) if similarity[i] > 0.2]
return f"{blip_desc} The image appears {', '.join(detected_attrs)}."
def generate_video(model, processor, image, description):
# Prepare the input
inputs = processor(
images=image,
text=description,
return_tensors="pt"
)
# Generate video frames
with torch.no_grad():
frames = model.generate(
**inputs,
num_frames=30, # 10 seconds at 3fps
num_inference_steps=50,
guidance_scale=7.5
)
# Save video to temporary file
temp_dir = tempfile.mkdtemp()
temp_path = os.path.join(temp_dir, "output.mp4")
# Convert frames to video
frames = frames.cpu().numpy()
height, width = frames[0].shape[:2]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(temp_path, fourcc, 3, (width, height))
for frame in frames:
video_writer.write(frame)
video_writer.release()
return temp_path
def main():
st.title("Enhanced Video Generator")
models = load_models()
ltx, ltx_processor, blip, blip_processor, clip, clip_processor = models
image_file = st.file_uploader("Upload Image", type=['png', 'jpg', 'jpeg'])
if image_file:
image = Image.open(image_file)
enhanced_image = enhance_image(image)
st.image(enhanced_image, caption="Enhanced Image")
description = get_descriptions(
enhanced_image, blip, blip_processor, clip, clip_processor
)
st.write("Image Analysis:", description)
if st.button("Generate Video"):
with st.spinner("Generating video..."):
video_path = generate_video(ltx, ltx_processor, enhanced_image, description)
st.video(video_path)
if __name__ == "__main__":
main()