Dhan98 commited on
Commit
aa83b59
·
verified ·
1 Parent(s): 16dc47e
Files changed (1) hide show
  1. app.py +107 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoProcessor, AutoModel, BlipProcessor, BlipForConditionalGeneration, CLIPProcessor, CLIPModel
3
+ import torch
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+ import tempfile
8
+ import os
9
+
10
+ @st.cache_resource
11
+ def load_models():
12
+ ltx = AutoModel.from_pretrained("Lightricks/LTX-Video", trust_remote_code=True)
13
+ ltx_processor = AutoProcessor.from_pretrained("Lightricks/LTX-Video")
14
+
15
+ blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
16
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
17
+
18
+ clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
19
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
20
+
21
+ return ltx, ltx_processor, blip, blip_processor, clip, clip_processor
22
+
23
+ def enhance_image(image):
24
+ img = np.array(image)
25
+ denoised = cv2.fastNlMeansDenoisingColored(img)
26
+ lab = cv2.cvtColor(denoised, cv2.COLOR_RGB2LAB)
27
+ l, a, b = cv2.split(lab)
28
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
29
+ l = clahe.apply(l)
30
+ enhanced = cv2.cvtColor(cv2.merge([l,a,b]), cv2.COLOR_LAB2RGB)
31
+ return Image.fromarray(enhanced)
32
+
33
+ def get_descriptions(image, blip_model, blip_processor, clip_model, clip_processor):
34
+ blip_inputs = blip_processor(images=image, return_tensors="pt")
35
+ blip_output = blip_model.generate(**blip_inputs, max_length=50)
36
+ blip_desc = blip_processor.decode(blip_output[0], skip_special_tokens=True)
37
+
38
+ clip_inputs = clip_processor(images=image, return_tensors="pt", text=None)
39
+ image_features = clip_model.get_image_features(**clip_inputs)
40
+
41
+ attributes = ["bright", "dark", "colorful", "natural", "indoor", "outdoor"]
42
+ text_inputs = clip_processor(text=attributes, return_tensors="pt", images=None)
43
+ text_features = clip_model.get_text_features(**text_inputs)
44
+
45
+ similarity = torch.nn.functional.cosine_similarity(image_features, text_features)
46
+ detected_attrs = [attr for i, attr in enumerate(attributes) if similarity[i] > 0.2]
47
+
48
+ return f"{blip_desc} The image appears {', '.join(detected_attrs)}."
49
+
50
+ def generate_video(model, processor, image, description):
51
+ # Prepare the input
52
+ inputs = processor(
53
+ images=image,
54
+ text=description,
55
+ return_tensors="pt"
56
+ )
57
+
58
+ # Generate video frames
59
+ with torch.no_grad():
60
+ frames = model.generate(
61
+ **inputs,
62
+ num_frames=30, # 10 seconds at 3fps
63
+ num_inference_steps=50,
64
+ guidance_scale=7.5
65
+ )
66
+
67
+ # Save video to temporary file
68
+ temp_dir = tempfile.mkdtemp()
69
+ temp_path = os.path.join(temp_dir, "output.mp4")
70
+
71
+ # Convert frames to video
72
+ frames = frames.cpu().numpy()
73
+ height, width = frames[0].shape[:2]
74
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
75
+ video_writer = cv2.VideoWriter(temp_path, fourcc, 3, (width, height))
76
+
77
+ for frame in frames:
78
+ video_writer.write(frame)
79
+ video_writer.release()
80
+
81
+ return temp_path
82
+
83
+ def main():
84
+ st.title("Enhanced Video Generator")
85
+
86
+ models = load_models()
87
+ ltx, ltx_processor, blip, blip_processor, clip, clip_processor = models
88
+
89
+ image_file = st.file_uploader("Upload Image", type=['png', 'jpg', 'jpeg'])
90
+ if image_file:
91
+ image = Image.open(image_file)
92
+ enhanced_image = enhance_image(image)
93
+
94
+ st.image(enhanced_image, caption="Enhanced Image")
95
+
96
+ description = get_descriptions(
97
+ enhanced_image, blip, blip_processor, clip, clip_processor
98
+ )
99
+ st.write("Image Analysis:", description)
100
+
101
+ if st.button("Generate Video"):
102
+ with st.spinner("Generating video..."):
103
+ video_path = generate_video(ltx, ltx_processor, enhanced_image, description)
104
+ st.video(video_path)
105
+
106
+ if __name__ == "__main__":
107
+ main()