jbilcke-hf HF Staff commited on
Commit
f8498f5
·
verified ·
1 Parent(s): e0425c1

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +18 -0
  2. handler.py +255 -0
  3. model_index.json +32 -0
  4. requirements.txt +20 -0
  5. teacache.py +146 -0
README.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ base_model:
5
+ - tencent/HunyuanVideo
6
+ pipeline_tag: text-to-video
7
+ library_name: diffusers
8
+ tags:
9
+ - HunyuanVideo
10
+ - Tencent
11
+ - Video
12
+ license: other
13
+ license_name: tencent-hunyuan-community
14
+ license_link: LICENSE
15
+ ---
16
+
17
+ This model is [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo) adapted to run on the Hugging Face Inference Endpoints.
18
+
handler.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Any, Optional
3
+ import base64
4
+ import logging
5
+ import random
6
+ import traceback
7
+ import torch
8
+ from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
9
+ from varnish import Varnish
10
+
11
+ from enhance_a_video import enable_enhance, inject_enhance_for_hunyuanvideo, set_enhance_weight
12
+ from teacache import enable_teacache, disable_teacache
13
+
14
+ # Configure logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ @dataclass
19
+ class GenerationConfig:
20
+ """Configuration for video generation"""
21
+ # Content settings
22
+ prompt: str
23
+ negative_prompt: str = ""
24
+
25
+ # Model settings
26
+ num_frames: int = 49 # Should be 4k + 1 format
27
+ height: int = 320
28
+ width: int = 576
29
+ num_inference_steps: int = 50
30
+ guidance_scale: float = 7.0
31
+
32
+ # Reproducibility
33
+ seed: int = -1
34
+
35
+ # Varnish post-processing settings
36
+ fps: int = 30
37
+ double_num_frames: bool = False
38
+ super_resolution: bool = False
39
+ grain_amount: float = 0.0
40
+ quality: int = 18 # CRF scale (0-51, lower is better)
41
+
42
+ # Audio settings
43
+ enable_audio: bool = False
44
+ audio_prompt: str = ""
45
+ audio_negative_prompt: str = "voices, voice, talking, speaking, speech"
46
+
47
+ # TeaCache settings
48
+ enable_teacache: bool = True
49
+ teacache_threshold: float = 0.15 # values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup)
50
+
51
+
52
+ # Enhance-A-Video settings
53
+ enable_enhance_a_video: bool = True
54
+ enhance_a_video_weight: float = 4.0
55
+
56
+ def validate_and_adjust(self) -> 'GenerationConfig':
57
+ """Validate and adjust parameters"""
58
+ # Ensure num_frames follows 4k + 1 format
59
+ k = (self.num_frames - 1) // 4
60
+ self.num_frames = (k * 4) + 1
61
+
62
+ # Set random seed if not specified
63
+ if self.seed == -1:
64
+ self.seed = random.randint(0, 2**32 - 1)
65
+
66
+ return self
67
+
68
+ class EndpointHandler:
69
+ """Handles video generation requests using HunyuanVideo and Varnish"""
70
+
71
+ def __init__(self, path: str = ""):
72
+ """Initialize handler with models
73
+
74
+ Args:
75
+ path: Path to model weights
76
+ """
77
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
78
+
79
+
80
+ # Initialize transformer with Enhance-A-Video injection first
81
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained(
82
+ path,
83
+ subfolder="transformer",
84
+ torch_dtype=torch.bfloat16
85
+ )
86
+ inject_enhance_for_hunyuanvideo(transformer)
87
+
88
+ # Initialize HunyuanVideo pipeline with the enhanced transformer
89
+ self.pipeline = HunyuanVideoPipeline.from_pretrained(
90
+ path,
91
+ transformer=transformer,
92
+ torch_dtype=torch.float16,
93
+ ).to(self.device)
94
+
95
+
96
+ # Initialize text encoders in float16
97
+ self.pipeline.text_encoder = self.pipeline.text_encoder.half()
98
+ self.pipeline.text_encoder_2 = self.pipeline.text_encoder_2.half()
99
+
100
+ # Initialize transformer in bfloat16
101
+ self.pipeline.transformer = self.pipeline.transformer.to(torch.bfloat16)
102
+
103
+ # Initialize VAE in float16
104
+ self.pipeline.vae = self.pipeline.vae.half()
105
+
106
+ # Initialize Varnish for post-processing
107
+ self.varnish = Varnish(
108
+ device=self.device,
109
+ model_base_dir="/repository/varnish"
110
+ )
111
+
112
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
113
+ """Process video generation requests
114
+
115
+ Args:
116
+ data: Request data containing:
117
+ - inputs (str): Prompt for video generation
118
+ - parameters (dict): Generation parameters
119
+
120
+ Returns:
121
+ Dictionary containing:
122
+ - video: Base64 encoded MP4 data URI
123
+ - content-type: MIME type
124
+ - metadata: Generation metadata
125
+ """
126
+ # Extract inputs
127
+ inputs = data.pop("inputs", data)
128
+ if isinstance(inputs, dict):
129
+ prompt = inputs.get("prompt", "")
130
+ else:
131
+ prompt = inputs
132
+
133
+ params = data.get("parameters", {})
134
+
135
+ # Create and validate config
136
+ config = GenerationConfig(
137
+ prompt=prompt,
138
+ negative_prompt=params.get("negative_prompt", ""),
139
+ num_frames=params.get("num_frames", 49),
140
+ height=params.get("height", 320),
141
+ width=params.get("width", 576),
142
+ num_inference_steps=params.get("num_inference_steps", 50),
143
+ guidance_scale=params.get("guidance_scale", 7.0),
144
+ seed=params.get("seed", -1),
145
+ fps=params.get("fps", 30),
146
+ double_num_frames=params.get("double_num_frames", False),
147
+ super_resolution=params.get("super_resolution", False),
148
+ grain_amount=params.get("grain_amount", 0.0),
149
+ quality=params.get("quality", 18),
150
+ enable_audio=params.get("enable_audio", False),
151
+ audio_prompt=params.get("audio_prompt", ""),
152
+ audio_negative_prompt=params.get("audio_negative_prompt", "voices, voice, talking, speaking, speech"),
153
+ enable_teacache=params.get("enable_teacache", True),
154
+
155
+ # values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup).
156
+ teacache_threshold=params.get("teacache_threshold", 0.15),
157
+
158
+ enable_enhance_a_video=params.get("enable_enhance_a_video", True),
159
+ enhance_a_video_weight=params.get("enhance_a_video_weight", 4.0)
160
+ ).validate_and_adjust()
161
+
162
+ try:
163
+ # Set random seeds
164
+ if config.seed != -1:
165
+ torch.manual_seed(config.seed)
166
+ random.seed(config.seed)
167
+ generator = torch.Generator(device=self.device).manual_seed(config.seed)
168
+ else:
169
+ generator = None
170
+
171
+ # Configure TeaCache
172
+ #if config.enable_teacache:
173
+ # enable_teacache(
174
+ # self.pipeline.transformer,
175
+ # num_inference_steps=config.num_inference_steps,
176
+ # rel_l1_thresh=config.teacache_threshold
177
+ # )
178
+ #else:
179
+ # disable_teacache(self.pipeline.transformer)
180
+
181
+ # Configure Enhance-A-Video weight if enabled
182
+ if config.enable_enhance_a_video:
183
+ set_enhance_weight(config.enhance_a_video_weight)
184
+ enable_enhance()
185
+ else:
186
+ # Reset enhance weight to 0 to effectively disable it
187
+ set_enhance_weight(0)
188
+
189
+ # Generate video frames
190
+ with torch.inference_mode():
191
+ output = self.pipeline(
192
+ prompt=config.prompt,
193
+
194
+ # Failed to generate video: HunyuanVideoPipeline.__call__() got an unexpected keyword argument 'negative_prompt'
195
+ #negative_prompt=config.negative_prompt,
196
+
197
+ num_frames=config.num_frames,
198
+ height=config.height,
199
+ width=config.width,
200
+ num_inference_steps=config.num_inference_steps,
201
+ guidance_scale=config.guidance_scale,
202
+ generator=generator,
203
+ output_type="pt",
204
+ ).frames
205
+
206
+ # Process with Varnish
207
+ import asyncio
208
+ try:
209
+ loop = asyncio.get_event_loop()
210
+ except RuntimeError:
211
+ loop = asyncio.new_event_loop()
212
+ asyncio.set_event_loop(loop)
213
+
214
+ result = loop.run_until_complete(
215
+ self.varnish(
216
+ input_data=output,
217
+ fps=config.fps,
218
+ double_num_frames=config.double_num_frames,
219
+ super_resolution=config.super_resolution,
220
+ grain_amount=config.grain_amount,
221
+ enable_audio=config.enable_audio,
222
+ audio_prompt=config.audio_prompt,
223
+ audio_negative_prompt=config.audio_negative_prompt,
224
+ )
225
+ )
226
+
227
+ # Get video data URI
228
+ video_uri = loop.run_until_complete(
229
+ result.write(
230
+ type="data-uri",
231
+ quality=config.quality
232
+ )
233
+ )
234
+
235
+ return {
236
+ "video": video_uri,
237
+ "content-type": "video/mp4",
238
+ "metadata": {
239
+ "width": result.metadata.width,
240
+ "height": result.metadata.height,
241
+ "num_frames": result.metadata.frame_count,
242
+ "fps": result.metadata.fps,
243
+ "duration": result.metadata.duration,
244
+ "seed": config.seed,
245
+ "enable_teacache": config.enable_teacache,
246
+ "teacache_threshold": config.teacache_threshold if config.enable_teacache else 0,
247
+ "enable_enhance_a_video": config.enable_enhance_a_video,
248
+ "enhance_a_video_weight": config.enhance_a_video_weight if config.enable_enhance_a_video else 0,
249
+ }
250
+ }
251
+
252
+ except Exception as e:
253
+ message = f"Error generating video ({str(e)})\n{traceback.format_exc()}"
254
+ logger.error(message)
255
+ raise RuntimeError(message)
model_index.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "HunyuanVideoPipeline",
3
+ "_diffusers_version": "0.32.0.dev0",
4
+ "scheduler": [
5
+ "diffusers",
6
+ "FlowMatchEulerDiscreteScheduler"
7
+ ],
8
+ "text_encoder": [
9
+ "transformers",
10
+ "LlamaModel"
11
+ ],
12
+ "text_encoder_2": [
13
+ "transformers",
14
+ "CLIPTextModel"
15
+ ],
16
+ "tokenizer": [
17
+ "transformers",
18
+ "LlamaTokenizerFast"
19
+ ],
20
+ "tokenizer_2": [
21
+ "transformers",
22
+ "CLIPTokenizer"
23
+ ],
24
+ "transformer": [
25
+ "diffusers",
26
+ "HunyuanVideoTransformer3DModel"
27
+ ],
28
+ "vae": [
29
+ "diffusers",
30
+ "AutoencoderKLHunyuanVideo"
31
+ ]
32
+ }
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers @ git+https://github.com/huggingface/diffusers.git@main
2
+ varnish @ git+https://github.com/jbilcke-hf/varnish.git@main
3
+
4
+ opencv-python>=4.10.0.84
5
+
6
+ transformers==4.48.0
7
+ huggingface_hub==0.27.1
8
+
9
+ tokenizers>=0.20.3
10
+ accelerate>=1.1.1
11
+ pandas>=2.0.3
12
+ numpy
13
+ einops==0.7.0
14
+ tqdm>=4.66.5
15
+ loguru>=0.7.2
16
+ imageio>=2.34.2
17
+ imageio-ffmpeg>=0.5.1
18
+ safetensors>=0.4.5
19
+
20
+ moviepy==1.0.3
teacache.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # teacache.py
2
+ import torch
3
+ import numpy as np
4
+ from typing import Optional, Dict, Union, Any
5
+ from functools import wraps
6
+
7
+ class TeaCacheConfig:
8
+ """Configuration for TeaCache acceleration"""
9
+ def __init__(
10
+ self,
11
+ rel_l1_thresh: float = 0.15,
12
+ enable: bool = True
13
+ ):
14
+ self.rel_l1_thresh = rel_l1_thresh
15
+ self.enable = enable
16
+ self._reset_state()
17
+
18
+ def _reset_state(self):
19
+ """Reset internal state"""
20
+ self.cnt = 0
21
+ self.accumulated_rel_l1_distance = 0
22
+ self.previous_modulated_input = None
23
+ self.previous_residual = None
24
+
25
+ def create_teacache_forward(original_forward):
26
+ """Factory function to create a TeaCache-enabled forward pass"""
27
+ @wraps(original_forward)
28
+ def teacache_forward(
29
+ self,
30
+ hidden_states: torch.Tensor,
31
+ timestep: torch.Tensor,
32
+ encoder_hidden_states: Optional[torch.Tensor] = None,
33
+ encoder_attention_mask: Optional[torch.Tensor] = None,
34
+ pooled_projections: Optional[torch.Tensor] = None,
35
+ guidance: Optional[torch.Tensor] = None,
36
+ attention_kwargs: Optional[Dict[str, Any]] = None,
37
+ return_dict: bool = True,
38
+ ):
39
+ # Skip TeaCache if not enabled
40
+ if not hasattr(self, 'teacache_config') or not self.teacache_config.enable:
41
+ return original_forward(
42
+ self,
43
+ hidden_states=hidden_states,
44
+ timestep=timestep,
45
+ encoder_hidden_states=encoder_hidden_states,
46
+ encoder_attention_mask=encoder_attention_mask,
47
+ pooled_projections=pooled_projections,
48
+ guidance=guidance,
49
+ attention_kwargs=attention_kwargs,
50
+ return_dict=return_dict
51
+ )
52
+
53
+ config = self.teacache_config
54
+
55
+ # Prepare modulation vectors similar to HunyuanVideo implementation
56
+ if pooled_projections is not None:
57
+ vec = self.vector_in(pooled_projections)
58
+
59
+ if guidance is not None:
60
+ if vec is None:
61
+ vec = self.guidance_in(guidance)
62
+ else:
63
+ vec = vec + self.guidance_in(guidance)
64
+
65
+ # TeaCache optimization logic
66
+ inp = hidden_states.clone()
67
+ if hasattr(self.double_blocks[0], 'img_norm1'):
68
+ # HunyuanVideo specific modulation
69
+ img_mod1_shift, img_mod1_scale, _, _, _, _ = self.double_blocks[0].img_mod(vec).chunk(6, dim=-1)
70
+ normed_inp = self.double_blocks[0].img_norm1(inp)
71
+ modulated_inp = normed_inp * (1 + img_mod1_scale) + img_mod1_shift
72
+ else:
73
+ # Fallback modulation
74
+ normed_inp = self.transformer_blocks[0].norm1(inp)
75
+ modulated_inp = normed_inp
76
+
77
+ # Determine if we should calculate or use cache
78
+ should_calc = True
79
+ if config.cnt == 0 or config.cnt == self.num_inference_steps - 1:
80
+ should_calc = True
81
+ config.accumulated_rel_l1_distance = 0
82
+ elif config.previous_modulated_input is not None:
83
+ coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01,
84
+ -3.14987800e+00, 9.61237896e-02]
85
+ rescale_func = np.poly1d(coefficients)
86
+
87
+ rel_l1 = ((modulated_inp - config.previous_modulated_input).abs().mean() /
88
+ config.previous_modulated_input.abs().mean()).cpu().item()
89
+ config.accumulated_rel_l1_distance += rescale_func(rel_l1)
90
+
91
+ should_calc = config.accumulated_rel_l1_distance >= config.rel_l1_thresh
92
+ if should_calc:
93
+ config.accumulated_rel_l1_distance = 0
94
+
95
+ config.previous_modulated_input = modulated_inp
96
+ config.cnt += 1
97
+ if config.cnt >= self.num_inference_steps:
98
+ config.cnt = 0
99
+
100
+ # Use cache or calculate new result
101
+ if not should_calc and config.previous_residual is not None:
102
+ hidden_states += config.previous_residual
103
+ else:
104
+ ori_hidden_states = hidden_states.clone()
105
+
106
+ # Use original forward pass
107
+ out = original_forward(
108
+ self,
109
+ hidden_states=hidden_states,
110
+ timestep=timestep,
111
+ encoder_hidden_states=encoder_hidden_states,
112
+ encoder_attention_mask=encoder_attention_mask,
113
+ pooled_projections=pooled_projections,
114
+ guidance=guidance,
115
+ attention_kwargs=attention_kwargs,
116
+ return_dict=True
117
+ )
118
+ hidden_states = out["sample"]
119
+
120
+ # Store residual for future use
121
+ config.previous_residual = hidden_states - ori_hidden_states
122
+
123
+ if not return_dict:
124
+ return (hidden_states,)
125
+
126
+ return {"sample": hidden_states}
127
+
128
+ return teacache_forward
129
+
130
+ def enable_teacache(model: Any, num_inference_steps: int, rel_l1_thresh: float = 0.15):
131
+ """Enable TeaCache acceleration for a model"""
132
+ if not hasattr(model, '_original_forward'):
133
+ model._original_forward = model.forward
134
+
135
+ model.teacache_config = TeaCacheConfig(rel_l1_thresh=rel_l1_thresh)
136
+ model.num_inference_steps = num_inference_steps
137
+ model.forward = create_teacache_forward(model._original_forward).__get__(model)
138
+
139
+ def disable_teacache(model: Any):
140
+ """Disable TeaCache acceleration for a model"""
141
+ if hasattr(model, '_original_forward'):
142
+ model.forward = model._original_forward
143
+ del model._original_forward
144
+
145
+ if hasattr(model, 'teacache_config'):
146
+ del model.teacache_config