fffiloni commited on
Commit
fc0a183
·
verified ·
1 Parent(s): ce67ba5

Migrated from GitHub

Browse files
Files changed (40) hide show
  1. .gitattributes +5 -0
  2. .pre-commit-config.yaml +25 -0
  3. LICENSE.txt +38 -0
  4. ORIGINAL_README.md +717 -0
  5. assets/logo2.png +0 -0
  6. assets/main_pipeline.jpg +3 -0
  7. generate_video.py +161 -0
  8. generate_video_df.py +186 -0
  9. requirements.txt +16 -0
  10. skycaptioner_v1/README.md +262 -0
  11. skycaptioner_v1/examples/data/1.mp4 +3 -0
  12. skycaptioner_v1/examples/data/2.mp4 +3 -0
  13. skycaptioner_v1/examples/data/3.mp4 +3 -0
  14. skycaptioner_v1/examples/data/4.mp4 +3 -0
  15. skycaptioner_v1/examples/test.csv +5 -0
  16. skycaptioner_v1/examples/test_result.csv +5 -0
  17. skycaptioner_v1/infer_fusion_caption.sh +9 -0
  18. skycaptioner_v1/infer_struct_caption.sh +8 -0
  19. skycaptioner_v1/requirements.txt +3 -0
  20. skycaptioner_v1/scripts/utils.py +19 -0
  21. skycaptioner_v1/scripts/vllm_fusion_caption.py +250 -0
  22. skycaptioner_v1/scripts/vllm_struct_caption.py +153 -0
  23. skyreels_v2_infer/__init__.py +1 -0
  24. skyreels_v2_infer/distributed/__init__.py +0 -0
  25. skyreels_v2_infer/distributed/xdit_context_parallel.py +286 -0
  26. skyreels_v2_infer/modules/__init__.py +69 -0
  27. skyreels_v2_infer/modules/attention.py +179 -0
  28. skyreels_v2_infer/modules/clip.py +525 -0
  29. skyreels_v2_infer/modules/t5.py +454 -0
  30. skyreels_v2_infer/modules/tokenizers.py +78 -0
  31. skyreels_v2_infer/modules/transformer.py +839 -0
  32. skyreels_v2_infer/modules/vae.py +639 -0
  33. skyreels_v2_infer/modules/xlm_roberta.py +165 -0
  34. skyreels_v2_infer/pipelines/__init__.py +5 -0
  35. skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py +427 -0
  36. skyreels_v2_infer/pipelines/image2video_pipeline.py +156 -0
  37. skyreels_v2_infer/pipelines/prompt_enhancer.py +65 -0
  38. skyreels_v2_infer/pipelines/text2video_pipeline.py +110 -0
  39. skyreels_v2_infer/scheduler/__init__.py +0 -0
  40. skyreels_v2_infer/scheduler/fm_solvers_unipc.py +759 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/main_pipeline.jpg filter=lfs diff=lfs merge=lfs -text
37
+ skycaptioner_v1/examples/data/1.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ skycaptioner_v1/examples/data/2.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ skycaptioner_v1/examples/data/3.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ skycaptioner_v1/examples/data/4.mp4 filter=lfs diff=lfs merge=lfs -text
.pre-commit-config.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/asottile/reorder-python-imports.git
3
+ rev: v3.8.3
4
+ hooks:
5
+ - id: reorder-python-imports
6
+ name: Reorder Python imports
7
+ types: [file, python]
8
+ - repo: https://github.com/psf/black.git
9
+ rev: 22.8.0
10
+ hooks:
11
+ - id: black
12
+ additional_dependencies: ['click==8.0.4']
13
+ args: [--line-length=120]
14
+ types: [file, python]
15
+ - repo: https://github.com/pre-commit/pre-commit-hooks.git
16
+ rev: v4.3.0
17
+ hooks:
18
+ - id: check-byte-order-marker
19
+ types: [file, python]
20
+ - id: trailing-whitespace
21
+ types: [file, python]
22
+ - id: end-of-file-fixer
23
+ types: [file, python]
24
+
25
+
LICENSE.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ - zh
5
+ license: other
6
+ tasks:
7
+ - text-generation
8
+
9
+ ---
10
+
11
+ <!-- markdownlint-disable first-line-h1 -->
12
+ <!-- markdownlint-disable html -->
13
+
14
+ # <span id="Terms">声明与协议/Terms and Conditions</span>
15
+
16
+ ## 声明
17
+
18
+ 我们在此声明,不要利用Skywork模型进行任何危害国家社会安全或违法的活动。另外,我们也要求使用者不要将 Skywork 模型用于未经适当安全审查和备案的互联网服务。我们希望所有的使用者都能遵守这个原则,确保科技的发展能在规范和合法的环境下进行。
19
+
20
+ 我们已经尽我们所能,来确保模型训练过程中使用的数据的合规性。然而,尽管我们已经做出了巨大的努力,但由于模型和数据的复杂性,仍有可能存在一些无法预见的问题。因此,如果由于使用skywork开源模型而导致的任何问题,包括但不限于数据安全问题、公共舆论风险,或模型被误导、滥用、传播或不当利用所带来的任何风险和问题,我们将不承担任何责任。
21
+
22
+ We hereby declare that the Skywork model should not be used for any activities that pose a threat to national or societal security or engage in unlawful actions. Additionally, we request users not to deploy the Skywork model for internet services without appropriate security reviews and records. We hope that all users will adhere to this principle to ensure that technological advancements occur in a regulated and lawful environment.
23
+
24
+ We have done our utmost to ensure the compliance of the data used during the model's training process. However, despite our extensive efforts, due to the complexity of the model and data, there may still be unpredictable risks and issues. Therefore, if any problems arise as a result of using the Skywork open-source model, including but not limited to data security issues, public opinion risks, or any risks and problems arising from the model being misled, abused, disseminated, or improperly utilized, we will not assume any responsibility.
25
+
26
+ ## 协议
27
+
28
+ 社区使用Skywork模型需要遵循[《Skywork 模型社区许可协议》](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf)。Skywork模型支持商业用途,如果您计划将Skywork模型或其衍生品用于商业目的,无需再次申请, 但请您仔细阅读[《Skywork 模型社区许可协议》](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf)并严格遵守相关条款。
29
+
30
+
31
+ The community usage of Skywork model requires [Skywork Community License](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20Community%20License.pdf). The Skywork model supports commercial use. If you plan to use the Skywork model or its derivatives for commercial purposes, you must abide by terms and conditions within [Skywork Community License](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20Community%20License.pdf).
32
+
33
+
34
+
35
+ [《Skywork 模型社区许可协议》》]:https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf
36
+
37
+
38
ORIGINAL_README.md ADDED
@@ -0,0 +1,717 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="assets/logo2.png" alt="SkyReels Logo" width="50%">
3
+ </p>
4
+
5
+ <h1 align="center">SkyReels V2: Infinite-Length Film Generative Model</h1>
6
+
7
+ <p align="center">
8
+ 📑 <a href="https://arxiv.org/pdf/2504.13074">Technical Report</a> · 👋 <a href="https://www.skyreels.ai/home?utm_campaign=github_SkyReels_V2" target="_blank">Playground</a> · 💬 <a href="https://discord.gg/PwM6NYtccQ" target="_blank">Discord</a> · 🤗 <a href="https://huggingface.co/collections/Skywork/skyreels-v2-6801b1b93df627d441d0d0d9" target="_blank">Hugging Face</a> · 🤖 <a href="https://www.modelscope.cn/collections/SkyReels-V2-f665650130b144" target="_blank">ModelScope</a>
9
+ </p>
10
+
11
+ ---
12
+ Welcome to the **SkyReels V2** repository! Here, you'll find the model weights and inference code for our infinite-length film generative models. To the best of our knowledge, it represents the first open-source video generative model employing **AutoRegressive Diffusion-Forcing architecture** that achieves the **SOTA performance** among publicly available models.
13
+
14
+
15
+ ## 🔥🔥🔥 News!!
16
+ * Apr 24, 2025: 🔥 We release the 720P models, [SkyReels-V2-DF-14B-720P](https://huggingface.co/Skywork/SkyReels-V2-DF-14B-720P) and [SkyReels-V2-I2V-14B-720P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P). The former facilitates infinite-length autoregressive video generation, and the latter focuses on Image2Video synthesis.
17
+ * Apr 21, 2025: 👋 We release the inference code and model weights of [SkyReels-V2](https://huggingface.co/collections/Skywork/skyreels-v2-6801b1b93df627d441d0d0d9) Series Models and the video captioning model [SkyCaptioner-V1](https://huggingface.co/Skywork/SkyCaptioner-V1) .
18
+ * Apr 3, 2025: 🔥 We also release [SkyReels-A2](https://github.com/SkyworkAI/SkyReels-A2). This is an open-sourced controllable video generation framework capable of assembling arbitrary visual elements.
19
+ * Feb 18, 2025: 🔥 we released [SkyReels-A1](https://github.com/SkyworkAI/SkyReels-A1). This is an open-sourced and effective framework for portrait image animation.
20
+ * Feb 18, 2025: 🔥 We released [SkyReels-V1](https://github.com/SkyworkAI/SkyReels-V1). This is the first and most advanced open-source human-centric video foundation model.
21
+
22
+ ## 🎥 Demos
23
+ <table>
24
+ <tr>
25
+ <td align="center">
26
+ <video src="https://github.com/user-attachments/assets/f6f9f9a7-5d5f-433c-9d73-d8d593b7ad25" width="100%"></video>
27
+ </td>
28
+ <td align="center">
29
+ <video src="https://github.com/user-attachments/assets/0eb13415-f4d9-4aaf-bcd3-3031851109b9" width="100%"></video>
30
+ </td>
31
+ <td align="center">
32
+ <video src="https://github.com/user-attachments/assets/dcd16603-5bf4-4786-8e4d-1ed23889d07a" width="100%"></video>
33
+ </td>
34
+ </tr>
35
+ </table>
36
+ The demos above showcase 30-second videos generated using our SkyReels-V2 Diffusion Forcing model.
37
+
38
+
39
+ ## 📑 TODO List
40
+
41
+ - [x] <a href="https://arxiv.org/pdf/2504.13074">Technical Report</a>
42
+ - [x] Checkpoints of the 14B and 1.3B Models Series
43
+ - [x] Single-GPU & Multi-GPU Inference Code
44
+ - [x] <a href="https://huggingface.co/Skywork/SkyCaptioner-V1">SkyCaptioner-V1</a>: A Video Captioning Model
45
+ - [x] Prompt Enhancer
46
+ - [ ] Diffusers integration
47
+ - [ ] Checkpoints of the 5B Models Series
48
+ - [ ] Checkpoints of the Camera Director Models
49
+ - [ ] Checkpoints of the Step & Guidance Distill Model
50
+
51
+
52
+ ## 🚀 Quickstart
53
+
54
+ #### Installation
55
+ ```shell
56
+ # clone the repository.
57
+ git clone https://github.com/SkyworkAI/SkyReels-V2
58
+ cd SkyReels-V2
59
+ # Install dependencies. Test environment uses Python 3.10.12.
60
+ pip install -r requirements.txt
61
+ ```
62
+
63
+ #### Model Download
64
+ You can download our models from Hugging Face:
65
+ <table>
66
+ <thead>
67
+ <tr>
68
+ <th>Type</th>
69
+ <th>Model Variant</th>
70
+ <th>Recommended Height/Width/Frame</th>
71
+ <th>Link</th>
72
+ </tr>
73
+ </thead>
74
+ <tbody>
75
+ <tr>
76
+ <td rowspan="5">Diffusion Forcing</td>
77
+ <td>1.3B-540P</td>
78
+ <td>544 * 960 * 97f</td>
79
+ <td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-DF-1.3B-540P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-DF-1.3B-540P">ModelScope</a></td>
80
+ </tr>
81
+ <tr>
82
+ <td>5B-540P</td>
83
+ <td>544 * 960 * 97f</td>
84
+ <td>Coming Soon</td>
85
+ </tr>
86
+ <tr>
87
+ <td>5B-720P</td>
88
+ <td>720 * 1280 * 121f</td>
89
+ <td>Coming Soon</td>
90
+ </tr>
91
+ <tr>
92
+ <td>14B-540P</td>
93
+ <td>544 * 960 * 97f</td>
94
+ <td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-DF-14B-540P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-DF-14B-540P">ModelScope</a></td>
95
+ </tr>
96
+ <tr>
97
+ <td>14B-720P</td>
98
+ <td>720 * 1280 * 121f</td>
99
+ <td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-DF-14B-720P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-DF-14B-720P">ModelScope</a></td>
100
+ </tr>
101
+ <tr>
102
+ <td rowspan="5">Text-to-Video</td>
103
+ <td>1.3B-540P</td>
104
+ <td>544 * 960 * 97f</td>
105
+ <td>Coming Soon</td>
106
+ </tr>
107
+ <tr>
108
+ <td>5B-540P</td>
109
+ <td>544 * 960 * 97f</td>
110
+ <td>Coming Soon</td>
111
+ </tr>
112
+ <tr>
113
+ <td>5B-720P</td>
114
+ <td>720 * 1280 * 121f</td>
115
+ <td>Coming Soon</td>
116
+ </tr>
117
+ <tr>
118
+ <td>14B-540P</td>
119
+ <td>544 * 960 * 97f</td>
120
+ <td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-T2V-14B-540P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-T2V-14B-540P">ModelScope</a></td>
121
+ </tr>
122
+ <tr>
123
+ <td>14B-720P</td>
124
+ <td>720 * 1280 * 121f</td>
125
+ <td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-T2V-14B-720P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-T2V-14B-720P">ModelScope</a></td>
126
+ </tr>
127
+ <tr>
128
+ <td rowspan="5">Image-to-Video</td>
129
+ <td>1.3B-540P</td>
130
+ <td>544 * 960 * 97f</td>
131
+ <td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-I2V-1.3B-540P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-I2V-1.3B-540P">ModelScope</a></td>
132
+ </tr>
133
+ <tr>
134
+ <td>5B-540P</td>
135
+ <td>544 * 960 * 97f</td>
136
+ <td>Coming Soon</td>
137
+ </tr>
138
+ <tr>
139
+ <td>5B-720P</td>
140
+ <td>720 * 1280 * 121f</td>
141
+ <td>Coming Soon</td>
142
+ </tr>
143
+ <tr>
144
+ <td>14B-540P</td>
145
+ <td>544 * 960 * 97f</td>
146
+ <td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-540P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-I2V-14B-540P">ModelScope</a></td>
147
+ </tr>
148
+ <tr>
149
+ <td>14B-720P</td>
150
+ <td>720 * 1280 * 121f</td>
151
+ <td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-I2V-14B-720P">ModelScope</a></td>
152
+ </tr>
153
+ <tr>
154
+ <td rowspan="3">Camera Director</td>
155
+ <td>5B-540P</td>
156
+ <td>544 * 960 * 97f</td>
157
+ <td>Coming Soon</td>
158
+ </tr>
159
+ <tr>
160
+ <td>5B-720P</td>
161
+ <td>720 * 1280 * 121f</td>
162
+ <td>Coming Soon</td>
163
+ </tr>
164
+ <tr>
165
+ <td>14B-720P</td>
166
+ <td>720 * 1280 * 121f</td>
167
+ <td>Coming Soon</td>
168
+ </tr>
169
+ </tbody>
170
+ </table>
171
+
172
+ After downloading, set the model path in your generation commands:
173
+
174
+
175
+ #### Single GPU Inference
176
+
177
+ - **Diffusion Forcing for Long Video Generation**
178
+
179
+ The <a href="https://arxiv.org/abs/2407.01392">**Diffusion Forcing**</a> version model allows us to generate Infinite-Length videos. This model supports both **text-to-video (T2V)** and **image-to-video (I2V)** tasks, and it can perform inference in both synchronous and asynchronous modes. Here we demonstrate 2 running scripts as examples for long video generation. If you want to adjust the inference parameters, e.g., the duration of video, inference mode, read the Note below first.
180
+
181
+ synchronous generation for 10s video
182
+ ```shell
183
+ model_id=Skywork/SkyReels-V2-DF-14B-540P
184
+ # synchronous inference
185
+ python3 generate_video_df.py \
186
+ --model_id ${model_id} \
187
+ --resolution 540P \
188
+ --ar_step 0 \
189
+ --base_num_frames 97 \
190
+ --num_frames 257 \
191
+ --overlap_history 17 \
192
+ --prompt "A graceful white swan with a curved neck and delicate feathers swimming in a serene lake at dawn, its reflection perfectly mirrored in the still water as mist rises from the surface, with the swan occasionally dipping its head into the water to feed." \
193
+ --addnoise_condition 20 \
194
+ --offload \
195
+ --teacache \
196
+ --use_ret_steps \
197
+ --teacache_thresh 0.3
198
+ ```
199
+
200
+ asynchronous generation for 30s video
201
+ ```shell
202
+ model_id=Skywork/SkyReels-V2-DF-14B-540P
203
+ # asynchronous inference
204
+ python3 generate_video_df.py \
205
+ --model_id ${model_id} \
206
+ --resolution 540P \
207
+ --ar_step 5 \
208
+ --causal_block_size 5 \
209
+ --base_num_frames 97 \
210
+ --num_frames 737 \
211
+ --overlap_history 17 \
212
+ --prompt "A graceful white swan with a curved neck and delicate feathers swimming in a serene lake at dawn, its reflection perfectly mirrored in the still water as mist rises from the surface, with the swan occasionally dipping its head into the water to feed." \
213
+ --addnoise_condition 20 \
214
+ --offload
215
+ ```
216
+
217
+ > **Note**:
218
+ > - If you want to run the **image-to-video (I2V)** task, add `--image ${image_path}` to your command and it is also better to use **text-to-video (T2V)**-like prompt which includes some descriptions of the first-frame image.
219
+ > - For long video generation, you can just switch the `--num_frames`, e.g., `--num_frames 257` for 10s video, `--num_frames 377` for 15s video, `--num_frames 737` for 30s video, `--num_frames 1457` for 60s video. The number is not strictly aligned with the logical frame number for specified time duration, but it is aligned with some training parameters, which means it may perform better. When you use asynchronous inference with causal_block_size > 1, the `--num_frames` should be carefully set.
220
+ > - You can use `--ar_step 5` to enable asynchronous inference. When asynchronous inference, `--causal_block_size 5` is recommended while it is not supposed to be set for synchronous generation. REMEMBER that the frame latent number inputted into the model in every iteration, e.g., base frame latent number (e.g., (97-1)//4+1=25 for base_num_frames=97) and (e.g., (237-97-(97-17)x1+17-1)//4+1=20 for base_num_frames=97, num_frames=237, overlap_history=17) for the last iteration, MUST be divided by causal_block_size. If you find it too hard to calculate and set proper values, just use our recommended setting above :). Asynchronous inference will take more steps to diffuse the whole sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous inference may improve the instruction following and visual consistent performance.
221
+ > - To reduce peak VRAM, just lower the `--base_num_frames`, e.g., to 77 or 57, while keeping the same generative length `--num_frames` you want to generate. This may slightly reduce video quality, and it should not be set too small.
222
+ > - `--addnoise_condition` is used to help smooth the long video generation by adding some noise to the clean condition. Too large noise can cause the inconsistency as well. 20 is a recommended value, and you may try larger ones, but it is recommended to not exceed 50.
223
+ > - Generating a 540P video using the 1.3B model requires approximately 14.7GB peak VRAM, while the same resolution video using the 14B model demands around 51.2GB peak VRAM.
224
+
225
+ - **Text To Video & Image To Video**
226
+
227
+ ```shell
228
+ # run Text-to-Video Generation
229
+ model_id=Skywork/SkyReels-V2-T2V-14B-540P
230
+ python3 generate_video.py \
231
+ --model_id ${model_id} \
232
+ --resolution 540P \
233
+ --num_frames 97 \
234
+ --guidance_scale 6.0 \
235
+ --shift 8.0 \
236
+ --fps 24 \
237
+ --prompt "A serene lake surrounded by towering mountains, with a few swans gracefully gliding across the water and sunlight dancing on the surface." \
238
+ --offload \
239
+ --teacache \
240
+ --use_ret_steps \
241
+ --teacache_thresh 0.3
242
+ ```
243
+ > **Note**:
244
+ > - When using an **image-to-video (I2V)** model, you must provide an input image using the `--image ${image_path}` parameter. The `--guidance_scale 5.0` and `--shift 3.0` is recommended for I2V model.
245
+ > - Generating a 540P video using the 1.3B model requires approximately 14.7GB peak VRAM, while the same resolution video using the 14B model demands around 43.4GB peak VRAM.
246
+
247
+
248
+ - **Prompt Enhancer**
249
+
250
+ The prompt enhancer is implemented based on <a href="https://huggingface.co/Qwen/Qwen2.5-32B-Instruct">Qwen2.5-32B-Instruct</a> and is utilized via the `--prompt_enhancer` parameter. It works ideally for short prompts, while for long prompts, it might generate an excessively lengthy prompt that could lead to over-saturation in the generative video. Note the peak memory of GPU is 64G+ if you use `--prompt_enhancer`. If you want to obtain the enhanced prompt separately, you can also run the prompt_enhancer script separately for testing. The steps are as follows:
251
+
252
+ ```shell
253
+ cd skyreels_v2_infer/pipelines
254
+ python3 prompt_enhancer.py --prompt "A serene lake surrounded by towering mountains, with a few swans gracefully gliding across the water and sunlight dancing on the surface."
255
+ ```
256
+ > **Note**:
257
+ > - `--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter.
258
+
259
+
260
+ **Advanced Configuration Options**
261
+
262
+ Below are the key parameters you can customize for video generation:
263
+
264
+ | Parameter | Recommended Value | Description |
265
+ |:----------------------:|:---------:|:-----------------------------------------:|
266
+ | --prompt | | Text description for generating your video |
267
+ | --image | | Path to input image for image-to-video generation |
268
+ | --resolution | 540P or 720P | Output video resolution (select based on model type) |
269
+ | --num_frames | 97 or 121 | Total frames to generate (**97 for 540P models**, **121 for 720P models**) |
270
+ | --inference_steps | 50 | Number of denoising steps |
271
+ | --fps | 24 | Frames per second in the output video |
272
+ | --shift | 8.0 or 5.0 | Flow matching scheduler parameter (**8.0 for T2V**, **5.0 for I2V**) |
273
+ | --guidance_scale | 6.0 or 5.0 | Controls text adherence strength (**6.0 for T2V**, **5.0 for I2V**) |
274
+ | --seed | | Fixed seed for reproducible results (omit for random generation) |
275
+ | --offload | True | Offloads model components to CPU to reduce VRAM usage (recommended) |
276
+ | --use_usp | True | Enables multi-GPU acceleration with xDiT USP |
277
+ | --outdir | ./video_out | Directory where generated videos will be saved |
278
+ | --prompt_enhancer | True | Expand the prompt into a more detailed description |
279
+ | --teacache | False | Enables teacache for faster inference |
280
+ | --teacache_thresh | 0.2 | Higher speedup will cause to worse quality |
281
+ | --use_ret_steps | False | Retention Steps for teacache |
282
+
283
+ **Diffusion Forcing Additional Parameters**
284
+ | Parameter | Recommended Value | Description |
285
+ |:----------------------:|:---------:|:-----------------------------------------:|
286
+ | --ar_step | 0 | Controls asynchronous inference (0 for synchronous mode) |
287
+ | --base_num_frames | 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**) |
288
+ | --overlap_history | 17 | Number of frames to overlap for smooth transitions in long videos |
289
+ | --addnoise_condition | 20 | Improves consistency in long video generation |
290
+ | --causal_block_size | 5 | Recommended when using asynchronous inference (--ar_step > 0) |
291
+
292
+ #### Multi-GPU inference using xDiT USP
293
+
294
+ We use [xDiT](https://github.com/xdit-project/xDiT) USP to accelerate inference. For example, to generate a video with 2 GPUs, you can use the following command:
295
+ - **Diffusion Forcing**
296
+ ```shell
297
+ model_id=Skywork/SkyReels-V2-DF-14B-540P
298
+ # diffusion forcing synchronous inference
299
+ torchrun --nproc_per_node=2 generate_video_df.py \
300
+ --model_id ${model_id} \
301
+ --resolution 540P \
302
+ --ar_step 0 \
303
+ --base_num_frames 97 \
304
+ --num_frames 257 \
305
+ --overlap_history 17 \
306
+ --prompt "A graceful white swan with a curved neck and delicate feathers swimming in a serene lake at dawn, its reflection perfectly mirrored in the still water as mist rises from the surface, with the swan occasionally dipping its head into the water to feed." \
307
+ --addnoise_condition 20 \
308
+ --use_usp \
309
+ --offload \
310
+ --seed 42
311
+ ```
312
+ - **Text To Video & Image To Video**
313
+ ```shell
314
+ # run Text-to-Video Generation
315
+ model_id=Skywork/SkyReels-V2-T2V-14B-540P
316
+ torchrun --nproc_per_node=2 generate_video.py \
317
+ --model_id ${model_id} \
318
+ --resolution 540P \
319
+ --num_frames 97 \
320
+ --guidance_scale 6.0 \
321
+ --shift 8.0 \
322
+ --fps 24 \
323
+ --offload \
324
+ --prompt "A serene lake surrounded by towering mountains, with a few swans gracefully gliding across the water and sunlight dancing on the surface." \
325
+ --use_usp \
326
+ --seed 42
327
+ ```
328
+ > **Note**:
329
+ > - When using an **image-to-video (I2V)** model, you must provide an input image using the `--image ${image_path}` parameter. The `--guidance_scale 5.0` and `--shift 3.0` is recommended for I2V model.
330
+
331
+
332
+ ## Contents
333
+ - [Abstract](#abstract)
334
+ - [Methodology of SkyReels-V2](#methodology-of-skyreels-v2)
335
+ - [Key Contributions of SkyReels-V2](#key-contributions-of-skyreels-v2)
336
+ - [Video Captioner](#video-captioner)
337
+ - [Reinforcement Learning](#reinforcement-learning)
338
+ - [Diffusion Forcing](#diffusion-forcing)
339
+ - [High-Quality Supervised Fine-Tuning(SFT)](#high-quality-supervised-fine-tuning-sft)
340
+ - [Performance](#performance)
341
+ - [Acknowledgements](#acknowledgements)
342
+ - [Citation](#citation)
343
+ ---
344
+
345
+ ## Abstract
346
+ Recent advances in video generation have been driven by diffusion models and autoregressive frameworks, yet critical challenges persist in harmonizing prompt adherence, visual quality, motion dynamics, and duration: compromises in motion dynamics to enhance temporal visual quality, constrained video duration (5-10 seconds) to prioritize resolution, and inadequate shot-aware generation stemming from general-purpose MLLMs' inability to interpret cinematic grammar, such as shot composition, actor expressions, and camera motions. These intertwined limitations hinder realistic long-form synthesis and professional film-style generation.
347
+
348
+ To address these limitations, we introduce SkyReels-V2, the world's first infinite-length film generative model using a Diffusion Forcing framework. Our approach synergizes Multi-modal Large Language Models (MLLM), Multi-stage Pretraining, Reinforcement Learning, and Diffusion Forcing techniques to achieve comprehensive optimization. Beyond its technical innovations, SkyReels-V2 enables multiple practical applications, including Story Generation, Image-to-Video Synthesis, Camera Director functionality, and multi-subject consistent video generation through our <a href="https://github.com/SkyworkAI/SkyReels-A2">Skyreels-A2</a> system.
349
+
350
+ ## Methodology of SkyReels-V2
351
+
352
+ The SkyReels-V2 methodology consists of several interconnected components. It starts with a comprehensive data processing pipeline that prepares various quality training data. At its core is the Video Captioner architecture, which provides detailed annotations for video content. The system employs a multi-task pretraining strategy to build fundamental video generation capabilities. Post-training optimization includes Reinforcement Learning to enhance motion quality, Diffusion Forcing Training for generating extended videos, and High-quality Supervised Fine-Tuning (SFT) stages for visual refinement. The model runs on optimized computational infrastructure for efficient training and inference. SkyReels-V2 supports multiple applications, including Story Generation, Image-to-Video Synthesis, Camera Director functionality, and Elements-to-Video Generation.
353
+
354
+ <p align="center">
355
+ <img src="assets/main_pipeline.jpg" alt="mainpipeline" width="100%">
356
+ </p>
357
+
358
+ ## Key Contributions of SkyReels-V2
359
+
360
+ #### Video Captioner
361
+
362
+ <a href="https://huggingface.co/Skywork/SkyCaptioner-V1">SkyCaptioner-V1</a> serves as our video captioning model for data annotation. This model is trained on the captioning result from the base model <a href="https://huggingface.co/Qwen/Qwen2.5-VL-72B-Instruct">Qwen2.5-VL-72B-Instruct</a> and the sub-expert captioners on a balanced video data. The balanced video data is a carefully curated dataset of approximately 2 million videos to ensure conceptual balance and annotation quality. Built upon the <a href="https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct">Qwen2.5-VL-7B-Instruct</a> foundation model, <a href="https://huggingface.co/Skywork/SkyCaptioner-V1">SkyCaptioner-V1</a> is fine-tuned to enhance performance in domain-specific video captioning tasks. To compare the performance with the SOTA models, we conducted a manual assessment of accuracy across different captioning fields using a test set of 1,000 samples. The proposed <a href="https://huggingface.co/Skywork/SkyCaptioner-V1">SkyCaptioner-V1</a> achieves the highest average accuracy among the baseline models, and show a dramatic result in the shot related fields
363
+
364
+ <p align="center">
365
+ <table align="center">
366
+ <thead>
367
+ <tr>
368
+ <th>model</th>
369
+ <th><a href="https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct">Qwen2.5-VL-7B-Ins.</a></th>
370
+ <th><a href="https://huggingface.co/Qwen/Qwen2.5-VL-72B-Instruct">Qwen2.5-VL-72B-Ins.</a></th>
371
+ <th><a href="https://huggingface.co/omni-research/Tarsier2-Recap-7b">Tarsier2-Recap-7b</a></th>
372
+ <th><a href="https://huggingface.co/Skywork/SkyCaptioner-V1">SkyCaptioner-V1</th>
373
+ </tr>
374
+ </thead>
375
+ <tbody>
376
+ <tr>
377
+ <td>Avg accuracy</td>
378
+ <td>51.4%</td>
379
+ <td>58.7%</td>
380
+ <td>49.4%</td>
381
+ <td><strong>76.3%</strong></td>
382
+ </tr>
383
+ <tr>
384
+ <td>shot type</td>
385
+ <td>76.8%</td>
386
+ <td>82.5%</td>
387
+ <td>60.2%</td>
388
+ <td><strong>93.7%</strong></td>
389
+ </tr>
390
+ <tr>
391
+ <td>shot angle</td>
392
+ <td>60.0%</td>
393
+ <td>73.7%</td>
394
+ <td>52.4%</td>
395
+ <td><strong>89.8%</strong></td>
396
+ </tr>
397
+ <tr>
398
+ <td>shot position</td>
399
+ <td>28.4%</td>
400
+ <td>32.7%</td>
401
+ <td>23.6%</td>
402
+ <td><strong>83.1%</strong></td>
403
+ </tr>
404
+ <tr>
405
+ <td>camera motion</td>
406
+ <td>62.0%</td>
407
+ <td>61.2%</td>
408
+ <td>45.3%</td>
409
+ <td><strong>85.3%</strong></td>
410
+ </tr>
411
+ <tr>
412
+ <td>expression</td>
413
+ <td>43.6%</td>
414
+ <td>51.5%</td>
415
+ <td>54.3%</td>
416
+ <td><strong>68.8%</strong></td>
417
+ </tr>
418
+ <tr>
419
+ <td colspan="5" style="text-align: center; border-bottom: 1px solid #ddd; padding: 8px;"></td>
420
+ </tr>
421
+ <tr>
422
+ <td>TYPES_type</td>
423
+ <td>43.5%</td>
424
+ <td>49.7%</td>
425
+ <td>47.6%</td>
426
+ <td><strong>82.5%</strong></td>
427
+ </tr>
428
+ <tr>
429
+ <td>TYPES_sub_type</td>
430
+ <td>38.9%</td>
431
+ <td>44.9%</td>
432
+ <td>45.9%</td>
433
+ <td><strong>75.4%</strong></td>
434
+ </tr>
435
+ <tr>
436
+ <td>appearance</td>
437
+ <td>40.9%</td>
438
+ <td>52.0%</td>
439
+ <td>45.6%</td>
440
+ <td><strong>59.3%</strong></td>
441
+ </tr>
442
+ <tr>
443
+ <td>action</td>
444
+ <td>32.4%</td>
445
+ <td>52.0%</td>
446
+ <td><strong>69.8%</strong></td>
447
+ <td>68.8%</td>
448
+ </tr>
449
+ <tr>
450
+ <td>position</td>
451
+ <td>35.4%</td>
452
+ <td>48.6%</td>
453
+ <td>45.5%</td>
454
+ <td><strong>57.5%</strong></td>
455
+ </tr>
456
+ <tr>
457
+ <td>is_main_subject</td>
458
+ <td>58.5%</td>
459
+ <td>68.7%</td>
460
+ <td>69.7%</td>
461
+ <td><strong>80.9%</strong></td>
462
+ </tr>
463
+ <tr>
464
+ <td>environment</td>
465
+ <td>70.4%</td>
466
+ <td><strong>72.7%</strong></td>
467
+ <td>61.4%</td>
468
+ <td>70.5%</td>
469
+ </tr>
470
+ <tr>
471
+ <td>lighting</td>
472
+ <td>77.1%</td>
473
+ <td><strong>80.0%</strong></td>
474
+ <td>21.2%</td>
475
+ <td>76.5%</td>
476
+ </tr>
477
+ </tbody>
478
+ </table>
479
+ </p>
480
+
481
+ #### Reinforcement Learning
482
+ Inspired by the previous success in LLM, we propose to enhance the performance of the generative model by Reinforcement Learning. Specifically, we focus on the motion quality because we find that the main drawback of our generative model is:
483
+
484
+ - the generative model does not handle well with large, deformable motions.
485
+ - the generated videos may violate the physical law.
486
+
487
+ To avoid the degradation in other metrics, such as text alignment and video quality, we ensure the preference data pairs have comparable text alignment and video quality, while only the motion quality varies. This requirement poses greater challenges in obtaining preference annotations due to the inherently higher costs of human annotation. To address this challenge, we propose a semi-automatic pipeline that strategically combines automatically generated motion pairs and human annotation results. This hybrid approach not only enhances the data scale but also improves alignment with human preferences through curated quality control. Leveraging this enhanced dataset, we first train a specialized reward model to capture the generic motion quality differences between paired samples. This learned reward function subsequently guides the sample selection process for Direct Preference Optimization (DPO), enhancing the motion quality of the generative model.
488
+
489
+ #### Diffusion Forcing
490
+
491
+ We introduce the Diffusion Forcing Transformer to unlock our model’s ability to generate long videos. Diffusion Forcing is a training and sampling strategy where each token is assigned an independent noise level. This allows tokens to be denoised according to arbitrary, per-token schedules. Conceptually, this approach functions as a form of partial masking: a token with zero noise is fully unmasked, while complete noise fully masks it. Diffusion Forcing trains the model to "unmask" any combination of variably noised tokens, using the cleaner tokens as conditional information to guide the recovery of noisy ones. Building on this, our Diffusion Forcing Transformer can extend video generation indefinitely based on the last frames of the previous segment. Note that the synchronous full sequence diffusion is a special case of Diffusion Forcing, where all tokens share the same noise level. This relationship allows us to fine-tune the Diffusion Forcing Transformer from a full-sequence diffusion model.
492
+
493
+ #### High-Quality Supervised Fine-Tuning (SFT)
494
+
495
+ We implement two sequential high-quality supervised fine-tuning (SFT) stages at 540p and 720p resolutions respectively, with the initial SFT phase conducted immediately after pretraining but prior to reinforcement learning (RL) stage.This first-stage SFT serves as a conceptual equilibrium trainer, building upon the foundation model’s pretraining outcomes that utilized only fps24 video data, while strategically removing FPS embedding components to streamline thearchitecture. Trained with the high-quality concept-balanced samples, this phase establishes optimized initialization parameters for subsequent training processes. Following this, we execute a secondary high-resolution SFT at 720p after completing the diffusion forcing stage, incorporating identical loss formulations and the higher-quality concept-balanced datasets by the manually filter. This final refinement phase focuses on resolution increase such that the overall video quality will be further enhanced.
496
+
497
+ ## Performance
498
+
499
+ To comprehensively evaluate our proposed method, we construct the SkyReels-Bench for human assessment and leveraged the open-source <a href="https://github.com/Vchitect/VBench">V-Bench</a> for automated evaluation. This allows us to compare our model with the state-of-the-art (SOTA) baselines, including both open-source and proprietary models.
500
+
501
+ #### Human Evaluation
502
+
503
+ For human evaluation, we design SkyReels-Bench with 1,020 text prompts, systematically assessing three dimensions: Instruction Adherence, Motion Quality, Consistency and Visual Quality. This benchmark is designed to evaluate both text-to-video (T2V) and image-to-video (I2V) generation models, providing comprehensive assessment across different generation paradigms. To ensure fairness, all models were evaluated under default settings with consistent resolutions, and no post-generation filtering was applied.
504
+
505
+ - Text To Video Models
506
+
507
+ <p align="center">
508
+ <table align="center">
509
+ <thead>
510
+ <tr>
511
+ <th>Model Name</th>
512
+ <th>Average</th>
513
+ <th>Instruction Adherence</th>
514
+ <th>Consistency</th>
515
+ <th>Visual Quality</th>
516
+ <th>Motion Quality</th>
517
+ </tr>
518
+ </thead>
519
+ <tbody>
520
+ <tr>
521
+ <td><a href="https://runwayml.com/research/introducing-gen-3-alpha">Runway-Gen3 Alpha</a></td>
522
+ <td>2.53</td>
523
+ <td>2.19</td>
524
+ <td>2.57</td>
525
+ <td>3.23</td>
526
+ <td>2.11</td>
527
+ </tr>
528
+ <tr>
529
+ <td><a href="https://github.com/Tencent/HunyuanVideo">HunyuanVideo-13B</a></td>
530
+ <td>2.82</td>
531
+ <td>2.64</td>
532
+ <td>2.81</td>
533
+ <td>3.20</td>
534
+ <td>2.61</td>
535
+ </tr>
536
+ <tr>
537
+ <td><a href="https://klingai.com">Kling-1.6 STD Mode</a></td>
538
+ <td>2.99</td>
539
+ <td>2.77</td>
540
+ <td>3.05</td>
541
+ <td>3.39</td>
542
+ <td><strong>2.76</strong></td>
543
+ </tr>
544
+ <tr>
545
+ <td><a href="https://hailuoai.video">Hailuo-01</a></td>
546
+ <td>3.0</td>
547
+ <td>2.8</td>
548
+ <td>3.08</td>
549
+ <td>3.29</td>
550
+ <td>2.74</td>
551
+ </tr>
552
+ <tr>
553
+ <td><a href="https://github.com/Wan-Video/Wan2.1">Wan2.1-14B</a></td>
554
+ <td>3.12</td>
555
+ <td>2.91</td>
556
+ <td>3.31</td>
557
+ <td><strong>3.54</strong></td>
558
+ <td>2.71</td>
559
+ </tr>
560
+ <tr>
561
+ <td>SkyReels-V2</td>
562
+ <td><strong>3.14</strong></td>
563
+ <td><strong>3.15</strong></td>
564
+ <td><strong>3.35</strong></td>
565
+ <td>3.34</td>
566
+ <td>2.74</td>
567
+ </tr>
568
+ </tbody>
569
+ </table>
570
+ </p>
571
+
572
+ The evaluation demonstrates that our model achieves significant advancements in **instruction adherence (3.15)** compared to baseline methods, while maintaining competitive performance in **motion quality (2.74)** without sacrificing the **consistency (3.35)**.
573
+
574
+ - Image To Video Models
575
+
576
+ <p align="center">
577
+ <table align="center">
578
+ <thead>
579
+ <tr>
580
+ <th>Model</th>
581
+ <th>Average</th>
582
+ <th>Instruction Adherence</th>
583
+ <th>Consistency</th>
584
+ <th>Visual Quality</th>
585
+ <th>Motion Quality</th>
586
+ </tr>
587
+ </thead>
588
+ <tbody>
589
+ <tr>
590
+ <td><a href="https://github.com/Tencent/HunyuanVideo">HunyuanVideo-13B</a></td>
591
+ <td>2.84</td>
592
+ <td>2.97</td>
593
+ <td>2.95</td>
594
+ <td>2.87</td>
595
+ <td>2.56</td>
596
+ </tr>
597
+ <tr>
598
+ <td><a href="https://github.com/Wan-Video/Wan2.1">Wan2.1-14B</a></td>
599
+ <td>2.85</td>
600
+ <td>3.10</td>
601
+ <td>2.81</td>
602
+ <td>3.00</td>
603
+ <td>2.48</td>
604
+ </tr>
605
+ <tr>
606
+ <td><a href="https://hailuoai.video">Hailuo-01</a></td>
607
+ <td>3.05</td>
608
+ <td>3.31</td>
609
+ <td>2.58</td>
610
+ <td>3.55</td>
611
+ <td>2.74</td>
612
+ </tr>
613
+ <tr>
614
+ <td><a href="https://klingai.com">Kling-1.6 Pro Mode</a></td>
615
+ <td>3.4</td>
616
+ <td>3.56</td>
617
+ <td>3.03</td>
618
+ <td>3.58</td>
619
+ <td>3.41</td>
620
+ </tr>
621
+ <tr>
622
+ <td><a href="https://runwayml.com/research/introducing-runway-gen-4">Runway-Gen4</a></td>
623
+ <td>3.39</td>
624
+ <td>3.75</td>
625
+ <td>3.2</td>
626
+ <td>3.4</td>
627
+ <td>3.37</td>
628
+ </tr>
629
+ <tr>
630
+ <td>SkyReels-V2-DF</td>
631
+ <td>3.24</td>
632
+ <td>3.64</td>
633
+ <td>3.21</td>
634
+ <td>3.18</td>
635
+ <td>2.93</td>
636
+ </tr>
637
+ <tr>
638
+ <td>SkyReels-V2-I2V</td>
639
+ <td>3.29</td>
640
+ <td>3.42</td>
641
+ <td>3.18</td>
642
+ <td>3.56</td>
643
+ <td>3.01</td>
644
+ </tr>
645
+ </tbody>
646
+ </table>
647
+ </p>
648
+
649
+ Our results demonstrate that both **SkyReels-V2-I2V (3.29)** and **SkyReels-V2-DF (3.24)** achieve state-of-the-art performance among open-source models, significantly outperforming HunyuanVideo-13B (2.84) and Wan2.1-14B (2.85) across all quality dimensions. With an average score of 3.29, SkyReels-V2-I2V demonstrates comparable performance to proprietary models Kling-1.6 (3.4) and Runway-Gen4 (3.39).
650
+
651
+
652
+ #### VBench
653
+ To objectively compare SkyReels-V2 Model against other leading open-source Text-To-Video models, we conduct comprehensive evaluations using the public benchmark <a href="https://github.com/Vchitect/VBench">V-Bench</a>. Our evaluation specifically leverages the benchmark’s longer version prompt. For fair comparison with baseline models, we strictly follow their recommended setting for inference.
654
+
655
+ <p align="center">
656
+ <table align="center">
657
+ <thead>
658
+ <tr>
659
+ <th>Model</th>
660
+ <th>Total Score</th>
661
+ <th>Quality Score</th>
662
+ <th>Semantic Score</th>
663
+ </tr>
664
+ </thead>
665
+ <tbody>
666
+ <tr>
667
+ <td><a href="https://github.com/hpcaitech/Open-Sora">OpenSora 2.0</a></td>
668
+ <td>81.5 %</td>
669
+ <td>82.1 %</td>
670
+ <td>78.2 %</td>
671
+ </tr>
672
+ <tr>
673
+ <td><a href="https://github.com/THUDM/CogVideo">CogVideoX1.5-5B</a></td>
674
+ <td>80.3 %</td>
675
+ <td>80.9 %</td>
676
+ <td>77.9 %</td>
677
+ </tr>
678
+ <tr>
679
+ <td><a href="https://github.com/Tencent/HunyuanVideo">HunyuanVideo-13B</a></td>
680
+ <td>82.7 %</td>
681
+ <td>84.4 %</td>
682
+ <td>76.2 %</td>
683
+ </tr>
684
+ <tr>
685
+ <td><a href="https://github.com/Wan-Video/Wan2.1">Wan2.1-14B</a></td>
686
+ <td>83.7 %</td>
687
+ <td>84.2 %</td>
688
+ <td><strong>81.4 %</strong></td>
689
+ </tr>
690
+ <tr>
691
+ <td>SkyReels-V2</td>
692
+ <td><strong>83.9 %</strong></td>
693
+ <td><strong>84.7 %</strong></td>
694
+ <td>80.8 %</td>
695
+ </tr>
696
+ </tbody>
697
+ </table>
698
+ </p>
699
+
700
+ The VBench results demonstrate that SkyReels-V2 outperforms all compared models including HunyuanVideo-13B and Wan2.1-14B, With the highest **total score (83.9%)** and **quality score (84.7%)**. In this evaluation, the semantic score is slightly lower than Wan2.1-14B, while we outperform Wan2.1-14B in human evaluations, with the primary gap attributed to V-Bench’s insufficient evaluation of shot-scenario semantic adherence.
701
+
702
+ ## Acknowledgements
703
+ We would like to thank the contributors of <a href="https://github.com/Wan-Video/Wan2.1">Wan 2.1</a>, <a href="https://github.com/xdit-project/xDiT">XDit</a> and <a href="https://qwenlm.github.io/blog/qwen2.5/">Qwen 2.5</a> repositories, for their open research and contributions.
704
+
705
+ ## Citation
706
+
707
+ ```bibtex
708
+ @misc{chen2025skyreelsv2infinitelengthfilmgenerative,
709
+ title={SkyReels-V2: Infinite-length Film Generative Model},
710
+ author={Guibin Chen and Dixuan Lin and Jiangping Yang and Chunze Lin and Junchen Zhu and Mingyuan Fan and Hao Zhang and Sheng Chen and Zheng Chen and Chengcheng Ma and Weiming Xiong and Wei Wang and Nuo Pang and Kang Kang and Zhiheng Xu and Yuzhe Jin and Yupeng Liang and Yubing Song and Peng Zhao and Boyuan Xu and Di Qiu and Debang Li and Zhengcong Fei and Yang Li and Yahui Zhou},
711
+ year={2025},
712
+ eprint={2504.13074},
713
+ archivePrefix={arXiv},
714
+ primaryClass={cs.CV},
715
+ url={https://arxiv.org/abs/2504.13074},
716
+ }
717
+ ```
assets/logo2.png ADDED
assets/main_pipeline.jpg ADDED

Git LFS Details

  • SHA256: e8fd982dd51a3edd0a1ce451b391526c1a51ea94ae68f5ed79380173dbcce7fb
  • Pointer size: 131 Bytes
  • Size of remote file: 183 kB
generate_video.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gc
3
+ import os
4
+ import random
5
+ import time
6
+
7
+ import imageio
8
+ import torch
9
+ from diffusers.utils import load_image
10
+
11
+ from skyreels_v2_infer.modules import download_model
12
+ from skyreels_v2_infer.pipelines import Image2VideoPipeline
13
+ from skyreels_v2_infer.pipelines import PromptEnhancer
14
+ from skyreels_v2_infer.pipelines import resizecrop
15
+ from skyreels_v2_infer.pipelines import Text2VideoPipeline
16
+
17
+ MODEL_ID_CONFIG = {
18
+ "text2video": [
19
+ "Skywork/SkyReels-V2-T2V-14B-540P",
20
+ "Skywork/SkyReels-V2-T2V-14B-720P",
21
+ ],
22
+ "image2video": [
23
+ "Skywork/SkyReels-V2-I2V-1.3B-540P",
24
+ "Skywork/SkyReels-V2-I2V-14B-540P",
25
+ "Skywork/SkyReels-V2-I2V-14B-720P",
26
+ ],
27
+ }
28
+
29
+
30
+ if __name__ == "__main__":
31
+
32
+ parser = argparse.ArgumentParser()
33
+ parser.add_argument("--outdir", type=str, default="video_out")
34
+ parser.add_argument("--model_id", type=str, default="Skywork/SkyReels-V2-T2V-14B-540P")
35
+ parser.add_argument("--resolution", type=str, choices=["540P", "720P"])
36
+ parser.add_argument("--num_frames", type=int, default=97)
37
+ parser.add_argument("--image", type=str, default=None)
38
+ parser.add_argument("--guidance_scale", type=float, default=6.0)
39
+ parser.add_argument("--shift", type=float, default=8.0)
40
+ parser.add_argument("--inference_steps", type=int, default=30)
41
+ parser.add_argument("--use_usp", action="store_true")
42
+ parser.add_argument("--offload", action="store_true")
43
+ parser.add_argument("--fps", type=int, default=24)
44
+ parser.add_argument("--seed", type=int, default=None)
45
+ parser.add_argument(
46
+ "--prompt",
47
+ type=str,
48
+ default="A serene lake surrounded by towering mountains, with a few swans gracefully gliding across the water and sunlight dancing on the surface.",
49
+ )
50
+ parser.add_argument("--prompt_enhancer", action="store_true")
51
+ parser.add_argument("--teacache", action="store_true")
52
+ parser.add_argument(
53
+ "--teacache_thresh",
54
+ type=float,
55
+ default=0.2,
56
+ help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup")
57
+ parser.add_argument(
58
+ "--use_ret_steps",
59
+ action="store_true",
60
+ help="Using Retention Steps will result in faster generation speed and better generation quality.")
61
+ args = parser.parse_args()
62
+
63
+ args.model_id = download_model(args.model_id)
64
+ print("model_id:", args.model_id)
65
+
66
+ assert (args.use_usp and args.seed is not None) or (not args.use_usp), "usp mode need seed"
67
+ if args.seed is None:
68
+ random.seed(time.time())
69
+ args.seed = int(random.randrange(4294967294))
70
+
71
+ if args.resolution == "540P":
72
+ height = 544
73
+ width = 960
74
+ elif args.resolution == "720P":
75
+ height = 720
76
+ width = 1280
77
+ else:
78
+ raise ValueError(f"Invalid resolution: {args.resolution}")
79
+
80
+ image = load_image(args.image).convert("RGB") if args.image else None
81
+ negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
82
+ local_rank = 0
83
+ if args.use_usp:
84
+ assert not args.prompt_enhancer, "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter."
85
+ from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
86
+ import torch.distributed as dist
87
+
88
+ dist.init_process_group("nccl")
89
+ local_rank = dist.get_rank()
90
+ torch.cuda.set_device(dist.get_rank())
91
+ device = "cuda"
92
+
93
+ init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
94
+
95
+ initialize_model_parallel(
96
+ sequence_parallel_degree=dist.get_world_size(),
97
+ ring_degree=1,
98
+ ulysses_degree=dist.get_world_size(),
99
+ )
100
+
101
+ prompt_input = args.prompt
102
+ if args.prompt_enhancer and args.image is None:
103
+ print(f"init prompt enhancer")
104
+ prompt_enhancer = PromptEnhancer()
105
+ prompt_input = prompt_enhancer(prompt_input)
106
+ print(f"enhanced prompt: {prompt_input}")
107
+ del prompt_enhancer
108
+ gc.collect()
109
+ torch.cuda.empty_cache()
110
+
111
+ if image is None:
112
+ assert "T2V" in args.model_id, f"check model_id:{args.model_id}"
113
+ print("init text2video pipeline")
114
+ pipe = Text2VideoPipeline(
115
+ model_path=args.model_id, dit_path=args.model_id, use_usp=args.use_usp, offload=args.offload
116
+ )
117
+ else:
118
+ assert "I2V" in args.model_id, f"check model_id:{args.model_id}"
119
+ print("init img2video pipeline")
120
+ pipe = Image2VideoPipeline(
121
+ model_path=args.model_id, dit_path=args.model_id, use_usp=args.use_usp, offload=args.offload
122
+ )
123
+ args.image = load_image(args.image)
124
+ image_width, image_height = args.image.size
125
+ if image_height > image_width:
126
+ height, width = width, height
127
+ args.image = resizecrop(args.image, height, width)
128
+
129
+ if args.teacache:
130
+ pipe.transformer.initialize_teacache(enable_teacache=True, num_steps=args.inference_steps,
131
+ teacache_thresh=args.teacache_thresh, use_ret_steps=args.use_ret_steps,
132
+ ckpt_dir=args.model_id)
133
+
134
+
135
+ kwargs = {
136
+ "prompt": prompt_input,
137
+ "negative_prompt": negative_prompt,
138
+ "num_frames": args.num_frames,
139
+ "num_inference_steps": args.inference_steps,
140
+ "guidance_scale": args.guidance_scale,
141
+ "shift": args.shift,
142
+ "generator": torch.Generator(device="cuda").manual_seed(args.seed),
143
+ "height": height,
144
+ "width": width,
145
+ }
146
+
147
+ if image is not None:
148
+ kwargs["image"] = args.image.convert("RGB")
149
+
150
+ save_dir = os.path.join("result", args.outdir)
151
+ os.makedirs(save_dir, exist_ok=True)
152
+
153
+ with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad():
154
+ print(f"infer kwargs:{kwargs}")
155
+ video_frames = pipe(**kwargs)[0]
156
+
157
+ if local_rank == 0:
158
+ current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
159
+ video_out_file = f"{args.prompt[:100].replace('/','')}_{args.seed}_{current_time}.mp4"
160
+ output_path = os.path.join(save_dir, video_out_file)
161
+ imageio.mimwrite(output_path, video_frames, fps=args.fps, quality=8, output_params=["-loglevel", "error"])
generate_video_df.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gc
3
+ import os
4
+ import random
5
+ import time
6
+
7
+ import imageio
8
+ import torch
9
+ from diffusers.utils import load_image
10
+
11
+ from skyreels_v2_infer import DiffusionForcingPipeline
12
+ from skyreels_v2_infer.modules import download_model
13
+ from skyreels_v2_infer.pipelines import PromptEnhancer
14
+ from skyreels_v2_infer.pipelines import resizecrop
15
+
16
+ if __name__ == "__main__":
17
+
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument("--outdir", type=str, default="diffusion_forcing")
20
+ parser.add_argument("--model_id", type=str, default="Skywork/SkyReels-V2-DF-1.3B-540P")
21
+ parser.add_argument("--resolution", type=str, choices=["540P", "720P"])
22
+ parser.add_argument("--num_frames", type=int, default=97)
23
+ parser.add_argument("--image", type=str, default=None)
24
+ parser.add_argument("--ar_step", type=int, default=0)
25
+ parser.add_argument("--causal_attention", action="store_true")
26
+ parser.add_argument("--causal_block_size", type=int, default=1)
27
+ parser.add_argument("--base_num_frames", type=int, default=97)
28
+ parser.add_argument("--overlap_history", type=int, default=None)
29
+ parser.add_argument("--addnoise_condition", type=int, default=0)
30
+ parser.add_argument("--guidance_scale", type=float, default=6.0)
31
+ parser.add_argument("--shift", type=float, default=8.0)
32
+ parser.add_argument("--inference_steps", type=int, default=30)
33
+ parser.add_argument("--use_usp", action="store_true")
34
+ parser.add_argument("--offload", action="store_true")
35
+ parser.add_argument("--fps", type=int, default=24)
36
+ parser.add_argument("--seed", type=int, default=None)
37
+ parser.add_argument(
38
+ "--prompt",
39
+ type=str,
40
+ default="A woman in a leather jacket and sunglasses riding a vintage motorcycle through a desert highway at sunset, her hair blowing wildly in the wind as the motorcycle kicks up dust, with the golden sun casting long shadows across the barren landscape.",
41
+ )
42
+ parser.add_argument("--prompt_enhancer", action="store_true")
43
+ parser.add_argument("--teacache", action="store_true")
44
+ parser.add_argument(
45
+ "--teacache_thresh",
46
+ type=float,
47
+ default=0.2,
48
+ help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup",
49
+ )
50
+ parser.add_argument(
51
+ "--use_ret_steps",
52
+ action="store_true",
53
+ help="Using Retention Steps will result in faster generation speed and better generation quality.",
54
+ )
55
+ args = parser.parse_args()
56
+
57
+ args.model_id = download_model(args.model_id)
58
+ print("model_id:", args.model_id)
59
+
60
+ assert (args.use_usp and args.seed is not None) or (not args.use_usp), "usp mode need seed"
61
+ if args.seed is None:
62
+ random.seed(time.time())
63
+ args.seed = int(random.randrange(4294967294))
64
+
65
+ if args.resolution == "540P":
66
+ height = 544
67
+ width = 960
68
+ elif args.resolution == "720P":
69
+ height = 720
70
+ width = 1280
71
+ else:
72
+ raise ValueError(f"Invalid resolution: {args.resolution}")
73
+
74
+ num_frames = args.num_frames
75
+ fps = args.fps
76
+
77
+ if num_frames > args.base_num_frames:
78
+ assert (
79
+ args.overlap_history is not None
80
+ ), 'You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.'
81
+ if args.addnoise_condition > 60:
82
+ print(
83
+ f'You have set "addnoise_condition" as {args.addnoise_condition}. The value is too large which can cause inconsistency in long video generation. The value is recommanded to set 20.'
84
+ )
85
+
86
+ guidance_scale = args.guidance_scale
87
+ shift = args.shift
88
+ if args.image:
89
+ args.image = load_image(args.image)
90
+ image_width, image_height = args.image.size
91
+ if image_height > image_width:
92
+ height, width = width, height
93
+ args.image = resizecrop(args.image, height, width)
94
+ image = args.image.convert("RGB") if args.image else None
95
+ negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
96
+
97
+ save_dir = os.path.join("result", args.outdir)
98
+ os.makedirs(save_dir, exist_ok=True)
99
+ local_rank = 0
100
+ if args.use_usp:
101
+ assert (
102
+ not args.prompt_enhancer
103
+ ), "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter."
104
+ from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
105
+ import torch.distributed as dist
106
+
107
+ dist.init_process_group("nccl")
108
+ local_rank = dist.get_rank()
109
+ torch.cuda.set_device(dist.get_rank())
110
+ device = "cuda"
111
+
112
+ init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
113
+
114
+ initialize_model_parallel(
115
+ sequence_parallel_degree=dist.get_world_size(),
116
+ ring_degree=1,
117
+ ulysses_degree=dist.get_world_size(),
118
+ )
119
+
120
+ prompt_input = args.prompt
121
+ if args.prompt_enhancer and args.image is None:
122
+ print(f"init prompt enhancer")
123
+ prompt_enhancer = PromptEnhancer()
124
+ prompt_input = prompt_enhancer(prompt_input)
125
+ print(f"enhanced prompt: {prompt_input}")
126
+ del prompt_enhancer
127
+ gc.collect()
128
+ torch.cuda.empty_cache()
129
+
130
+ pipe = DiffusionForcingPipeline(
131
+ args.model_id,
132
+ dit_path=args.model_id,
133
+ device=torch.device("cuda"),
134
+ weight_dtype=torch.bfloat16,
135
+ use_usp=args.use_usp,
136
+ offload=args.offload,
137
+ )
138
+
139
+ if args.causal_attention:
140
+ pipe.transformer.set_ar_attention(args.causal_block_size)
141
+
142
+ if args.teacache:
143
+ if args.ar_step > 0:
144
+ num_steps = (
145
+ args.inference_steps
146
+ + (((args.base_num_frames - 1) // 4 + 1) // args.causal_block_size - 1) * args.ar_step
147
+ )
148
+ print("num_steps:", num_steps)
149
+ else:
150
+ num_steps = args.inference_steps
151
+ pipe.transformer.initialize_teacache(
152
+ enable_teacache=True,
153
+ num_steps=num_steps,
154
+ teacache_thresh=args.teacache_thresh,
155
+ use_ret_steps=args.use_ret_steps,
156
+ ckpt_dir=args.model_id,
157
+ )
158
+
159
+ print(f"prompt:{prompt_input}")
160
+ print(f"guidance_scale:{guidance_scale}")
161
+
162
+ with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad():
163
+ video_frames = pipe(
164
+ prompt=prompt_input,
165
+ negative_prompt=negative_prompt,
166
+ image=image,
167
+ height=height,
168
+ width=width,
169
+ num_frames=num_frames,
170
+ num_inference_steps=args.inference_steps,
171
+ shift=shift,
172
+ guidance_scale=guidance_scale,
173
+ generator=torch.Generator(device="cuda").manual_seed(args.seed),
174
+ overlap_history=args.overlap_history,
175
+ addnoise_condition=args.addnoise_condition,
176
+ base_num_frames=args.base_num_frames,
177
+ ar_step=args.ar_step,
178
+ causal_block_size=args.causal_block_size,
179
+ fps=fps,
180
+ )[0]
181
+
182
+ if local_rank == 0:
183
+ current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
184
+ video_out_file = f"{args.prompt[:100].replace('/','')}_{args.seed}_{current_time}.mp4"
185
+ output_path = os.path.join(save_dir, video_out_file)
186
+ imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"])
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.5.1
2
+ torchvision==0.20.1
3
+ opencv-python==4.10.0.84
4
+ diffusers>=0.31.0
5
+ transformers==4.49.0
6
+ tokenizers==0.21.1
7
+ accelerate==1.6.0
8
+ tqdm
9
+ imageio
10
+ easydict
11
+ ftfy
12
+ dashscope
13
+ imageio-ffmpeg
14
+ flash_attn
15
+ numpy>=1.23.5,<2
16
+ xfuser
skycaptioner_v1/README.md ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SkyCaptioner-V1: A Structural Video Captioning Model
2
+
3
+ <p align="center">
4
+ 📑 <a href="https://arxiv.org/pdf/2504.13074">Technical Report</a> · 👋 <a href="https://www.skyreels.ai/home?utm_campaign=github_SkyReels_V2" target="_blank">Playground</a> · 💬 <a href="https://discord.gg/PwM6NYtccQ" target="_blank">Discord</a> · 🤗 <a href="https://huggingface.co/Skywork/SkyCaptioner-V1" target="_blank">Hugging Face</a> · 🤖 <a href="https://modelscope.cn/collections/SkyReels-V2-f665650130b144">ModelScope</a></a>
5
+ </p>
6
+
7
+ ---
8
+
9
+ Welcome to the SkyCaptioner-V1 repository! Here, you'll find the structural video captioning model weights and inference code for our video captioner that labels the video data efficiently and comprehensively.
10
+
11
+ ## 🔥🔥🔥 News!!
12
+
13
+ * Apr 21, 2025: 👋 We release the [vllm](https://github.com/vllm-project/vllm) batch inference code for SkyCaptioner-V1 Model and caption fusion inference code.
14
+ * Apr 21, 2025: 👋 We release the first shot-aware video captioning model [SkyCaptioner-V1 Model](https://huggingface.co/Skywork/SkyCaptioner-V1). For more details, please check our [paper](https://arxiv.org/pdf/2504.13074).
15
+
16
+ ## 📑 TODO List
17
+
18
+ - SkyCaptioner-V1
19
+
20
+ - [x] Checkpoints
21
+ - [x] Batch Inference Code
22
+ - [x] Caption Fusion Method
23
+ - [ ] Web Demo (Gradio)
24
+
25
+ ## 🌟 Overview
26
+
27
+ SkyCaptioner-V1 is a structural video captioning model designed to generate high-quality, structural descriptions for video data. It integrates specialized sub-expert models and multimodal large language models (MLLMs) with human annotations to address the limitations of general captioners in capturing professional film-related details. Key aspects include:
28
+
29
+ 1. ​​**Structural Representation**​: Combines general video descriptions (from MLLMs) with sub-expert captioner (e.g., shot types,shot angles, shot positions, camera motions.) and human annotations.
30
+ 2. ​​**Knowledge Distillation**​: Distills expertise from sub-expert captioners into a unified model.
31
+ 3. ​​**Application Flexibility**​: Generates dense captions for text-to-video (T2V) and concise prompts for image-to-video (I2V) tasks.
32
+
33
+ ## 🔑 Key Features
34
+
35
+ ### Structural Captioning Framework
36
+
37
+ Our Video Captioning model captures multi-dimensional details:
38
+
39
+ * ​​**Subjects**​: Appearance, action, expression, position, and hierarchical categorization.
40
+ * ​​**Shot Metadata**​: Shot type (e.g., close-up, long shot), shot angle, shot position, camera motion, environment, lighting, etc.
41
+
42
+ ### Sub-Expert Integration
43
+
44
+ * ​​**Shot Captioner**​: Classifies shot type, angle, and position with high precision.
45
+ * ​​**Expression Captioner**​: Analyzes facial expressions, emotion intensity, and temporal dynamics.
46
+ * ​​**Camera Motion Captioner**​: Tracks 6DoF camera movements and composite motion types,
47
+
48
+ ### Training Pipeline
49
+
50
+ * Trained on \~2M high-quality, concept-balanced videos curated from 10M raw samples.
51
+ * Fine-tuned on Qwen2.5-VL-7B-Instruct with a global batch size of 512 across 32 A800 GPUs.
52
+ * Optimized using AdamW (learning rate: 1e-5) for 2 epochs.
53
+
54
+ ### Dynamic Caption Fusion:
55
+
56
+ * Adapts output length based on application (T2V/I2V).
57
+ * Employs LLM Model to fusion structural fields to get a natural and fluency caption for downstream tasks.
58
+
59
+ ## 📊 Benchmark Results
60
+
61
+ SkyCaptioner-V1 demonstrates significant improvements over existing models in key film-specific captioning tasks, particularly in ​**shot-language understanding** and ​​**domain-specific precision**​. The differences stem from its structural architecture and expert-guided training:
62
+
63
+ 1. ​​**Superior Shot-Language Understanding**​:
64
+ * ​Our Captioner model outperforms Qwen2.5-VL-72B with +11.2% in shot type, +16.1% in shot angle, and +50.4% in shot position accuracy. Because SkyCaptioner-V1’s specialized shot classifiers outperform generalist MLLMs, which lack film-domain fine-tuning.
65
+ * ​+28.5% accuracy in camera motion vs. Tarsier2-recap-7B (88.8% vs. 41.5%):
66
+ Its 6DoF motion analysis and active learning pipeline address ambiguities in composite motions (e.g., tracking + panning) that challenge generic captioners.
67
+ 2. ​​**High domain-specific precision**​:
68
+ * ​​Expression accuracy​: ​68.8% vs. 54.3% (Tarsier2-recap-7B), leveraging temporal-aware S2D frameworks to capture dynamic facial changes.
69
+
70
+ <p align="center">
71
+ <table align="center">
72
+ <thead>
73
+ <tr>
74
+ <th>Metric</th>
75
+ <th>Qwen2.5-VL-7B-Ins.</th>
76
+ <th>Qwen2.5-VL-72B-Ins.</th>
77
+ <th>Tarsier2-recap-7B</th>
78
+ <th>SkyCaptioner-V1</th>
79
+ </tr>
80
+ </thead>
81
+ <tbody>
82
+ <tr>
83
+ <td>Avg accuracy</td>
84
+ <td>51.4%</td>
85
+ <td>58.7%</td>
86
+ <td>49.4%</td>
87
+ <td><strong>76.3%</strong></td>
88
+ </tr>
89
+ <tr>
90
+ <td>shot type</td>
91
+ <td>76.8%</td>
92
+ <td>82.5%</td>
93
+ <td>60.2%</td>
94
+ <td><strong>93.7%</strong></td>
95
+ </tr>
96
+ <tr>
97
+ <td>shot angle</td>
98
+ <td>60.0%</td>
99
+ <td>73.7%</td>
100
+ <td>52.4%</td>
101
+ <td><strong>89.8%</strong></td>
102
+ </tr>
103
+ <tr>
104
+ <td>shot position</td>
105
+ <td>28.4%</td>
106
+ <td>32.7%</td>
107
+ <td>23.6%</td>
108
+ <td><strong>83.1%</strong></td>
109
+ </tr>
110
+ <tr>
111
+ <td>camera motion</td>
112
+ <td>62.0%</td>
113
+ <td>61.2%</td>
114
+ <td>45.3%</td>
115
+ <td><strong>85.3%</strong></td>
116
+ </tr>
117
+ <tr>
118
+ <td>expression</td>
119
+ <td>43.6%</td>
120
+ <td>51.5%</td>
121
+ <td>54.3%</td>
122
+ <td><strong>68.8%</strong></td>
123
+ </tr>
124
+ <tr>
125
+ <td>TYPES_type</td>
126
+ <td>43.5%</td>
127
+ <td>49.7%</td>
128
+ <td>47.6%</td>
129
+ <td><strong>82.5%</strong></td>
130
+ </tr>
131
+ <tr>
132
+ <td>TYPES_sub_type</td>
133
+ <td>38.9%</td>
134
+ <td>44.9%</td>
135
+ <td>45.9%</td>
136
+ <td><strong>75.4%</strong></td>
137
+ </tr>
138
+ <tr>
139
+ <td>appearance</td>
140
+ <td>40.9%</td>
141
+ <td>52.0%</td>
142
+ <td>45.6%</td>
143
+ <td><strong>59.3%</strong></td>
144
+ </tr>
145
+ <tr>
146
+ <td>action</td>
147
+ <td>32.4%</td>
148
+ <td>52.0%</td>
149
+ <td><strong>69.8%</strong></td>
150
+ <td>68.8%</td>
151
+ </tr>
152
+ <tr>
153
+ <td>position</td>
154
+ <td>35.4%</td>
155
+ <td>48.6%</td>
156
+ <td>45.5%</td>
157
+ <td><strong>57.5%</strong></td>
158
+ </tr>
159
+ <tr>
160
+ <td>is_main_subject</td>
161
+ <td>58.5%</td>
162
+ <td>68.7%</td>
163
+ <td>69.7%</td>
164
+ <td><strong>80.9%</strong></td>
165
+ </tr>
166
+ <tr>
167
+ <td>environment</td>
168
+ <td>70.4%</td>
169
+ <td><strong>72.7%</strong></td>
170
+ <td>61.4%</td>
171
+ <td>70.5%</td>
172
+ </tr>
173
+ <tr>
174
+ <td>lighting</td>
175
+ <td>77.1%</td>
176
+ <td><strong>80.0%</strong></td>
177
+ <td>21.2%</td>
178
+ <td>76.5%</td>
179
+ </tr>
180
+ </tbody>
181
+ </table>
182
+ </p>
183
+
184
+ ## 📦 Model Downloads
185
+
186
+ Our SkyCaptioner-V1 model can be downloaded from [SkyCaptioner-V1 Model](https://huggingface.co/Skywork/SkyCaptioner-V1).
187
+ We use [Qwen2.5-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-32B-Instruct) as our caption fusion model to intelligently combine structured caption fields, producing either dense or sparse final captions depending on application requirements.
188
+
189
+ ```shell
190
+ # download SkyCaptioner-V1
191
+ huggingface-cli download Skywork/SkyCaptioner-V1 --local-dir /path/to/your_local_model_path
192
+ # download Qwen2.5-32B-Instruct
193
+ huggingface-cli download Qwen/Qwen2.5-32B-Instruct --local-dir /path/to/your_local_model_path2
194
+ ```
195
+
196
+ ## 🛠️ Running Guide
197
+
198
+ Begin by cloning the repository:
199
+
200
+ ```shell
201
+ git clone https://github.com/SkyworkAI/SkyReels-V2
202
+ cd skycaptioner_v1
203
+ ```
204
+
205
+ ### Installation Guide for Linux
206
+
207
+ We recommend Python 3.10 and CUDA version 12.2 for the manual installation.
208
+
209
+ ```shell
210
+ pip install -r requirements.txt
211
+ ```
212
+
213
+ ### Running Command
214
+
215
+ #### Get Structural Caption by SkyCaptioner-V1
216
+
217
+ ```shell
218
+ export SkyCaptioner_V1_Model_PATH="/path/to/your_local_model_path"
219
+
220
+ python scripts/vllm_struct_caption.py \
221
+ --model_path ${SkyCaptioner_V1_Model_PATH} \
222
+ --input_csv "./examples/test.csv" \
223
+ --out_csv "./examples/test_result.csv" \
224
+ --tp 1 \
225
+ --bs 4
226
+ ```
227
+
228
+ #### T2V/I2V Caption Fusion by Qwen2.5-32B-Instruct Model
229
+
230
+ ```shell
231
+ export LLM_MODEL_PATH="/path/to/your_local_model_path2"
232
+
233
+ python scripts/vllm_fusion_caption.py \
234
+ --model_path ${LLM_MODEL_PATH} \
235
+ --input_csv "./examples/test_result.csv" \
236
+ --out_csv "./examples/test_result_caption.csv" \
237
+ --bs 4 \
238
+ --tp 1 \
239
+ --task t2v
240
+ ```
241
+ > **Note**:
242
+ > - If you want to get i2v caption, just change the `--task t2v` to `--task i2v` in your Command.
243
+
244
+ ## Acknowledgements
245
+
246
+ We would like to thank the contributors of <a href="https://github.com/QwenLM/Qwen2.5-VL">Qwen2.5-VL</a>, <a href="https://github.com/bytedance/tarsier">tarsier2</a> and <a href="https://github.com/vllm-project/vllm">vllm</a> repositories, for their open research and contributions.
247
+
248
+ ## Citation
249
+
250
+ ```bibtex
251
+ @misc{chen2025skyreelsv2infinitelengthfilmgenerative,
252
+ author = {Guibin Chen and Dixuan Lin and Jiangping Yang and Chunze Lin and Junchen Zhu and Mingyuan Fan and Hao Zhang and Sheng Chen and Zheng Chen and Chengcheng Ma and Weiming Xiong and Wei Wang and Nuo Pang and Kang Kang and Zhiheng Xu and Yuzhe Jin and Yupeng Liang and Yubing Song and Peng Zhao and Boyuan Xu and Di Qiu and Debang Li and Zhengcong Fei and Yang Li and Yahui Zhou},
253
+ title = {Skyreels V2:Infinite-Length Film Generative Model},
254
+ year = {2025},
255
+ eprint={2504.13074},
256
+ archivePrefix={arXiv},
257
+ primaryClass={cs.CV},
258
+ url={https://arxiv.org/abs/2504.13074}
259
+ }
260
+ ```
261
+
262
+
skycaptioner_v1/examples/data/1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16becc00811e427d9e3f5da5a977239907ea3a6d45b5482b9bcfea2abc3c6b7f
3
+ size 4821469
skycaptioner_v1/examples/data/2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:109b954b7b1e36addef840554af53241415492e71972c87b7bfb03f77bf0d68a
3
+ size 1030652
skycaptioner_v1/examples/data/3.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fd544787b24265605cf16c954f495fe9d73680e0529f1438e47c02803a9c2bf
3
+ size 1661366
skycaptioner_v1/examples/data/4.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5cc4038924aae963a54779d713f7f9680259ae81a4446fb2775cc6bb51d907c
3
+ size 2523332
skycaptioner_v1/examples/test.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ path
2
+ ./examples/data/1.mp4
3
+ ./examples/data/2.mp4
4
+ ./examples/data/3.mp4
5
+ ./examples/data/4.mp4
skycaptioner_v1/examples/test_result.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ path,structural_caption
2
+ ./examples/data/1.mp4,"{""subjects"": [{""TYPES"": {""type"": ""Sport"", ""sub_type"": ""Other""}, ""appearance"": ""Wearing winter sports gear, including a helmet and goggles."", ""action"": ""The video shows a snowy mountain landscape with a ski slope surrounded by dense trees and distant mountains. A skier is seen descending the slope, moving from the top left to the bottom right of the frame. The skier maintains a steady pace, navigating the curves of the slope. The background includes a ski lift with chairs moving along the cables, and the slope is marked with red and white lines indicating the path for skiers. The skier continues to descend, gradually getting closer to the bottom of the slope."", ""expression"": """", ""position"": ""Centered in the frame, moving downwards on the slope."", ""is_main_subject"": true}, {""TYPES"": {""type"": ""Scenery"", ""sub_type"": ""Mountain""}, ""appearance"": ""White, covering the ground and trees."", ""action"": """", ""expression"": """", ""position"": ""Surrounding the skier, covering the entire visible area."", ""is_main_subject"": false}, {""TYPES"": {""type"": ""Plant"", ""sub_type"": ""Tree""}, ""appearance"": ""Tall, evergreen, covered in snow."", ""action"": """", ""expression"": """", ""position"": ""Scattered throughout the scene, both in the foreground and background."", ""is_main_subject"": false}, {""TYPES"": {""type"": ""Scenery"", ""sub_type"": ""Mountain""}, ""appearance"": ""Snow-covered, with ski lifts and structures."", ""action"": """", ""expression"": """", ""position"": ""In the background, providing context to the location."", ""is_main_subject"": false}], ""shot_type"": ""long_shot"", ""shot_angle"": ""high_angle"", ""shot_position"": ""overlooking_view"", ""camera_motion"": ""the camera moves toward zooms in"", ""environment"": ""A snowy mountain landscape with a ski slope, trees, and a ski resort in the background."", ""lighting"": ""Bright daylight, casting shadows and highlighting the snow's texture.""}"
3
+ ./examples/data/2.mp4,"{""subjects"": [{""TYPES"": {""type"": ""Human"", ""sub_type"": ""Woman""}, ""appearance"": ""Long, straight black hair, wearing a sparkling choker necklace with a diamond-like texture, light-colored top, subtle makeup with pink lipstick, stud earrings."", ""action"": ""A woman wearing a sparkling choker necklace and earrings is sitting in a car, looking to her left and speaking. A man, dressed in a suit, is sitting next to her, attentively watching her. The background outside the car is green, indicating a possible outdoor setting."", ""expression"": ""The individual in the video exhibits a neutral facial expression, characterized by slightly open lips and a gentle, soft-focus gaze. There are no noticeable signs of sadness or distress evident in their demeanor."", ""position"": ""Seated in the foreground of the car, facing slightly to the right."", ""is_main_subject"": true}, {""TYPES"": {""type"": ""Human"", ""sub_type"": ""Man""}, ""appearance"": ""Short hair, wearing a dark-colored suit with a white shirt."", ""action"": """", ""expression"": """", ""position"": ""Seated in the background of the car, facing the woman."", ""is_main_subject"": false}], ""shot_type"": ""close_up"", ""shot_angle"": ""eye_level"", ""shot_position"": ""side_view"", ""camera_motion"": """", ""environment"": ""Interior of a car with dark upholstery."", ""lighting"": ""Soft and natural lighting, suggesting daytime.""}"
4
+ ./examples/data/3.mp4,"{""subjects"": [{""TYPES"": {""type"": ""Animal"", ""sub_type"": ""Insect""}, ""appearance"": ""The spider has a spherical, yellowish-green body with darker green stripes and spots. It has eight slender legs with visible joints and segments."", ""action"": ""A spider with a yellow and green body and black and white striped legs is hanging from its web in a natural setting with a blurred background of green and brown hues. The spider remains mostly still, with slight movements in its legs and body, indicating a gentle swaying motion."", ""expression"": """", ""position"": ""The spider is centrally positioned in the frame, hanging from a web."", ""is_main_subject"": true}], ""shot_type"": ""extreme_close_up"", ""shot_angle"": ""eye_level"", ""shot_position"": ""front_view"", ""camera_motion"": """", ""environment"": ""The background consists of vertical, out-of-focus lines in shades of green and brown, suggesting a natural environment with vegetation."", ""lighting"": ""The lighting is soft and diffused, with no harsh shadows, indicating an overcast sky or a shaded area.""}"
5
+ ./examples/data/4.mp4,"{""subjects"": [{""TYPES"": {""type"": ""Sport"", ""sub_type"": ""Football""}, ""appearance"": ""Wearing a dark-colored jersey, black shorts, and bright blue soccer shoes with white soles."", ""action"": ""A man is on a grassy field with orange cones placed in a line. He is wearing a gray shirt, black shorts, and black socks with blue shoes. The man starts by standing still with his feet apart, then begins to move forward while keeping his eyes on the cones. He continues to run forward, maintaining his focus on the cones, and his feet move in a coordinated manner to navigate around them. The background shows a clear sky, some trees, and a few buildings in the distance."", ""expression"": """", ""position"": ""Centered in the frame, with the soccer ball positioned between the cones."", ""is_main_subject"": true}, {""TYPES"": {""type"": ""Sport"", ""sub_type"": ""Other""}, ""appearance"": ""Bright orange, conical shape."", ""action"": """", ""expression"": """", ""position"": ""Placed on the grass, creating a path for the soccer player."", ""is_main_subject"": false}], ""shot_type"": ""full_shot"", ""shot_angle"": ""low_angle"", ""shot_position"": ""front_view"", ""camera_motion"": ""use a tracking shot, the camera moves toward zooms in"", ""environment"": ""Outdoor sports field with well-maintained grass, trees, and a clear blue sky."", ""lighting"": ""Bright and natural, suggesting a sunny day.""}"
skycaptioner_v1/infer_fusion_caption.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ expor LLM_MODEL_PATH="/path/to/your_local_model_path2"
2
+
3
+ python scripts/vllm_fusion_caption.py \
4
+ --model_path ${LLM_MODEL_PATH} \
5
+ --input_csv "./examples/test_result.csv" \
6
+ --out_csv "./examples/test_result_caption.csv" \
7
+ --bs 4 \
8
+ --tp 1 \
9
+ --task t2v
skycaptioner_v1/infer_struct_caption.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ expor SkyCaptioner_V1_Model_PATH="/path/to/your_local_model_path"
2
+
3
+ python scripts/vllm_struct_caption.py \
4
+ --model_path ${SkyCaptioner_V1_Model_PATH} \
5
+ --input_csv "./examples/test.csv" \
6
+ --out_csv "./examepls/test_result.csv" \
7
+ --tp 1 \
8
+ --bs 32
skycaptioner_v1/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ decord==0.6.0
2
+ transformers>=4.49.0
3
+ vllm==0.8.4
skycaptioner_v1/scripts/utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ def result_writer(indices_list: list, result_list: list, meta: pd.DataFrame, column):
5
+ flat_indices = []
6
+ for x in zip(indices_list):
7
+ flat_indices.extend(x)
8
+ flat_results = []
9
+ for x in zip(result_list):
10
+ flat_results.extend(x)
11
+
12
+ flat_indices = np.array(flat_indices)
13
+ flat_results = np.array(flat_results)
14
+
15
+ unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True)
16
+ meta.loc[unique_indices, column[0]] = flat_results[unique_indices_idx]
17
+
18
+ meta = meta.loc[unique_indices]
19
+ return meta
skycaptioner_v1/scripts/vllm_fusion_caption.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import argparse
4
+ import glob
5
+ import time
6
+ import gc
7
+ from tqdm import tqdm
8
+ import torch
9
+ from transformers import AutoTokenizer
10
+ import pandas as pd
11
+ from vllm import LLM, SamplingParams
12
+ from torch.utils.data import DataLoader
13
+ import json
14
+ import random
15
+ from utils import result_writer
16
+
17
+ SYSTEM_PROMPT_I2V = """
18
+ You are an expert in video captioning. You are given a structured video caption and you need to compose it to be more natural and fluent in English.
19
+
20
+ ## Structured Input
21
+ {structured_input}
22
+
23
+ ## Notes
24
+ 1. If there has an empty field, just ignore it and do not mention it in the output.
25
+ 2. Do not make any semantic changes to the original fields. Please be sure to follow the original meaning.
26
+ 3. If the action field is not empty, eliminate the irrelevant information in the action field that is not related to the timing action(such as wearings, background and environment information) to make a pure action field.
27
+
28
+ ## Output Principles and Orders
29
+ 1. First, eliminate the static information in the action field that is not related to the timing action, such as background or environment information.
30
+ 2. Second, describe each subject with its pure action and expression if these fields exist.
31
+
32
+ ## Output
33
+ Please directly output the final composed caption without any additional information.
34
+ """
35
+
36
+ SYSTEM_PROMPT_T2V = """
37
+ You are an expert in video captioning. You are given a structured video caption and you need to compose it to be more natural and fluent in English.
38
+
39
+ ## Structured Input
40
+ {structured_input}
41
+
42
+ ## Notes
43
+ 1. According to the action field information, change its name field to the subject pronoun in the action.
44
+ 2. If there has an empty field, just ignore it and do not mention it in the output.
45
+ 3. Do not make any semantic changes to the original fields. Please be sure to follow the original meaning.
46
+
47
+ ## Output Principles and Orders
48
+ 1. First, declare the shot_type, then declare the shot_angle and the shot_position fields.
49
+ 2. Second, eliminate information in the action field that is not related to the timing action, such as background or environment information if action is not empty.
50
+ 3. Third, describe each subject with its pure action, appearance, expression, position if these fields exist.
51
+ 4. Finally, declare the environment and lighting if the environment and lighting fields are not empty.
52
+
53
+ ## Output
54
+ Please directly output the final composed caption without any additional information.
55
+ """
56
+
57
+ SHOT_TYPE_LIST = [
58
+ 'close-up shot',
59
+ 'extreme close-up shot',
60
+ 'medium shot',
61
+ 'long shot',
62
+ 'full shot',
63
+ ]
64
+
65
+
66
+ class StructuralCaptionDataset(torch.utils.data.Dataset):
67
+ def __init__(self, input_csv, model_path):
68
+ self.meta = pd.read_csv(input_csv)
69
+ self.task = args.task
70
+ self.system_prompt = SYSTEM_PROMPT_T2V if self.task == 't2v' else SYSTEM_PROMPT_I2V
71
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
72
+
73
+
74
+ def __len__(self):
75
+ return len(self.meta)
76
+
77
+ def __getitem__(self, index):
78
+ row = self.meta.iloc[index]
79
+ real_index = self.meta.index[index]
80
+
81
+ struct_caption = json.loads(row["structural_caption"])
82
+
83
+ camera_movement = struct_caption.get('camera_motion', '')
84
+ if camera_movement != '':
85
+ camera_movement += '.'
86
+ camera_movement = camera_movement.capitalize()
87
+
88
+ fusion_by_llm = False
89
+ cleaned_struct_caption = self.clean_struct_caption(struct_caption, self.task)
90
+ if cleaned_struct_caption.get('num_subjects', 0) > 0:
91
+ new_struct_caption = json.dumps(cleaned_struct_caption, indent=4, ensure_ascii=False)
92
+ conversation = [
93
+ {
94
+ "role": "system",
95
+ "content": self.system_prompt.format(structured_input=new_struct_caption),
96
+ },
97
+ ]
98
+ text = self.tokenizer.apply_chat_template(
99
+ conversation,
100
+ tokenize=False,
101
+ add_generation_prompt=True
102
+ )
103
+ fusion_by_llm = True
104
+ else:
105
+ text = '-'
106
+ return real_index, fusion_by_llm, text, '-', camera_movement
107
+
108
+ def clean_struct_caption(self, struct_caption, task):
109
+ raw_subjects = struct_caption.get('subjects', [])
110
+ subjects = []
111
+ for subject in raw_subjects:
112
+ subject_type = subject.get("TYPES", {}).get('type', '')
113
+ subject_sub_type = subject.get("TYPES", {}).get('sub_type', '')
114
+ if subject_type not in ["Human", "Animal"]:
115
+ subject['expression'] = ''
116
+ if subject_type == 'Human' and subject_sub_type == 'Accessory':
117
+ subject['expression'] = ''
118
+ if subject_sub_type != '':
119
+ subject['name'] = subject_sub_type
120
+ if 'TYPES' in subject:
121
+ del subject['TYPES']
122
+ if 'is_main_subject' in subject:
123
+ del subject['is_main_subject']
124
+ subjects.append(subject)
125
+
126
+ to_del_subject_ids = []
127
+ for idx, subject in enumerate(subjects):
128
+ action = subject.get('action', '').strip()
129
+ subject['action'] = action
130
+ if random.random() > 0.9 and 'appearance' in subject:
131
+ del subject['appearance']
132
+ if random.random() > 0.9 and 'position' in subject:
133
+ del subject['position']
134
+ if task == 'i2v':
135
+ # just keep name and action, expression in subjects
136
+ dropped_keys = ['appearance', 'position']
137
+ for key in dropped_keys:
138
+ if key in subject:
139
+ del subject[key]
140
+ if subject['action'] == '' and ('expression' not in subject or subject['expression'] == ''):
141
+ to_del_subject_ids.append(idx)
142
+
143
+ # delete the subjects according to the to_del_subject_ids
144
+ for idx in sorted(to_del_subject_ids, reverse=True):
145
+ del subjects[idx]
146
+
147
+
148
+ shot_type = struct_caption.get('shot_type', '').replace('_', ' ')
149
+ if shot_type not in SHOT_TYPE_LIST:
150
+ struct_caption['shot_type'] = ''
151
+
152
+ new_struct_caption = {
153
+ 'num_subjects': len(subjects),
154
+ 'subjects': subjects,
155
+ 'shot_type': struct_caption.get('shot_type', ''),
156
+ 'shot_angle': struct_caption.get('shot_angle', ''),
157
+ 'shot_position': struct_caption.get('shot_position', ''),
158
+ 'environment': struct_caption.get('environment', ''),
159
+ 'lighting': struct_caption.get('lighting', ''),
160
+ }
161
+
162
+ if task == 't2v' and random.random() > 0.9:
163
+ del new_struct_caption['lighting']
164
+
165
+ if task == 'i2v':
166
+ drop_keys = ['environment', 'lighting', 'shot_type', 'shot_angle', 'shot_position']
167
+ for drop_key in drop_keys:
168
+ del new_struct_caption[drop_key]
169
+ return new_struct_caption
170
+
171
+ def custom_collate_fn(batch):
172
+ real_indices, fusion_by_llm, texts, original_texts, camera_movements = zip(*batch)
173
+ return list(real_indices), list(fusion_by_llm), list(texts), list(original_texts), list(camera_movements)
174
+
175
+
176
+ if __name__ == "__main__":
177
+ parser = argparse.ArgumentParser(description="Caption Fusion by LLM")
178
+ parser.add_argument("--input_csv", default="./examples/test_result.csv")
179
+ parser.add_argument("--out_csv", default="./examples/test_result_caption.csv")
180
+ parser.add_argument("--bs", type=int, default=4)
181
+ parser.add_argument("--tp", type=int, default=1)
182
+ parser.add_argument("--model_path", required=True, type=str, help="LLM model path")
183
+ parser.add_argument("--task", default='t2v', help="t2v or i2v")
184
+
185
+ args = parser.parse_args()
186
+
187
+ sampling_params = SamplingParams(
188
+ temperature=0.1,
189
+ max_tokens=512,
190
+ stop=['\n\n']
191
+ )
192
+ # model_path = "/maindata/data/shared/public/Common-Models/Qwen2.5-32B-Instruct/"
193
+
194
+
195
+ llm = LLM(
196
+ model=args.model_path,
197
+ gpu_memory_utilization=0.9,
198
+ max_model_len=4096,
199
+ tensor_parallel_size = args.tp
200
+ )
201
+
202
+
203
+ dataset = StructuralCaptionDataset(input_csv=args.input_csv, model_path=args.model_path)
204
+
205
+ dataloader = DataLoader(
206
+ dataset,
207
+ batch_size=args.bs,
208
+ num_workers=8,
209
+ collate_fn=custom_collate_fn,
210
+ shuffle=False,
211
+ drop_last=False,
212
+ )
213
+
214
+ indices_list = []
215
+ result_list = []
216
+ for indices, fusion_by_llms, texts, original_texts, camera_movements in tqdm(dataloader):
217
+ llm_indices, llm_texts, llm_original_texts, llm_camera_movements = [], [], [], []
218
+ for idx, fusion_by_llm, text, original_text, camera_movement in zip(indices, fusion_by_llms, texts, original_texts, camera_movements):
219
+ if fusion_by_llm:
220
+ llm_indices.append(idx)
221
+ llm_texts.append(text)
222
+ llm_original_texts.append(original_text)
223
+ llm_camera_movements.append(camera_movement)
224
+ else:
225
+ indices_list.append(idx)
226
+ caption = original_text + " " + camera_movement
227
+ result_list.append(caption)
228
+ if len(llm_texts) > 0:
229
+ try:
230
+ outputs = llm.generate(llm_texts, sampling_params, use_tqdm=False)
231
+ results = []
232
+ for output in outputs:
233
+ result = output.outputs[0].text.strip()
234
+ results.append(result)
235
+ indices_list.extend(llm_indices)
236
+ except Exception as e:
237
+ print(f"Error at {llm_indices}: {str(e)}")
238
+ indices_list.extend(llm_indices)
239
+ results = llm_original_texts
240
+
241
+ for result, camera_movement in zip(results, llm_camera_movements):
242
+ # concat camera movement to fusion_caption
243
+ llm_caption = result + " " + camera_movement
244
+ result_list.append(llm_caption)
245
+ torch.cuda.empty_cache()
246
+ gc.collect()
247
+ gathered_list = [indices_list, result_list]
248
+ meta_new = result_writer(indices_list, result_list, dataset.meta, column=[f"{args.task}_fusion_caption"])
249
+ meta_new.to_csv(args.out_csv, index=False)
250
+
skycaptioner_v1/scripts/vllm_struct_caption.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import decord
4
+ import argparse
5
+
6
+ import pandas as pd
7
+ import numpy as np
8
+
9
+ from tqdm import tqdm
10
+ from vllm import LLM, SamplingParams
11
+ from transformers import AutoTokenizer, AutoProcessor
12
+
13
+ from torch.utils.data import DataLoader
14
+
15
+ SYSTEM_PROMPT = "I need you to generate a structured and detailed caption for the provided video. The structured output and the requirements for each field are as shown in the following JSON content: {\"subjects\": [{\"appearance\": \"Main subject appearance description\", \"action\": \"Main subject action\", \"expression\": \"Main subject expression (Only for human/animal categories, empty otherwise)\", \"position\": \"Subject position in the video (Can be relative position to other objects or spatial description)\", \"TYPES\": {\"type\": \"Main category (e.g., Human)\", \"sub_type\": \"Sub-category (e.g., Man)\"}, \"is_main_subject\": true}, {\"appearance\": \"Non-main subject appearance description\", \"action\": \"Non-main subject action\", \"expression\": \"Non-main subject expression (Only for human/animal categories, empty otherwise)\", \"position\": \"Position of non-main subject 1\", \"TYPES\": {\"type\": \"Main category (e.g., Vehicles)\", \"sub_type\": \"Sub-category (e.g., Ship)\"}, \"is_main_subject\": false}], \"shot_type\": \"Shot type(Options: long_shot/full_shot/medium_shot/close_up/extreme_close_up/other)\", \"shot_angle\": \"Camera angle(Options: eye_level/high_angle/low_angle/other)\", \"shot_position\": \"Camera position(Options: front_view/back_view/side_view/over_the_shoulder/overhead_view/point_of_view/aerial_view/overlooking_view/other)\", \"camera_motion\": \"Camera movement description\", \"environment\": \"Video background/environment description\", \"lighting\": \"Lighting information in the video\"}"
16
+
17
+
18
+ class VideoTextDataset(torch.utils.data.Dataset):
19
+ def __init__(self, csv_path, model_path):
20
+ self.meta = pd.read_csv(csv_path)
21
+ self._path = 'path'
22
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
23
+ self.processor = AutoProcessor.from_pretrained(model_path)
24
+
25
+ def __getitem__(self, index):
26
+ row = self.meta.iloc[index]
27
+ path = row[self._path]
28
+ real_index = self.meta.index[index]
29
+ vr = decord.VideoReader(path, ctx=decord.cpu(0), width=360, height=420)
30
+ start = 0
31
+ end = len(vr)
32
+ # avg_fps = vr.get_avg_fps()
33
+ index = self.get_index(end-start, 16, st=start)
34
+ frames = vr.get_batch(index).asnumpy() # n h w c
35
+ video_inputs = [torch.from_numpy(frames).permute(0, 3, 1, 2)]
36
+ conversation = {
37
+ "role": "user",
38
+ "content": [
39
+ {
40
+ "type": "video",
41
+ "video": row['path'],
42
+ "max_pixels": 360 * 420, # 460800
43
+ "fps": 2.0,
44
+ },
45
+ {
46
+ "type": "text",
47
+ "text": SYSTEM_PROMPT
48
+ },
49
+ ],
50
+ }
51
+
52
+ # 生成 user_input
53
+ user_input = self.processor.apply_chat_template(
54
+ [conversation],
55
+ tokenize=False,
56
+ add_generation_prompt=True
57
+ )
58
+ results = dict()
59
+ inputs = {
60
+ 'prompt': user_input,
61
+ 'multi_modal_data': {'video': video_inputs}
62
+ }
63
+ results["index"] = real_index
64
+ results['input'] = inputs
65
+ return results
66
+
67
+ def __len__(self):
68
+ return len(self.meta)
69
+
70
+ def get_index(self, video_size, num_frames, st=0):
71
+ seg_size = max(0., float(video_size - 1) / num_frames)
72
+ max_frame = int(video_size) - 1
73
+ seq = []
74
+ # index from 1, must add 1
75
+ for i in range(num_frames):
76
+ start = int(np.round(seg_size * i))
77
+ # end = int(np.round(seg_size * (i + 1)))
78
+ idx = min(start, max_frame)
79
+ seq.append(idx+st)
80
+ return seq
81
+
82
+ def result_writer(indices_list: list, result_list: list, meta: pd.DataFrame, column):
83
+ flat_indices = []
84
+ for x in zip(indices_list):
85
+ flat_indices.extend(x)
86
+ flat_results = []
87
+ for x in zip(result_list):
88
+ flat_results.extend(x)
89
+
90
+ flat_indices = np.array(flat_indices)
91
+ flat_results = np.array(flat_results)
92
+
93
+ unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True)
94
+ meta.loc[unique_indices, column[0]] = flat_results[unique_indices_idx]
95
+
96
+ meta = meta.loc[unique_indices]
97
+ return meta
98
+
99
+
100
+ def worker_init_fn(worker_id):
101
+ # Set different seed for each worker
102
+ worker_seed = torch.initial_seed() % 2**32
103
+ np.random.seed(worker_seed)
104
+ # Prevent deadlocks by setting timeout
105
+ torch.set_num_threads(1)
106
+
107
+ def main():
108
+ parser = argparse.ArgumentParser(description="SkyCaptioner-V1 vllm batch inference")
109
+ parser.add_argument("--input_csv", default="./examples/test.csv")
110
+ parser.add_argument("--out_csv", default="./examples/test_result.csv")
111
+ parser.add_argument("--bs", type=int, default=4)
112
+ parser.add_argument("--tp", type=int, default=1)
113
+ parser.add_argument("--model_path", required=True, type=str, help="skycaptioner-v1 model path")
114
+ args = parser.parse_args()
115
+
116
+ dataset = VideoTextDataset(csv_path=args.input_csv, model_path=args.model_path)
117
+ dataloader = DataLoader(
118
+ dataset,
119
+ batch_size=args.bs,
120
+ num_workers=4,
121
+ worker_init_fn=worker_init_fn,
122
+ persistent_workers=True,
123
+ timeout=180,
124
+ )
125
+
126
+ sampling_params = SamplingParams(temperature=0.05, max_tokens=2048)
127
+
128
+ llm = LLM(model=args.model_path,
129
+ gpu_memory_utilization=0.6,
130
+ max_model_len=31920,
131
+ tensor_parallel_size=args.tp)
132
+
133
+ indices_list = []
134
+ caption_save = []
135
+ for video_batch in tqdm(dataloader):
136
+ indices = video_batch["index"]
137
+ inputs = video_batch["input"]
138
+ batch_user_inputs = []
139
+ for prompt, video in zip(inputs['prompt'], inputs['multi_modal_data']['video'][0]):
140
+ usi={'prompt':prompt, 'multi_modal_data':{'video':video}}
141
+ batch_user_inputs.append(usi)
142
+ outputs = llm.generate(batch_user_inputs, sampling_params, use_tqdm=False)
143
+ struct_outputs = [output.outputs[0].text for output in outputs]
144
+
145
+ indices_list.extend(indices.tolist())
146
+ caption_save.extend(struct_outputs)
147
+
148
+ meta_new = result_writer(indices_list, caption_save, dataset.meta, column=["structural_caption"])
149
+ meta_new.to_csv(args.out_csv, index=False)
150
+ print(f'Saved structural_caption to {args.out_csv}')
151
+
152
+ if __name__ == '__main__':
153
+ main()
skyreels_v2_infer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .pipelines import DiffusionForcingPipeline
skyreels_v2_infer/distributed/__init__.py ADDED
File without changes
skyreels_v2_infer/distributed/xdit_context_parallel.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.amp as amp
4
+ from torch.backends.cuda import sdp_kernel
5
+ from xfuser.core.distributed import get_sequence_parallel_rank
6
+ from xfuser.core.distributed import get_sequence_parallel_world_size
7
+ from xfuser.core.distributed import get_sp_group
8
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
9
+
10
+ from ..modules.transformer import sinusoidal_embedding_1d
11
+
12
+
13
+ def pad_freqs(original_tensor, target_len):
14
+ seq_len, s1, s2 = original_tensor.shape
15
+ pad_size = target_len - seq_len
16
+ padding_tensor = torch.ones(pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device)
17
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
18
+ return padded_tensor
19
+
20
+
21
+ @amp.autocast("cuda", enabled=False)
22
+ def rope_apply(x, grid_sizes, freqs):
23
+ """
24
+ x: [B, L, N, C].
25
+ grid_sizes: [B, 3].
26
+ freqs: [M, C // 2].
27
+ """
28
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
29
+ # split freqs
30
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
31
+
32
+ # loop over samples
33
+ output = []
34
+ grid = [grid_sizes.tolist()] * x.size(0)
35
+ for i, (f, h, w) in enumerate(grid):
36
+ seq_len = f * h * w
37
+
38
+ # precompute multipliers
39
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2))
40
+ freqs_i = torch.cat(
41
+ [
42
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
43
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
44
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
45
+ ],
46
+ dim=-1,
47
+ ).reshape(seq_len, 1, -1)
48
+
49
+ # apply rotary embedding
50
+ sp_size = get_sequence_parallel_world_size()
51
+ sp_rank = get_sequence_parallel_rank()
52
+ freqs_i = pad_freqs(freqs_i, s * sp_size)
53
+ s_per_rank = s
54
+ freqs_i_rank = freqs_i[(sp_rank * s_per_rank) : ((sp_rank + 1) * s_per_rank), :, :]
55
+ x_i = torch.view_as_real(x_i * freqs_i_rank.cuda()).flatten(2)
56
+ x_i = torch.cat([x_i, x[i, s:]])
57
+
58
+ # append to collection
59
+ output.append(x_i)
60
+ return torch.stack(output).float()
61
+
62
+
63
+ def broadcast_should_calc(should_calc: bool) -> bool:
64
+ import torch.distributed as dist
65
+
66
+ device = torch.cuda.current_device()
67
+ int_should_calc = 1 if should_calc else 0
68
+ tensor = torch.tensor([int_should_calc], device=device, dtype=torch.int8)
69
+ dist.broadcast(tensor, src=0)
70
+ should_calc = tensor.item() == 1
71
+ return should_calc
72
+
73
+
74
+ def usp_dit_forward(self, x, t, context, clip_fea=None, y=None, fps=None):
75
+ """
76
+ x: A list of videos each with shape [C, T, H, W].
77
+ t: [B].
78
+ context: A list of text embeddings each with shape [L, C].
79
+ """
80
+ if self.model_type == "i2v":
81
+ assert clip_fea is not None and y is not None
82
+ # params
83
+ device = self.patch_embedding.weight.device
84
+ if self.freqs.device != device:
85
+ self.freqs = self.freqs.to(device)
86
+
87
+ if y is not None:
88
+ x = torch.cat([x, y], dim=1)
89
+
90
+ # embeddings
91
+ x = self.patch_embedding(x)
92
+ grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long)
93
+ x = x.flatten(2).transpose(1, 2)
94
+
95
+ if self.flag_causal_attention:
96
+ frame_num = grid_sizes[0]
97
+ height = grid_sizes[1]
98
+ width = grid_sizes[2]
99
+ block_num = frame_num // self.num_frame_per_block
100
+ range_tensor = torch.arange(block_num).view(-1, 1)
101
+ range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten()
102
+ casual_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
103
+ casual_mask = casual_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x.device)
104
+ casual_mask = casual_mask.repeat(1, height, width, 1, height, width)
105
+ casual_mask = casual_mask.reshape(frame_num * height * width, frame_num * height * width)
106
+ self.block_mask = casual_mask.unsqueeze(0).unsqueeze(0)
107
+
108
+ # time embeddings
109
+ with amp.autocast("cuda", dtype=torch.float32):
110
+ if t.dim() == 2:
111
+ b, f = t.shape
112
+ _flag_df = True
113
+ else:
114
+ _flag_df = False
115
+ e = self.time_embedding(
116
+ sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
117
+ ) # b, dim
118
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # b, 6, dim
119
+
120
+ if self.inject_sample_info:
121
+ fps = torch.tensor(fps, dtype=torch.long, device=device)
122
+
123
+ fps_emb = self.fps_embedding(fps).float()
124
+ if _flag_df:
125
+ e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
126
+ else:
127
+ e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
128
+
129
+ if _flag_df:
130
+ e = e.view(b, f, 1, 1, self.dim)
131
+ e0 = e0.view(b, f, 1, 1, 6, self.dim)
132
+ e = e.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3)
133
+ e0 = e0.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3)
134
+ e0 = e0.transpose(1, 2).contiguous()
135
+
136
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
137
+
138
+ # context
139
+ context = self.text_embedding(context)
140
+
141
+ if clip_fea is not None:
142
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
143
+ context = torch.concat([context_clip, context], dim=1)
144
+
145
+ # arguments
146
+ if e0.ndim == 4:
147
+ e0 = torch.chunk(e0, get_sequence_parallel_world_size(), dim=2)[get_sequence_parallel_rank()]
148
+ kwargs = dict(e=e0, grid_sizes=grid_sizes, freqs=self.freqs, context=context, block_mask=self.block_mask)
149
+
150
+ if self.enable_teacache:
151
+ modulated_inp = e0 if self.use_ref_steps else e
152
+ # teacache
153
+ if self.cnt % 2 == 0: # even -> conditon
154
+ self.is_even = True
155
+ if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
156
+ should_calc_even = True
157
+ self.accumulated_rel_l1_distance_even = 0
158
+ else:
159
+ rescale_func = np.poly1d(self.coefficients)
160
+ self.accumulated_rel_l1_distance_even += rescale_func(
161
+ ((modulated_inp - self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean())
162
+ .cpu()
163
+ .item()
164
+ )
165
+ if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
166
+ should_calc_even = False
167
+ else:
168
+ should_calc_even = True
169
+ self.accumulated_rel_l1_distance_even = 0
170
+ self.previous_e0_even = modulated_inp.clone()
171
+ else: # odd -> unconditon
172
+ self.is_even = False
173
+ if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
174
+ should_calc_odd = True
175
+ self.accumulated_rel_l1_distance_odd = 0
176
+ else:
177
+ rescale_func = np.poly1d(self.coefficients)
178
+ self.accumulated_rel_l1_distance_odd += rescale_func(
179
+ ((modulated_inp - self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean())
180
+ .cpu()
181
+ .item()
182
+ )
183
+ if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
184
+ should_calc_odd = False
185
+ else:
186
+ should_calc_odd = True
187
+ self.accumulated_rel_l1_distance_odd = 0
188
+ self.previous_e0_odd = modulated_inp.clone()
189
+
190
+ x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
191
+ if self.enable_teacache:
192
+ if self.is_even:
193
+ should_calc_even = broadcast_should_calc(should_calc_even)
194
+ if not should_calc_even:
195
+ x += self.previous_residual_even
196
+ else:
197
+ ori_x = x.clone()
198
+ for block in self.blocks:
199
+ x = block(x, **kwargs)
200
+ ori_x.mul_(-1)
201
+ ori_x.add_(x)
202
+ self.previous_residual_even = ori_x
203
+ else:
204
+ should_calc_odd = broadcast_should_calc(should_calc_odd)
205
+ if not should_calc_odd:
206
+ x += self.previous_residual_odd
207
+ else:
208
+ ori_x = x.clone()
209
+ for block in self.blocks:
210
+ x = block(x, **kwargs)
211
+ ori_x.mul_(-1)
212
+ ori_x.add_(x)
213
+ self.previous_residual_odd = ori_x
214
+ self.cnt += 1
215
+ if self.cnt >= self.num_steps:
216
+ self.cnt = 0
217
+ else:
218
+ # Context Parallel
219
+ for block in self.blocks:
220
+ x = block(x, **kwargs)
221
+
222
+ # head
223
+ if e.ndim == 3:
224
+ e = torch.chunk(e, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
225
+ x = self.head(x, e)
226
+ # Context Parallel
227
+ x = get_sp_group().all_gather(x, dim=1)
228
+ # unpatchify
229
+ x = self.unpatchify(x, grid_sizes)
230
+ return x.float()
231
+
232
+
233
+ def usp_attn_forward(self, x, grid_sizes, freqs, block_mask):
234
+
235
+ r"""
236
+ Args:
237
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
238
+ seq_lens(Tensor): Shape [B]
239
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
240
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
241
+ """
242
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
243
+ half_dtypes = (torch.float16, torch.bfloat16)
244
+
245
+ def half(x):
246
+ return x if x.dtype in half_dtypes else x.to(torch.bfloat16)
247
+
248
+ # query, key, value function
249
+ def qkv_fn(x):
250
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
251
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
252
+ v = self.v(x).view(b, s, n, d)
253
+ return q, k, v
254
+
255
+ x = x.to(self.q.weight.dtype)
256
+ q, k, v = qkv_fn(x)
257
+
258
+ if not self._flag_ar_attention:
259
+ q = rope_apply(q, grid_sizes, freqs)
260
+ k = rope_apply(k, grid_sizes, freqs)
261
+ else:
262
+
263
+ q = rope_apply(q, grid_sizes, freqs)
264
+ k = rope_apply(k, grid_sizes, freqs)
265
+ q = q.to(torch.bfloat16)
266
+ k = k.to(torch.bfloat16)
267
+ v = v.to(torch.bfloat16)
268
+ # x = torch.nn.functional.scaled_dot_product_attention(
269
+ # q.transpose(1, 2),
270
+ # k.transpose(1, 2),
271
+ # v.transpose(1, 2),
272
+ # ).transpose(1, 2).contiguous()
273
+ with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
274
+ x = (
275
+ torch.nn.functional.scaled_dot_product_attention(
276
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
277
+ )
278
+ .transpose(1, 2)
279
+ .contiguous()
280
+ )
281
+ x = xFuserLongContextAttention()(None, query=half(q), key=half(k), value=half(v), window_size=self.window_size)
282
+
283
+ # output
284
+ x = x.flatten(2)
285
+ x = self.o(x)
286
+ return x
skyreels_v2_infer/modules/__init__.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+
4
+ import torch
5
+ from safetensors.torch import load_file
6
+
7
+ from .clip import CLIPModel
8
+ from .t5 import T5EncoderModel
9
+ from .transformer import WanModel
10
+ from .vae import WanVAE
11
+
12
+
13
+ def download_model(model_id):
14
+ if not os.path.exists(model_id):
15
+ from huggingface_hub import snapshot_download
16
+
17
+ model_id = snapshot_download(repo_id=model_id)
18
+ return model_id
19
+
20
+
21
+ def get_vae(model_path, device="cuda", weight_dtype=torch.float32) -> WanVAE:
22
+ vae = WanVAE(model_path).to(device).to(weight_dtype)
23
+ vae.vae.requires_grad_(False)
24
+ vae.vae.eval()
25
+ gc.collect()
26
+ torch.cuda.empty_cache()
27
+ return vae
28
+
29
+
30
+ def get_transformer(model_path, device="cuda", weight_dtype=torch.bfloat16) -> WanModel:
31
+ config_path = os.path.join(model_path, "config.json")
32
+ transformer = WanModel.from_config(config_path).to(weight_dtype).to(device)
33
+
34
+ for file in os.listdir(model_path):
35
+ if file.endswith(".safetensors"):
36
+ file_path = os.path.join(model_path, file)
37
+ state_dict = load_file(file_path)
38
+ transformer.load_state_dict(state_dict, strict=False)
39
+ del state_dict
40
+ gc.collect()
41
+ torch.cuda.empty_cache()
42
+
43
+ transformer.requires_grad_(False)
44
+ transformer.eval()
45
+ gc.collect()
46
+ torch.cuda.empty_cache()
47
+ return transformer
48
+
49
+
50
+ def get_text_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16) -> T5EncoderModel:
51
+ t5_model = os.path.join(model_path, "models_t5_umt5-xxl-enc-bf16.pth")
52
+ tokenizer_path = os.path.join(model_path, "google", "umt5-xxl")
53
+ text_encoder = T5EncoderModel(checkpoint_path=t5_model, tokenizer_path=tokenizer_path).to(device).to(weight_dtype)
54
+ text_encoder.requires_grad_(False)
55
+ text_encoder.eval()
56
+ gc.collect()
57
+ torch.cuda.empty_cache()
58
+ return text_encoder
59
+
60
+
61
+ def get_image_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16) -> CLIPModel:
62
+ checkpoint_path = os.path.join(model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
63
+ tokenizer_path = os.path.join(model_path, "xlm-roberta-large")
64
+ image_enc = CLIPModel(checkpoint_path, tokenizer_path).to(weight_dtype).to(device)
65
+ image_enc.requires_grad_(False)
66
+ image_enc.eval()
67
+ gc.collect()
68
+ torch.cuda.empty_cache()
69
+ return image_enc
skyreels_v2_infer/modules/attention.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+
4
+ try:
5
+ import flash_attn_interface
6
+
7
+ FLASH_ATTN_3_AVAILABLE = True
8
+ except ModuleNotFoundError:
9
+ FLASH_ATTN_3_AVAILABLE = False
10
+
11
+ try:
12
+ import flash_attn
13
+
14
+ FLASH_ATTN_2_AVAILABLE = True
15
+ except ModuleNotFoundError:
16
+ FLASH_ATTN_2_AVAILABLE = False
17
+
18
+ import warnings
19
+
20
+ __all__ = [
21
+ "flash_attention",
22
+ "attention",
23
+ ]
24
+
25
+
26
+ def flash_attention(
27
+ q,
28
+ k,
29
+ v,
30
+ q_lens=None,
31
+ k_lens=None,
32
+ dropout_p=0.0,
33
+ softmax_scale=None,
34
+ q_scale=None,
35
+ causal=False,
36
+ window_size=(-1, -1),
37
+ deterministic=False,
38
+ dtype=torch.bfloat16,
39
+ version=None,
40
+ ):
41
+ """
42
+ q: [B, Lq, Nq, C1].
43
+ k: [B, Lk, Nk, C1].
44
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
45
+ q_lens: [B].
46
+ k_lens: [B].
47
+ dropout_p: float. Dropout probability.
48
+ softmax_scale: float. The scaling of QK^T before applying softmax.
49
+ causal: bool. Whether to apply causal attention mask.
50
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
51
+ deterministic: bool. If True, slightly slower and uses more memory.
52
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
53
+ """
54
+ half_dtypes = (torch.float16, torch.bfloat16)
55
+ assert dtype in half_dtypes
56
+ assert q.device.type == "cuda" and q.size(-1) <= 256
57
+
58
+ # params
59
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
60
+
61
+ def half(x):
62
+ return x if x.dtype in half_dtypes else x.to(dtype)
63
+
64
+ # preprocess query
65
+
66
+ q = half(q.flatten(0, 1))
67
+ q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
68
+
69
+ # preprocess key, value
70
+
71
+ k = half(k.flatten(0, 1))
72
+ v = half(v.flatten(0, 1))
73
+ k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(device=k.device, non_blocking=True)
74
+
75
+ q = q.to(v.dtype)
76
+ k = k.to(v.dtype)
77
+
78
+ if q_scale is not None:
79
+ q = q * q_scale
80
+
81
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
82
+ warnings.warn("Flash attention 3 is not available, use flash attention 2 instead.")
83
+
84
+ torch.cuda.nvtx.range_push(f"{list(q.shape)}-{list(k.shape)}-{list(v.shape)}-{q.dtype}-{k.dtype}-{v.dtype}")
85
+ # apply attention
86
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
87
+ # Note: dropout_p, window_size are not supported in FA3 now.
88
+ x = flash_attn_interface.flash_attn_varlen_func(
89
+ q=q,
90
+ k=k,
91
+ v=v,
92
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
93
+ .cumsum(0, dtype=torch.int32)
94
+ .to(q.device, non_blocking=True),
95
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
96
+ .cumsum(0, dtype=torch.int32)
97
+ .to(q.device, non_blocking=True),
98
+ seqused_q=None,
99
+ seqused_k=None,
100
+ max_seqlen_q=lq,
101
+ max_seqlen_k=lk,
102
+ softmax_scale=softmax_scale,
103
+ causal=causal,
104
+ deterministic=deterministic,
105
+ )[0].unflatten(0, (b, lq))
106
+ else:
107
+ assert FLASH_ATTN_2_AVAILABLE
108
+ x = flash_attn.flash_attn_varlen_func(
109
+ q=q,
110
+ k=k,
111
+ v=v,
112
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
113
+ .cumsum(0, dtype=torch.int32)
114
+ .to(q.device, non_blocking=True),
115
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
116
+ .cumsum(0, dtype=torch.int32)
117
+ .to(q.device, non_blocking=True),
118
+ max_seqlen_q=lq,
119
+ max_seqlen_k=lk,
120
+ dropout_p=dropout_p,
121
+ softmax_scale=softmax_scale,
122
+ causal=causal,
123
+ window_size=window_size,
124
+ deterministic=deterministic,
125
+ ).unflatten(0, (b, lq))
126
+ torch.cuda.nvtx.range_pop()
127
+
128
+ # output
129
+ return x
130
+
131
+
132
+ def attention(
133
+ q,
134
+ k,
135
+ v,
136
+ q_lens=None,
137
+ k_lens=None,
138
+ dropout_p=0.0,
139
+ softmax_scale=None,
140
+ q_scale=None,
141
+ causal=False,
142
+ window_size=(-1, -1),
143
+ deterministic=False,
144
+ dtype=torch.bfloat16,
145
+ fa_version=None,
146
+ ):
147
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
148
+ return flash_attention(
149
+ q=q,
150
+ k=k,
151
+ v=v,
152
+ q_lens=q_lens,
153
+ k_lens=k_lens,
154
+ dropout_p=dropout_p,
155
+ softmax_scale=softmax_scale,
156
+ q_scale=q_scale,
157
+ causal=causal,
158
+ window_size=window_size,
159
+ deterministic=deterministic,
160
+ dtype=dtype,
161
+ version=fa_version,
162
+ )
163
+ else:
164
+ if q_lens is not None or k_lens is not None:
165
+ warnings.warn(
166
+ "Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance."
167
+ )
168
+ attn_mask = None
169
+
170
+ q = q.transpose(1, 2).to(dtype)
171
+ k = k.transpose(1, 2).to(dtype)
172
+ v = v.transpose(1, 2).to(dtype)
173
+
174
+ out = torch.nn.functional.scaled_dot_product_attention(
175
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
176
+ )
177
+
178
+ out = out.transpose(1, 2).contiguous()
179
+ return out
skyreels_v2_infer/modules/clip.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import logging
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms as T
10
+ from diffusers.models import ModelMixin
11
+
12
+ from .attention import flash_attention
13
+ from .tokenizers import HuggingfaceTokenizer
14
+ from .xlm_roberta import XLMRoberta
15
+
16
+ __all__ = [
17
+ "XLMRobertaCLIP",
18
+ "clip_xlm_roberta_vit_h_14",
19
+ "CLIPModel",
20
+ ]
21
+
22
+
23
+ def pos_interpolate(pos, seq_len):
24
+ if pos.size(1) == seq_len:
25
+ return pos
26
+ else:
27
+ src_grid = int(math.sqrt(pos.size(1)))
28
+ tar_grid = int(math.sqrt(seq_len))
29
+ n = pos.size(1) - src_grid * src_grid
30
+ return torch.cat(
31
+ [
32
+ pos[:, :n],
33
+ F.interpolate(
34
+ pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(0, 3, 1, 2),
35
+ size=(tar_grid, tar_grid),
36
+ mode="bicubic",
37
+ align_corners=False,
38
+ )
39
+ .flatten(2)
40
+ .transpose(1, 2),
41
+ ],
42
+ dim=1,
43
+ )
44
+
45
+
46
+ class QuickGELU(nn.Module):
47
+ def forward(self, x):
48
+ return x * torch.sigmoid(1.702 * x)
49
+
50
+
51
+ class LayerNorm(nn.LayerNorm):
52
+ def forward(self, x):
53
+ return super().forward(x.float()).type_as(x)
54
+
55
+
56
+ class SelfAttention(nn.Module):
57
+ def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0):
58
+ assert dim % num_heads == 0
59
+ super().__init__()
60
+ self.dim = dim
61
+ self.num_heads = num_heads
62
+ self.head_dim = dim // num_heads
63
+ self.causal = causal
64
+ self.attn_dropout = attn_dropout
65
+ self.proj_dropout = proj_dropout
66
+
67
+ # layers
68
+ self.to_qkv = nn.Linear(dim, dim * 3)
69
+ self.proj = nn.Linear(dim, dim)
70
+
71
+ def forward(self, x):
72
+ """
73
+ x: [B, L, C].
74
+ """
75
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
76
+
77
+ # compute query, key, value
78
+ q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
79
+
80
+ # compute attention
81
+ p = self.attn_dropout if self.training else 0.0
82
+ x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
83
+ x = x.reshape(b, s, c)
84
+
85
+ # output
86
+ x = self.proj(x)
87
+ x = F.dropout(x, self.proj_dropout, self.training)
88
+ return x
89
+
90
+
91
+ class SwiGLU(nn.Module):
92
+ def __init__(self, dim, mid_dim):
93
+ super().__init__()
94
+ self.dim = dim
95
+ self.mid_dim = mid_dim
96
+
97
+ # layers
98
+ self.fc1 = nn.Linear(dim, mid_dim)
99
+ self.fc2 = nn.Linear(dim, mid_dim)
100
+ self.fc3 = nn.Linear(mid_dim, dim)
101
+
102
+ def forward(self, x):
103
+ x = F.silu(self.fc1(x)) * self.fc2(x)
104
+ x = self.fc3(x)
105
+ return x
106
+
107
+
108
+ class AttentionBlock(nn.Module):
109
+ def __init__(
110
+ self,
111
+ dim,
112
+ mlp_ratio,
113
+ num_heads,
114
+ post_norm=False,
115
+ causal=False,
116
+ activation="quick_gelu",
117
+ attn_dropout=0.0,
118
+ proj_dropout=0.0,
119
+ norm_eps=1e-5,
120
+ ):
121
+ assert activation in ["quick_gelu", "gelu", "swi_glu"]
122
+ super().__init__()
123
+ self.dim = dim
124
+ self.mlp_ratio = mlp_ratio
125
+ self.num_heads = num_heads
126
+ self.post_norm = post_norm
127
+ self.causal = causal
128
+ self.norm_eps = norm_eps
129
+
130
+ # layers
131
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
132
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout)
133
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
134
+ if activation == "swi_glu":
135
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
136
+ else:
137
+ self.mlp = nn.Sequential(
138
+ nn.Linear(dim, int(dim * mlp_ratio)),
139
+ QuickGELU() if activation == "quick_gelu" else nn.GELU(),
140
+ nn.Linear(int(dim * mlp_ratio), dim),
141
+ nn.Dropout(proj_dropout),
142
+ )
143
+
144
+ def forward(self, x):
145
+ if self.post_norm:
146
+ x = x + self.norm1(self.attn(x))
147
+ x = x + self.norm2(self.mlp(x))
148
+ else:
149
+ x = x + self.attn(self.norm1(x))
150
+ x = x + self.mlp(self.norm2(x))
151
+ return x
152
+
153
+
154
+ class AttentionPool(nn.Module):
155
+ def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5):
156
+ assert dim % num_heads == 0
157
+ super().__init__()
158
+ self.dim = dim
159
+ self.mlp_ratio = mlp_ratio
160
+ self.num_heads = num_heads
161
+ self.head_dim = dim // num_heads
162
+ self.proj_dropout = proj_dropout
163
+ self.norm_eps = norm_eps
164
+
165
+ # layers
166
+ gain = 1.0 / math.sqrt(dim)
167
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
168
+ self.to_q = nn.Linear(dim, dim)
169
+ self.to_kv = nn.Linear(dim, dim * 2)
170
+ self.proj = nn.Linear(dim, dim)
171
+ self.norm = LayerNorm(dim, eps=norm_eps)
172
+ self.mlp = nn.Sequential(
173
+ nn.Linear(dim, int(dim * mlp_ratio)),
174
+ QuickGELU() if activation == "quick_gelu" else nn.GELU(),
175
+ nn.Linear(int(dim * mlp_ratio), dim),
176
+ nn.Dropout(proj_dropout),
177
+ )
178
+
179
+ def forward(self, x):
180
+ """
181
+ x: [B, L, C].
182
+ """
183
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
184
+
185
+ # compute query, key, value
186
+ q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
187
+ k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
188
+
189
+ # compute attention
190
+ x = flash_attention(q, k, v, version=2)
191
+ x = x.reshape(b, 1, c)
192
+
193
+ # output
194
+ x = self.proj(x)
195
+ x = F.dropout(x, self.proj_dropout, self.training)
196
+
197
+ # mlp
198
+ x = x + self.mlp(self.norm(x))
199
+ return x[:, 0]
200
+
201
+
202
+ class VisionTransformer(nn.Module):
203
+ def __init__(
204
+ self,
205
+ image_size=224,
206
+ patch_size=16,
207
+ dim=768,
208
+ mlp_ratio=4,
209
+ out_dim=512,
210
+ num_heads=12,
211
+ num_layers=12,
212
+ pool_type="token",
213
+ pre_norm=True,
214
+ post_norm=False,
215
+ activation="quick_gelu",
216
+ attn_dropout=0.0,
217
+ proj_dropout=0.0,
218
+ embedding_dropout=0.0,
219
+ norm_eps=1e-5,
220
+ ):
221
+ if image_size % patch_size != 0:
222
+ print("[WARNING] image_size is not divisible by patch_size", flush=True)
223
+ assert pool_type in ("token", "token_fc", "attn_pool")
224
+ out_dim = out_dim or dim
225
+ super().__init__()
226
+ self.image_size = image_size
227
+ self.patch_size = patch_size
228
+ self.num_patches = (image_size // patch_size) ** 2
229
+ self.dim = dim
230
+ self.mlp_ratio = mlp_ratio
231
+ self.out_dim = out_dim
232
+ self.num_heads = num_heads
233
+ self.num_layers = num_layers
234
+ self.pool_type = pool_type
235
+ self.post_norm = post_norm
236
+ self.norm_eps = norm_eps
237
+
238
+ # embeddings
239
+ gain = 1.0 / math.sqrt(dim)
240
+ self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm)
241
+ if pool_type in ("token", "token_fc"):
242
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
243
+ self.pos_embedding = nn.Parameter(
244
+ gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim)
245
+ )
246
+ self.dropout = nn.Dropout(embedding_dropout)
247
+
248
+ # transformer
249
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
250
+ self.transformer = nn.Sequential(
251
+ *[
252
+ AttentionBlock(
253
+ dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps
254
+ )
255
+ for _ in range(num_layers)
256
+ ]
257
+ )
258
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
259
+
260
+ # head
261
+ if pool_type == "token":
262
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
263
+ elif pool_type == "token_fc":
264
+ self.head = nn.Linear(dim, out_dim)
265
+ elif pool_type == "attn_pool":
266
+ self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps)
267
+
268
+ def forward(self, x, interpolation=False, use_31_block=False):
269
+ b = x.size(0)
270
+
271
+ # embeddings
272
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
273
+ if self.pool_type in ("token", "token_fc"):
274
+ x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
275
+ if interpolation:
276
+ e = pos_interpolate(self.pos_embedding, x.size(1))
277
+ else:
278
+ e = self.pos_embedding
279
+ x = self.dropout(x + e)
280
+ if self.pre_norm is not None:
281
+ x = self.pre_norm(x)
282
+
283
+ # transformer
284
+ if use_31_block:
285
+ x = self.transformer[:-1](x)
286
+ return x
287
+ else:
288
+ x = self.transformer(x)
289
+ return x
290
+
291
+
292
+ class XLMRobertaWithHead(XLMRoberta):
293
+ def __init__(self, **kwargs):
294
+ self.out_dim = kwargs.pop("out_dim")
295
+ super().__init__(**kwargs)
296
+
297
+ # head
298
+ mid_dim = (self.dim + self.out_dim) // 2
299
+ self.head = nn.Sequential(
300
+ nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), nn.Linear(mid_dim, self.out_dim, bias=False)
301
+ )
302
+
303
+ def forward(self, ids):
304
+ # xlm-roberta
305
+ x = super().forward(ids)
306
+
307
+ # average pooling
308
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
309
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
310
+
311
+ # head
312
+ x = self.head(x)
313
+ return x
314
+
315
+
316
+ class XLMRobertaCLIP(nn.Module):
317
+ def __init__(
318
+ self,
319
+ embed_dim=1024,
320
+ image_size=224,
321
+ patch_size=14,
322
+ vision_dim=1280,
323
+ vision_mlp_ratio=4,
324
+ vision_heads=16,
325
+ vision_layers=32,
326
+ vision_pool="token",
327
+ vision_pre_norm=True,
328
+ vision_post_norm=False,
329
+ activation="gelu",
330
+ vocab_size=250002,
331
+ max_text_len=514,
332
+ type_size=1,
333
+ pad_id=1,
334
+ text_dim=1024,
335
+ text_heads=16,
336
+ text_layers=24,
337
+ text_post_norm=True,
338
+ text_dropout=0.1,
339
+ attn_dropout=0.0,
340
+ proj_dropout=0.0,
341
+ embedding_dropout=0.0,
342
+ norm_eps=1e-5,
343
+ ):
344
+ super().__init__()
345
+ self.embed_dim = embed_dim
346
+ self.image_size = image_size
347
+ self.patch_size = patch_size
348
+ self.vision_dim = vision_dim
349
+ self.vision_mlp_ratio = vision_mlp_ratio
350
+ self.vision_heads = vision_heads
351
+ self.vision_layers = vision_layers
352
+ self.vision_pre_norm = vision_pre_norm
353
+ self.vision_post_norm = vision_post_norm
354
+ self.activation = activation
355
+ self.vocab_size = vocab_size
356
+ self.max_text_len = max_text_len
357
+ self.type_size = type_size
358
+ self.pad_id = pad_id
359
+ self.text_dim = text_dim
360
+ self.text_heads = text_heads
361
+ self.text_layers = text_layers
362
+ self.text_post_norm = text_post_norm
363
+ self.norm_eps = norm_eps
364
+
365
+ # models
366
+ self.visual = VisionTransformer(
367
+ image_size=image_size,
368
+ patch_size=patch_size,
369
+ dim=vision_dim,
370
+ mlp_ratio=vision_mlp_ratio,
371
+ out_dim=embed_dim,
372
+ num_heads=vision_heads,
373
+ num_layers=vision_layers,
374
+ pool_type=vision_pool,
375
+ pre_norm=vision_pre_norm,
376
+ post_norm=vision_post_norm,
377
+ activation=activation,
378
+ attn_dropout=attn_dropout,
379
+ proj_dropout=proj_dropout,
380
+ embedding_dropout=embedding_dropout,
381
+ norm_eps=norm_eps,
382
+ )
383
+ self.textual = XLMRobertaWithHead(
384
+ vocab_size=vocab_size,
385
+ max_seq_len=max_text_len,
386
+ type_size=type_size,
387
+ pad_id=pad_id,
388
+ dim=text_dim,
389
+ out_dim=embed_dim,
390
+ num_heads=text_heads,
391
+ num_layers=text_layers,
392
+ post_norm=text_post_norm,
393
+ dropout=text_dropout,
394
+ )
395
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
396
+
397
+ def forward(self, imgs, txt_ids):
398
+ """
399
+ imgs: [B, 3, H, W] of torch.float32.
400
+ - mean: [0.48145466, 0.4578275, 0.40821073]
401
+ - std: [0.26862954, 0.26130258, 0.27577711]
402
+ txt_ids: [B, L] of torch.long.
403
+ Encoded by data.CLIPTokenizer.
404
+ """
405
+ xi = self.visual(imgs)
406
+ xt = self.textual(txt_ids)
407
+ return xi, xt
408
+
409
+ def param_groups(self):
410
+ groups = [
411
+ {
412
+ "params": [p for n, p in self.named_parameters() if "norm" in n or n.endswith("bias")],
413
+ "weight_decay": 0.0,
414
+ },
415
+ {"params": [p for n, p in self.named_parameters() if not ("norm" in n or n.endswith("bias"))]},
416
+ ]
417
+ return groups
418
+
419
+
420
+ def _clip(
421
+ pretrained=False,
422
+ pretrained_name=None,
423
+ model_cls=XLMRobertaCLIP,
424
+ return_transforms=False,
425
+ return_tokenizer=False,
426
+ tokenizer_padding="eos",
427
+ dtype=torch.float32,
428
+ device="cpu",
429
+ **kwargs,
430
+ ):
431
+ # init a model on device
432
+ with torch.device(device):
433
+ model = model_cls(**kwargs)
434
+
435
+ # set device
436
+ model = model.to(dtype=dtype, device=device)
437
+ output = (model,)
438
+
439
+ # init transforms
440
+ if return_transforms:
441
+ # mean and std
442
+ if "siglip" in pretrained_name.lower():
443
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
444
+ else:
445
+ mean = [0.48145466, 0.4578275, 0.40821073]
446
+ std = [0.26862954, 0.26130258, 0.27577711]
447
+
448
+ # transforms
449
+ transforms = T.Compose(
450
+ [
451
+ T.Resize((model.image_size, model.image_size), interpolation=T.InterpolationMode.BICUBIC),
452
+ T.ToTensor(),
453
+ T.Normalize(mean=mean, std=std),
454
+ ]
455
+ )
456
+ output += (transforms,)
457
+ return output[0] if len(output) == 1 else output
458
+
459
+
460
+ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-roberta-large-vit-huge-14", **kwargs):
461
+ cfg = dict(
462
+ embed_dim=1024,
463
+ image_size=224,
464
+ patch_size=14,
465
+ vision_dim=1280,
466
+ vision_mlp_ratio=4,
467
+ vision_heads=16,
468
+ vision_layers=32,
469
+ vision_pool="token",
470
+ activation="gelu",
471
+ vocab_size=250002,
472
+ max_text_len=514,
473
+ type_size=1,
474
+ pad_id=1,
475
+ text_dim=1024,
476
+ text_heads=16,
477
+ text_layers=24,
478
+ text_post_norm=True,
479
+ text_dropout=0.1,
480
+ attn_dropout=0.0,
481
+ proj_dropout=0.0,
482
+ embedding_dropout=0.0,
483
+ )
484
+ cfg.update(**kwargs)
485
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
486
+
487
+
488
+ class CLIPModel(ModelMixin):
489
+ def __init__(self, checkpoint_path, tokenizer_path):
490
+ self.checkpoint_path = checkpoint_path
491
+ self.tokenizer_path = tokenizer_path
492
+
493
+ super().__init__()
494
+ # init model
495
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
496
+ pretrained=False, return_transforms=True, return_tokenizer=False
497
+ )
498
+ self.model = self.model.eval().requires_grad_(False)
499
+ logging.info(f"loading {checkpoint_path}")
500
+ self.model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
501
+
502
+ # init tokenizer
503
+ self.tokenizer = HuggingfaceTokenizer(
504
+ name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean="whitespace"
505
+ )
506
+
507
+ def encode_video(self, video):
508
+ # preprocess
509
+ b, c, t, h, w = video.shape
510
+ video = video.transpose(1, 2)
511
+ video = video.reshape(b * t, c, h, w)
512
+ size = (self.model.image_size,) * 2
513
+ video = F.interpolate(
514
+ video,
515
+ size=size,
516
+ mode='bicubic',
517
+ align_corners=False)
518
+
519
+ video = self.transforms.transforms[-1](video.mul_(0.5).add_(0.5))
520
+
521
+ # forward
522
+ with torch.amp.autocast(dtype=self.dtype, device_type=self.device.type):
523
+ out = self.model.visual(video, use_31_block=True)
524
+
525
+ return out
skyreels_v2_infer/modules/t5.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from transformers.models.t5.modeling_t5
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import logging
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from diffusers.models import ModelMixin
10
+
11
+ from .tokenizers import HuggingfaceTokenizer
12
+
13
+ __all__ = [
14
+ "T5Model",
15
+ "T5Encoder",
16
+ "T5Decoder",
17
+ "T5EncoderModel",
18
+ ]
19
+
20
+
21
+ def fp16_clamp(x):
22
+ if x.dtype == torch.float16 and torch.isinf(x).any():
23
+ clamp = torch.finfo(x.dtype).max - 1000
24
+ x = torch.clamp(x, min=-clamp, max=clamp)
25
+ return x
26
+
27
+
28
+ def init_weights(m):
29
+ if isinstance(m, T5LayerNorm):
30
+ nn.init.ones_(m.weight)
31
+ elif isinstance(m, T5Model):
32
+ nn.init.normal_(m.token_embedding.weight, std=1.0)
33
+ elif isinstance(m, T5FeedForward):
34
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
35
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
36
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
37
+ elif isinstance(m, T5Attention):
38
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
39
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
40
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
41
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
42
+ elif isinstance(m, T5RelativeEmbedding):
43
+ nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5)
44
+
45
+
46
+ class GELU(nn.Module):
47
+ def forward(self, x):
48
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
49
+
50
+
51
+ class T5LayerNorm(nn.Module):
52
+ def __init__(self, dim, eps=1e-6):
53
+ super(T5LayerNorm, self).__init__()
54
+ self.dim = dim
55
+ self.eps = eps
56
+ self.weight = nn.Parameter(torch.ones(dim))
57
+
58
+ def forward(self, x):
59
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
60
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
61
+ x = x.type_as(self.weight)
62
+ return self.weight * x
63
+
64
+
65
+ class T5Attention(nn.Module):
66
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
67
+ assert dim_attn % num_heads == 0
68
+ super(T5Attention, self).__init__()
69
+ self.dim = dim
70
+ self.dim_attn = dim_attn
71
+ self.num_heads = num_heads
72
+ self.head_dim = dim_attn // num_heads
73
+
74
+ # layers
75
+ self.q = nn.Linear(dim, dim_attn, bias=False)
76
+ self.k = nn.Linear(dim, dim_attn, bias=False)
77
+ self.v = nn.Linear(dim, dim_attn, bias=False)
78
+ self.o = nn.Linear(dim_attn, dim, bias=False)
79
+ self.dropout = nn.Dropout(dropout)
80
+
81
+ def forward(self, x, context=None, mask=None, pos_bias=None):
82
+ """
83
+ x: [B, L1, C].
84
+ context: [B, L2, C] or None.
85
+ mask: [B, L2] or [B, L1, L2] or None.
86
+ """
87
+ # check inputs
88
+ context = x if context is None else context
89
+ b, n, c = x.size(0), self.num_heads, self.head_dim
90
+
91
+ # compute query, key, value
92
+ q = self.q(x).view(b, -1, n, c)
93
+ k = self.k(context).view(b, -1, n, c)
94
+ v = self.v(context).view(b, -1, n, c)
95
+
96
+ # attention bias
97
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
98
+ if pos_bias is not None:
99
+ attn_bias += pos_bias
100
+ if mask is not None:
101
+ assert mask.ndim in [2, 3]
102
+ mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
103
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
104
+
105
+ # compute attention (T5 does not use scaling)
106
+ attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
107
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
108
+ x = torch.einsum("bnij,bjnc->binc", attn, v)
109
+
110
+ # output
111
+ x = x.reshape(b, -1, n * c)
112
+ x = self.o(x)
113
+ x = self.dropout(x)
114
+ return x
115
+
116
+
117
+ class T5FeedForward(nn.Module):
118
+ def __init__(self, dim, dim_ffn, dropout=0.1):
119
+ super(T5FeedForward, self).__init__()
120
+ self.dim = dim
121
+ self.dim_ffn = dim_ffn
122
+
123
+ # layers
124
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
125
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
126
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
127
+ self.dropout = nn.Dropout(dropout)
128
+
129
+ def forward(self, x):
130
+ x = self.fc1(x) * self.gate(x)
131
+ x = self.dropout(x)
132
+ x = self.fc2(x)
133
+ x = self.dropout(x)
134
+ return x
135
+
136
+
137
+ class T5SelfAttention(nn.Module):
138
+ def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1):
139
+ super(T5SelfAttention, self).__init__()
140
+ self.dim = dim
141
+ self.dim_attn = dim_attn
142
+ self.dim_ffn = dim_ffn
143
+ self.num_heads = num_heads
144
+ self.num_buckets = num_buckets
145
+ self.shared_pos = shared_pos
146
+
147
+ # layers
148
+ self.norm1 = T5LayerNorm(dim)
149
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
150
+ self.norm2 = T5LayerNorm(dim)
151
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
152
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
153
+
154
+ def forward(self, x, mask=None, pos_bias=None):
155
+ e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
156
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
157
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
158
+ return x
159
+
160
+
161
+ class T5CrossAttention(nn.Module):
162
+ def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1):
163
+ super(T5CrossAttention, self).__init__()
164
+ self.dim = dim
165
+ self.dim_attn = dim_attn
166
+ self.dim_ffn = dim_ffn
167
+ self.num_heads = num_heads
168
+ self.num_buckets = num_buckets
169
+ self.shared_pos = shared_pos
170
+
171
+ # layers
172
+ self.norm1 = T5LayerNorm(dim)
173
+ self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
174
+ self.norm2 = T5LayerNorm(dim)
175
+ self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
176
+ self.norm3 = T5LayerNorm(dim)
177
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
178
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
179
+
180
+ def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None):
181
+ e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
182
+ x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
183
+ x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask))
184
+ x = fp16_clamp(x + self.ffn(self.norm3(x)))
185
+ return x
186
+
187
+
188
+ class T5RelativeEmbedding(nn.Module):
189
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
190
+ super(T5RelativeEmbedding, self).__init__()
191
+ self.num_buckets = num_buckets
192
+ self.num_heads = num_heads
193
+ self.bidirectional = bidirectional
194
+ self.max_dist = max_dist
195
+
196
+ # layers
197
+ self.embedding = nn.Embedding(num_buckets, num_heads)
198
+
199
+ def forward(self, lq, lk):
200
+ device = self.embedding.weight.device
201
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
202
+ # torch.arange(lq).unsqueeze(1).to(device)
203
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1)
204
+ rel_pos = self._relative_position_bucket(rel_pos)
205
+ rel_pos_embeds = self.embedding(rel_pos)
206
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
207
+ return rel_pos_embeds.contiguous()
208
+
209
+ def _relative_position_bucket(self, rel_pos):
210
+ # preprocess
211
+ if self.bidirectional:
212
+ num_buckets = self.num_buckets // 2
213
+ rel_buckets = (rel_pos > 0).long() * num_buckets
214
+ rel_pos = torch.abs(rel_pos)
215
+ else:
216
+ num_buckets = self.num_buckets
217
+ rel_buckets = 0
218
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
219
+
220
+ # embeddings for small and large positions
221
+ max_exact = num_buckets // 2
222
+ rel_pos_large = (
223
+ max_exact
224
+ + (
225
+ torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)
226
+ ).long()
227
+ )
228
+ rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
229
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
230
+ return rel_buckets
231
+
232
+
233
+ class T5Encoder(nn.Module):
234
+ def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1):
235
+ super(T5Encoder, self).__init__()
236
+ self.dim = dim
237
+ self.dim_attn = dim_attn
238
+ self.dim_ffn = dim_ffn
239
+ self.num_heads = num_heads
240
+ self.num_layers = num_layers
241
+ self.num_buckets = num_buckets
242
+ self.shared_pos = shared_pos
243
+
244
+ # layers
245
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
246
+ self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
247
+ self.dropout = nn.Dropout(dropout)
248
+ self.blocks = nn.ModuleList(
249
+ [
250
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout)
251
+ for _ in range(num_layers)
252
+ ]
253
+ )
254
+ self.norm = T5LayerNorm(dim)
255
+
256
+ # initialize weights
257
+ self.apply(init_weights)
258
+
259
+ def forward(self, ids, mask=None):
260
+ x = self.token_embedding(ids)
261
+ x = self.dropout(x)
262
+ e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
263
+ for block in self.blocks:
264
+ x = block(x, mask, pos_bias=e)
265
+ x = self.norm(x)
266
+ x = self.dropout(x)
267
+ return x
268
+
269
+
270
+ class T5Decoder(nn.Module):
271
+ def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1):
272
+ super(T5Decoder, self).__init__()
273
+ self.dim = dim
274
+ self.dim_attn = dim_attn
275
+ self.dim_ffn = dim_ffn
276
+ self.num_heads = num_heads
277
+ self.num_layers = num_layers
278
+ self.num_buckets = num_buckets
279
+ self.shared_pos = shared_pos
280
+
281
+ # layers
282
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
283
+ self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None
284
+ self.dropout = nn.Dropout(dropout)
285
+ self.blocks = nn.ModuleList(
286
+ [
287
+ T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout)
288
+ for _ in range(num_layers)
289
+ ]
290
+ )
291
+ self.norm = T5LayerNorm(dim)
292
+
293
+ # initialize weights
294
+ self.apply(init_weights)
295
+
296
+ def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
297
+ b, s = ids.size()
298
+
299
+ # causal mask
300
+ if mask is None:
301
+ mask = torch.tril(torch.ones(1, s, s).to(ids.device))
302
+ elif mask.ndim == 2:
303
+ mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
304
+
305
+ # layers
306
+ x = self.token_embedding(ids)
307
+ x = self.dropout(x)
308
+ e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
309
+ for block in self.blocks:
310
+ x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
311
+ x = self.norm(x)
312
+ x = self.dropout(x)
313
+ return x
314
+
315
+
316
+ class T5Model(nn.Module):
317
+ def __init__(
318
+ self,
319
+ vocab_size,
320
+ dim,
321
+ dim_attn,
322
+ dim_ffn,
323
+ num_heads,
324
+ encoder_layers,
325
+ decoder_layers,
326
+ num_buckets,
327
+ shared_pos=True,
328
+ dropout=0.1,
329
+ ):
330
+ super(T5Model, self).__init__()
331
+ self.vocab_size = vocab_size
332
+ self.dim = dim
333
+ self.dim_attn = dim_attn
334
+ self.dim_ffn = dim_ffn
335
+ self.num_heads = num_heads
336
+ self.encoder_layers = encoder_layers
337
+ self.decoder_layers = decoder_layers
338
+ self.num_buckets = num_buckets
339
+
340
+ # layers
341
+ self.token_embedding = nn.Embedding(vocab_size, dim)
342
+ self.encoder = T5Encoder(
343
+ self.token_embedding, dim, dim_attn, dim_ffn, num_heads, encoder_layers, num_buckets, shared_pos, dropout
344
+ )
345
+ self.decoder = T5Decoder(
346
+ self.token_embedding, dim, dim_attn, dim_ffn, num_heads, decoder_layers, num_buckets, shared_pos, dropout
347
+ )
348
+ self.head = nn.Linear(dim, vocab_size, bias=False)
349
+
350
+ # initialize weights
351
+ self.apply(init_weights)
352
+
353
+ def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
354
+ x = self.encoder(encoder_ids, encoder_mask)
355
+ x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
356
+ x = self.head(x)
357
+ return x
358
+
359
+
360
+ def _t5(
361
+ name,
362
+ encoder_only=False,
363
+ decoder_only=False,
364
+ return_tokenizer=False,
365
+ tokenizer_kwargs={},
366
+ dtype=torch.float32,
367
+ device="cpu",
368
+ **kwargs,
369
+ ):
370
+ # sanity check
371
+ assert not (encoder_only and decoder_only)
372
+
373
+ # params
374
+ if encoder_only:
375
+ model_cls = T5Encoder
376
+ kwargs["vocab"] = kwargs.pop("vocab_size")
377
+ kwargs["num_layers"] = kwargs.pop("encoder_layers")
378
+ _ = kwargs.pop("decoder_layers")
379
+ elif decoder_only:
380
+ model_cls = T5Decoder
381
+ kwargs["vocab"] = kwargs.pop("vocab_size")
382
+ kwargs["num_layers"] = kwargs.pop("decoder_layers")
383
+ _ = kwargs.pop("encoder_layers")
384
+ else:
385
+ model_cls = T5Model
386
+
387
+ # init model
388
+ with torch.device(device):
389
+ model = model_cls(**kwargs)
390
+
391
+ # set device
392
+ model = model.to(dtype=dtype, device=device)
393
+
394
+ # init tokenizer
395
+ if return_tokenizer:
396
+ from .tokenizers import HuggingfaceTokenizer
397
+
398
+ tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs)
399
+ return model, tokenizer
400
+ else:
401
+ return model
402
+
403
+
404
+ def umt5_xxl(**kwargs):
405
+ cfg = dict(
406
+ vocab_size=256384,
407
+ dim=4096,
408
+ dim_attn=4096,
409
+ dim_ffn=10240,
410
+ num_heads=64,
411
+ encoder_layers=24,
412
+ decoder_layers=24,
413
+ num_buckets=32,
414
+ shared_pos=False,
415
+ dropout=0.1,
416
+ )
417
+ cfg.update(**kwargs)
418
+ return _t5("umt5-xxl", **cfg)
419
+
420
+
421
+ class T5EncoderModel(ModelMixin):
422
+ def __init__(
423
+ self,
424
+ checkpoint_path=None,
425
+ tokenizer_path=None,
426
+ text_len=512,
427
+ shard_fn=None,
428
+ ):
429
+ self.text_len = text_len
430
+ self.checkpoint_path = checkpoint_path
431
+ self.tokenizer_path = tokenizer_path
432
+
433
+ super().__init__()
434
+ # init model
435
+ model = umt5_xxl(encoder_only=True, return_tokenizer=False)
436
+ logging.info(f"loading {checkpoint_path}")
437
+ model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
438
+ self.model = model
439
+ if shard_fn is not None:
440
+ self.model = shard_fn(self.model, sync_module_states=False)
441
+ else:
442
+ self.model.eval().requires_grad_(False)
443
+ # init tokenizer
444
+ self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace")
445
+
446
+ def encode(self, texts):
447
+ ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
448
+ ids = ids.to(self.device)
449
+ mask = mask.to(self.device)
450
+ # seq_lens = mask.gt(0).sum(dim=1).long()
451
+ context = self.model(ids, mask)
452
+ context = context * mask.unsqueeze(-1).cuda()
453
+
454
+ return context
skyreels_v2_infer/modules/tokenizers.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import html
3
+ import string
4
+
5
+ import ftfy
6
+ import regex as re
7
+ from transformers import AutoTokenizer
8
+
9
+ __all__ = ["HuggingfaceTokenizer"]
10
+
11
+
12
+ def basic_clean(text):
13
+ text = ftfy.fix_text(text)
14
+ text = html.unescape(html.unescape(text))
15
+ return text.strip()
16
+
17
+
18
+ def whitespace_clean(text):
19
+ text = re.sub(r"\s+", " ", text)
20
+ text = text.strip()
21
+ return text
22
+
23
+
24
+ def canonicalize(text, keep_punctuation_exact_string=None):
25
+ text = text.replace("_", " ")
26
+ if keep_punctuation_exact_string:
27
+ text = keep_punctuation_exact_string.join(
28
+ part.translate(str.maketrans("", "", string.punctuation))
29
+ for part in text.split(keep_punctuation_exact_string)
30
+ )
31
+ else:
32
+ text = text.translate(str.maketrans("", "", string.punctuation))
33
+ text = text.lower()
34
+ text = re.sub(r"\s+", " ", text)
35
+ return text.strip()
36
+
37
+
38
+ class HuggingfaceTokenizer:
39
+ def __init__(self, name, seq_len=None, clean=None, **kwargs):
40
+ assert clean in (None, "whitespace", "lower", "canonicalize")
41
+ self.name = name
42
+ self.seq_len = seq_len
43
+ self.clean = clean
44
+
45
+ # init tokenizer
46
+ self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
47
+ self.vocab_size = self.tokenizer.vocab_size
48
+
49
+ def __call__(self, sequence, **kwargs):
50
+ return_mask = kwargs.pop("return_mask", False)
51
+
52
+ # arguments
53
+ _kwargs = {"return_tensors": "pt"}
54
+ if self.seq_len is not None:
55
+ _kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len})
56
+ _kwargs.update(**kwargs)
57
+
58
+ # tokenization
59
+ if isinstance(sequence, str):
60
+ sequence = [sequence]
61
+ if self.clean:
62
+ sequence = [self._clean(u) for u in sequence]
63
+ ids = self.tokenizer(sequence, **_kwargs)
64
+
65
+ # output
66
+ if return_mask:
67
+ return ids.input_ids, ids.attention_mask
68
+ else:
69
+ return ids.input_ids
70
+
71
+ def _clean(self, text):
72
+ if self.clean == "whitespace":
73
+ text = whitespace_clean(basic_clean(text))
74
+ elif self.clean == "lower":
75
+ text = whitespace_clean(basic_clean(text)).lower()
76
+ elif self.clean == "canonicalize":
77
+ text = canonicalize(basic_clean(text))
78
+ return text
skyreels_v2_infer/modules/transformer.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ import torch.amp as amp
6
+ import torch.nn as nn
7
+ from diffusers.configuration_utils import ConfigMixin
8
+ from diffusers.configuration_utils import register_to_config
9
+ from diffusers.loaders import PeftAdapterMixin
10
+ from diffusers.models.modeling_utils import ModelMixin
11
+ from torch.backends.cuda import sdp_kernel
12
+ from torch.nn.attention.flex_attention import BlockMask
13
+ from torch.nn.attention.flex_attention import create_block_mask
14
+ from torch.nn.attention.flex_attention import flex_attention
15
+
16
+ from .attention import flash_attention
17
+
18
+
19
+ flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune")
20
+
21
+ DISABLE_COMPILE = False # get os env
22
+
23
+ __all__ = ["WanModel"]
24
+
25
+
26
+ def sinusoidal_embedding_1d(dim, position):
27
+ # preprocess
28
+ assert dim % 2 == 0
29
+ half = dim // 2
30
+ position = position.type(torch.float64)
31
+
32
+ # calculation
33
+ sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
34
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
35
+ return x
36
+
37
+
38
+ @amp.autocast("cuda", enabled=False)
39
+ def rope_params(max_seq_len, dim, theta=10000):
40
+ assert dim % 2 == 0
41
+ freqs = torch.outer(
42
+ torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))
43
+ )
44
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
45
+ return freqs
46
+
47
+
48
+ @amp.autocast("cuda", enabled=False)
49
+ def rope_apply(x, grid_sizes, freqs):
50
+ n, c = x.size(2), x.size(3) // 2
51
+ bs = x.size(0)
52
+
53
+ # split freqs
54
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
55
+
56
+ # loop over samples
57
+ f, h, w = grid_sizes.tolist()
58
+ seq_len = f * h * w
59
+
60
+ # precompute multipliers
61
+
62
+ x = torch.view_as_complex(x.to(torch.float32).reshape(bs, seq_len, n, -1, 2))
63
+ freqs_i = torch.cat(
64
+ [
65
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
66
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
67
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
68
+ ],
69
+ dim=-1,
70
+ ).reshape(seq_len, 1, -1)
71
+
72
+ # apply rotary embedding
73
+ x = torch.view_as_real(x * freqs_i).flatten(3)
74
+
75
+ return x
76
+
77
+
78
+ @torch.compile(dynamic=True, disable=DISABLE_COMPILE)
79
+ def fast_rms_norm(x, weight, eps):
80
+ x = x.float()
81
+ x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps)
82
+ x = x.type_as(x) * weight
83
+ return x
84
+
85
+
86
+ class WanRMSNorm(nn.Module):
87
+ def __init__(self, dim, eps=1e-5):
88
+ super().__init__()
89
+ self.dim = dim
90
+ self.eps = eps
91
+ self.weight = nn.Parameter(torch.ones(dim))
92
+
93
+ def forward(self, x):
94
+ r"""
95
+ Args:
96
+ x(Tensor): Shape [B, L, C]
97
+ """
98
+ return fast_rms_norm(x, self.weight, self.eps)
99
+
100
+ def _norm(self, x):
101
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
102
+
103
+
104
+ class WanLayerNorm(nn.LayerNorm):
105
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
106
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
107
+
108
+ def forward(self, x):
109
+ r"""
110
+ Args:
111
+ x(Tensor): Shape [B, L, C]
112
+ """
113
+ return super().forward(x)
114
+
115
+
116
+ class WanSelfAttention(nn.Module):
117
+ def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
118
+ assert dim % num_heads == 0
119
+ super().__init__()
120
+ self.dim = dim
121
+ self.num_heads = num_heads
122
+ self.head_dim = dim // num_heads
123
+ self.window_size = window_size
124
+ self.qk_norm = qk_norm
125
+ self.eps = eps
126
+
127
+ # layers
128
+ self.q = nn.Linear(dim, dim)
129
+ self.k = nn.Linear(dim, dim)
130
+ self.v = nn.Linear(dim, dim)
131
+ self.o = nn.Linear(dim, dim)
132
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
133
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
134
+
135
+ self._flag_ar_attention = False
136
+
137
+ def set_ar_attention(self):
138
+ self._flag_ar_attention = True
139
+
140
+ def forward(self, x, grid_sizes, freqs, block_mask):
141
+ r"""
142
+ Args:
143
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
144
+ seq_lens(Tensor): Shape [B]
145
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
146
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
147
+ """
148
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
149
+
150
+ # query, key, value function
151
+ def qkv_fn(x):
152
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
153
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
154
+ v = self.v(x).view(b, s, n, d)
155
+ return q, k, v
156
+
157
+ x = x.to(self.q.weight.dtype)
158
+ q, k, v = qkv_fn(x)
159
+
160
+ if not self._flag_ar_attention:
161
+ q = rope_apply(q, grid_sizes, freqs)
162
+ k = rope_apply(k, grid_sizes, freqs)
163
+ x = flash_attention(q=q, k=k, v=v, window_size=self.window_size)
164
+ else:
165
+ q = rope_apply(q, grid_sizes, freqs)
166
+ k = rope_apply(k, grid_sizes, freqs)
167
+ q = q.to(torch.bfloat16)
168
+ k = k.to(torch.bfloat16)
169
+ v = v.to(torch.bfloat16)
170
+
171
+ with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
172
+ x = (
173
+ torch.nn.functional.scaled_dot_product_attention(
174
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
175
+ )
176
+ .transpose(1, 2)
177
+ .contiguous()
178
+ )
179
+
180
+ # output
181
+ x = x.flatten(2)
182
+ x = self.o(x)
183
+ return x
184
+
185
+
186
+ class WanT2VCrossAttention(WanSelfAttention):
187
+ def forward(self, x, context):
188
+ r"""
189
+ Args:
190
+ x(Tensor): Shape [B, L1, C]
191
+ context(Tensor): Shape [B, L2, C]
192
+ context_lens(Tensor): Shape [B]
193
+ """
194
+ b, n, d = x.size(0), self.num_heads, self.head_dim
195
+
196
+ # compute query, key, value
197
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
198
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
199
+ v = self.v(context).view(b, -1, n, d)
200
+
201
+ # compute attention
202
+ x = flash_attention(q, k, v)
203
+
204
+ # output
205
+ x = x.flatten(2)
206
+ x = self.o(x)
207
+ return x
208
+
209
+
210
+ class WanI2VCrossAttention(WanSelfAttention):
211
+ def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
212
+ super().__init__(dim, num_heads, window_size, qk_norm, eps)
213
+
214
+ self.k_img = nn.Linear(dim, dim)
215
+ self.v_img = nn.Linear(dim, dim)
216
+ # self.alpha = nn.Parameter(torch.zeros((1, )))
217
+ self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
218
+
219
+ def forward(self, x, context):
220
+ r"""
221
+ Args:
222
+ x(Tensor): Shape [B, L1, C]
223
+ context(Tensor): Shape [B, L2, C]
224
+ context_lens(Tensor): Shape [B]
225
+ """
226
+ context_img = context[:, :257]
227
+ context = context[:, 257:]
228
+ b, n, d = x.size(0), self.num_heads, self.head_dim
229
+
230
+ # compute query, key, value
231
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
232
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
233
+ v = self.v(context).view(b, -1, n, d)
234
+ k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
235
+ v_img = self.v_img(context_img).view(b, -1, n, d)
236
+ img_x = flash_attention(q, k_img, v_img)
237
+ # compute attention
238
+ x = flash_attention(q, k, v)
239
+
240
+ # output
241
+ x = x.flatten(2)
242
+ img_x = img_x.flatten(2)
243
+ x = x + img_x
244
+ x = self.o(x)
245
+ return x
246
+
247
+
248
+ WAN_CROSSATTENTION_CLASSES = {
249
+ "t2v_cross_attn": WanT2VCrossAttention,
250
+ "i2v_cross_attn": WanI2VCrossAttention,
251
+ }
252
+
253
+
254
+ def mul_add(x, y, z):
255
+ return x.float() + y.float() * z.float()
256
+
257
+
258
+ def mul_add_add(x, y, z):
259
+ return x.float() * (1 + y) + z
260
+
261
+
262
+ mul_add_compile = torch.compile(mul_add, dynamic=True, disable=DISABLE_COMPILE)
263
+ mul_add_add_compile = torch.compile(mul_add_add, dynamic=True, disable=DISABLE_COMPILE)
264
+
265
+
266
+ class WanAttentionBlock(nn.Module):
267
+ def __init__(
268
+ self,
269
+ cross_attn_type,
270
+ dim,
271
+ ffn_dim,
272
+ num_heads,
273
+ window_size=(-1, -1),
274
+ qk_norm=True,
275
+ cross_attn_norm=False,
276
+ eps=1e-6,
277
+ ):
278
+ super().__init__()
279
+ self.dim = dim
280
+ self.ffn_dim = ffn_dim
281
+ self.num_heads = num_heads
282
+ self.window_size = window_size
283
+ self.qk_norm = qk_norm
284
+ self.cross_attn_norm = cross_attn_norm
285
+ self.eps = eps
286
+
287
+ # layers
288
+ self.norm1 = WanLayerNorm(dim, eps)
289
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
290
+ self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
291
+ self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, num_heads, (-1, -1), qk_norm, eps)
292
+ self.norm2 = WanLayerNorm(dim, eps)
293
+ self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim))
294
+
295
+ # modulation
296
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
297
+
298
+ def set_ar_attention(self):
299
+ self.self_attn.set_ar_attention()
300
+
301
+ def forward(
302
+ self,
303
+ x,
304
+ e,
305
+ grid_sizes,
306
+ freqs,
307
+ context,
308
+ block_mask,
309
+ ):
310
+ r"""
311
+ Args:
312
+ x(Tensor): Shape [B, L, C]
313
+ e(Tensor): Shape [B, 6, C]
314
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
315
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
316
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
317
+ """
318
+ if e.dim() == 3:
319
+ modulation = self.modulation # 1, 6, dim
320
+ with amp.autocast("cuda", dtype=torch.float32):
321
+ e = (modulation + e).chunk(6, dim=1)
322
+ elif e.dim() == 4:
323
+ modulation = self.modulation.unsqueeze(2) # 1, 6, 1, dim
324
+ with amp.autocast("cuda", dtype=torch.float32):
325
+ e = (modulation + e).chunk(6, dim=1)
326
+ e = [ei.squeeze(1) for ei in e]
327
+
328
+ # self-attention
329
+ out = mul_add_add_compile(self.norm1(x), e[1], e[0])
330
+ y = self.self_attn(out, grid_sizes, freqs, block_mask)
331
+ with amp.autocast("cuda", dtype=torch.float32):
332
+ x = mul_add_compile(x, y, e[2])
333
+
334
+ # cross-attention & ffn function
335
+ def cross_attn_ffn(x, context, e):
336
+ dtype = context.dtype
337
+ x = x + self.cross_attn(self.norm3(x.to(dtype)), context)
338
+ y = self.ffn(mul_add_add_compile(self.norm2(x), e[4], e[3]).to(dtype))
339
+ with amp.autocast("cuda", dtype=torch.float32):
340
+ x = mul_add_compile(x, y, e[5])
341
+ return x
342
+
343
+ x = cross_attn_ffn(x, context, e)
344
+ return x.to(torch.bfloat16)
345
+
346
+
347
+ class Head(nn.Module):
348
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
349
+ super().__init__()
350
+ self.dim = dim
351
+ self.out_dim = out_dim
352
+ self.patch_size = patch_size
353
+ self.eps = eps
354
+
355
+ # layers
356
+ out_dim = math.prod(patch_size) * out_dim
357
+ self.norm = WanLayerNorm(dim, eps)
358
+ self.head = nn.Linear(dim, out_dim)
359
+
360
+ # modulation
361
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
362
+
363
+ def forward(self, x, e):
364
+ r"""
365
+ Args:
366
+ x(Tensor): Shape [B, L1, C]
367
+ e(Tensor): Shape [B, C]
368
+ """
369
+ with amp.autocast("cuda", dtype=torch.float32):
370
+ if e.dim() == 2:
371
+ modulation = self.modulation # 1, 2, dim
372
+ e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
373
+
374
+ elif e.dim() == 3:
375
+ modulation = self.modulation.unsqueeze(2) # 1, 2, seq, dim
376
+ e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
377
+ e = [ei.squeeze(1) for ei in e]
378
+ x = self.head(self.norm(x) * (1 + e[1]) + e[0])
379
+ return x
380
+
381
+
382
+ class MLPProj(torch.nn.Module):
383
+ def __init__(self, in_dim, out_dim):
384
+ super().__init__()
385
+
386
+ self.proj = torch.nn.Sequential(
387
+ torch.nn.LayerNorm(in_dim),
388
+ torch.nn.Linear(in_dim, in_dim),
389
+ torch.nn.GELU(),
390
+ torch.nn.Linear(in_dim, out_dim),
391
+ torch.nn.LayerNorm(out_dim),
392
+ )
393
+
394
+ def forward(self, image_embeds):
395
+ clip_extra_context_tokens = self.proj(image_embeds)
396
+ return clip_extra_context_tokens
397
+
398
+
399
+ class WanModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
400
+ r"""
401
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
402
+ """
403
+
404
+ ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim", "window_size"]
405
+ _no_split_modules = ["WanAttentionBlock"]
406
+
407
+ _supports_gradient_checkpointing = True
408
+
409
+ @register_to_config
410
+ def __init__(
411
+ self,
412
+ model_type="t2v",
413
+ patch_size=(1, 2, 2),
414
+ text_len=512,
415
+ in_dim=16,
416
+ dim=2048,
417
+ ffn_dim=8192,
418
+ freq_dim=256,
419
+ text_dim=4096,
420
+ out_dim=16,
421
+ num_heads=16,
422
+ num_layers=32,
423
+ window_size=(-1, -1),
424
+ qk_norm=True,
425
+ cross_attn_norm=True,
426
+ inject_sample_info=False,
427
+ eps=1e-6,
428
+ ):
429
+ r"""
430
+ Initialize the diffusion model backbone.
431
+
432
+ Args:
433
+ model_type (`str`, *optional*, defaults to 't2v'):
434
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
435
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
436
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
437
+ text_len (`int`, *optional*, defaults to 512):
438
+ Fixed length for text embeddings
439
+ in_dim (`int`, *optional*, defaults to 16):
440
+ Input video channels (C_in)
441
+ dim (`int`, *optional*, defaults to 2048):
442
+ Hidden dimension of the transformer
443
+ ffn_dim (`int`, *optional*, defaults to 8192):
444
+ Intermediate dimension in feed-forward network
445
+ freq_dim (`int`, *optional*, defaults to 256):
446
+ Dimension for sinusoidal time embeddings
447
+ text_dim (`int`, *optional*, defaults to 4096):
448
+ Input dimension for text embeddings
449
+ out_dim (`int`, *optional*, defaults to 16):
450
+ Output video channels (C_out)
451
+ num_heads (`int`, *optional*, defaults to 16):
452
+ Number of attention heads
453
+ num_layers (`int`, *optional*, defaults to 32):
454
+ Number of transformer blocks
455
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
456
+ Window size for local attention (-1 indicates global attention)
457
+ qk_norm (`bool`, *optional*, defaults to True):
458
+ Enable query/key normalization
459
+ cross_attn_norm (`bool`, *optional*, defaults to False):
460
+ Enable cross-attention normalization
461
+ eps (`float`, *optional*, defaults to 1e-6):
462
+ Epsilon value for normalization layers
463
+ """
464
+
465
+ super().__init__()
466
+
467
+ assert model_type in ["t2v", "i2v"]
468
+ self.model_type = model_type
469
+
470
+ self.patch_size = patch_size
471
+ self.text_len = text_len
472
+ self.in_dim = in_dim
473
+ self.dim = dim
474
+ self.ffn_dim = ffn_dim
475
+ self.freq_dim = freq_dim
476
+ self.text_dim = text_dim
477
+ self.out_dim = out_dim
478
+ self.num_heads = num_heads
479
+ self.num_layers = num_layers
480
+ self.window_size = window_size
481
+ self.qk_norm = qk_norm
482
+ self.cross_attn_norm = cross_attn_norm
483
+ self.eps = eps
484
+ self.num_frame_per_block = 1
485
+ self.flag_causal_attention = False
486
+ self.block_mask = None
487
+ self.enable_teacache = False
488
+
489
+ # embeddings
490
+ self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
491
+ self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim))
492
+
493
+ self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
494
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
495
+
496
+ if inject_sample_info:
497
+ self.fps_embedding = nn.Embedding(2, dim)
498
+ self.fps_projection = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim * 6))
499
+
500
+ # blocks
501
+ cross_attn_type = "t2v_cross_attn" if model_type == "t2v" else "i2v_cross_attn"
502
+ self.blocks = nn.ModuleList(
503
+ [
504
+ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
505
+ for _ in range(num_layers)
506
+ ]
507
+ )
508
+
509
+ # head
510
+ self.head = Head(dim, out_dim, patch_size, eps)
511
+
512
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
513
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
514
+ d = dim // num_heads
515
+ self.freqs = torch.cat(
516
+ [rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))],
517
+ dim=1,
518
+ )
519
+
520
+ if model_type == "i2v":
521
+ self.img_emb = MLPProj(1280, dim)
522
+
523
+ self.gradient_checkpointing = False
524
+
525
+ self.cpu_offloading = False
526
+
527
+ self.inject_sample_info = inject_sample_info
528
+ # initialize weights
529
+ self.init_weights()
530
+
531
+ def _set_gradient_checkpointing(self, module, value=False):
532
+ self.gradient_checkpointing = value
533
+
534
+ def zero_init_i2v_cross_attn(self):
535
+ print("zero init i2v cross attn")
536
+ for i in range(self.num_layers):
537
+ self.blocks[i].cross_attn.v_img.weight.data.zero_()
538
+ self.blocks[i].cross_attn.v_img.bias.data.zero_()
539
+
540
+ @staticmethod
541
+ def _prepare_blockwise_causal_attn_mask(
542
+ device: torch.device | str, num_frames: int = 21, frame_seqlen: int = 1560, num_frame_per_block=1
543
+ ) -> BlockMask:
544
+ """
545
+ we will divide the token sequence into the following format
546
+ [1 latent frame] [1 latent frame] ... [1 latent frame]
547
+ We use flexattention to construct the attention mask
548
+ """
549
+ total_length = num_frames * frame_seqlen
550
+
551
+ # we do right padding to get to a multiple of 128
552
+ padded_length = math.ceil(total_length / 128) * 128 - total_length
553
+
554
+ ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long)
555
+
556
+ # Block-wise causal mask will attend to all elements that are before the end of the current chunk
557
+ frame_indices = torch.arange(start=0, end=total_length, step=frame_seqlen * num_frame_per_block, device=device)
558
+
559
+ for tmp in frame_indices:
560
+ ends[tmp : tmp + frame_seqlen * num_frame_per_block] = tmp + frame_seqlen * num_frame_per_block
561
+
562
+ def attention_mask(b, h, q_idx, kv_idx):
563
+ return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
564
+ # return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask
565
+
566
+ block_mask = create_block_mask(
567
+ attention_mask,
568
+ B=None,
569
+ H=None,
570
+ Q_LEN=total_length + padded_length,
571
+ KV_LEN=total_length + padded_length,
572
+ _compile=False,
573
+ device=device,
574
+ )
575
+
576
+ return block_mask
577
+
578
+ def initialize_teacache(self, enable_teacache=True, num_steps=25, teacache_thresh=0.15, use_ret_steps=False, ckpt_dir=''):
579
+ self.enable_teacache = enable_teacache
580
+ print('using teacache')
581
+ self.cnt = 0
582
+ self.num_steps = num_steps
583
+ self.teacache_thresh = teacache_thresh
584
+ self.accumulated_rel_l1_distance_even = 0
585
+ self.accumulated_rel_l1_distance_odd = 0
586
+ self.previous_e0_even = None
587
+ self.previous_e0_odd = None
588
+ self.previous_residual_even = None
589
+ self.previous_residual_odd = None
590
+ self.use_ref_steps = use_ret_steps
591
+ if "I2V" in ckpt_dir:
592
+ if use_ret_steps:
593
+ if '540P' in ckpt_dir:
594
+ self.coefficients = [ 2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
595
+ if '720P' in ckpt_dir:
596
+ self.coefficients = [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
597
+ self.ret_steps = 5*2
598
+ self.cutoff_steps = num_steps*2
599
+ else:
600
+ if '540P' in ckpt_dir:
601
+ self.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
602
+ if '720P' in ckpt_dir:
603
+ self.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
604
+ self.ret_steps = 1*2
605
+ self.cutoff_steps = num_steps*2 - 2
606
+ else:
607
+ if use_ret_steps:
608
+ if '1.3B' in ckpt_dir:
609
+ self.coefficients = [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02]
610
+ if '14B' in ckpt_dir:
611
+ self.coefficients = [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01]
612
+ self.ret_steps = 5*2
613
+ self.cutoff_steps = num_steps*2
614
+ else:
615
+ if '1.3B' in ckpt_dir:
616
+ self.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
617
+ if '14B' in ckpt_dir:
618
+ self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
619
+ self.ret_steps = 1*2
620
+ self.cutoff_steps = num_steps*2 - 2
621
+
622
+ def forward(self, x, t, context, clip_fea=None, y=None, fps=None):
623
+ r"""
624
+ Forward pass through the diffusion model
625
+
626
+ Args:
627
+ x (List[Tensor]):
628
+ List of input video tensors, each with shape [C_in, F, H, W]
629
+ t (Tensor):
630
+ Diffusion timesteps tensor of shape [B]
631
+ context (List[Tensor]):
632
+ List of text embeddings each with shape [L, C]
633
+ seq_len (`int`):
634
+ Maximum sequence length for positional encoding
635
+ clip_fea (Tensor, *optional*):
636
+ CLIP image features for image-to-video mode
637
+ y (List[Tensor], *optional*):
638
+ Conditional video inputs for image-to-video mode, same shape as x
639
+
640
+ Returns:
641
+ List[Tensor]:
642
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
643
+ """
644
+ if self.model_type == "i2v":
645
+ assert clip_fea is not None and y is not None
646
+ # params
647
+ device = self.patch_embedding.weight.device
648
+ if self.freqs.device != device:
649
+ self.freqs = self.freqs.to(device)
650
+
651
+ if y is not None:
652
+ x = torch.cat([x, y], dim=1)
653
+
654
+ # embeddings
655
+ x = self.patch_embedding(x)
656
+ grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long)
657
+ x = x.flatten(2).transpose(1, 2)
658
+
659
+ if self.flag_causal_attention:
660
+ frame_num = grid_sizes[0]
661
+ height = grid_sizes[1]
662
+ width = grid_sizes[2]
663
+ block_num = frame_num // self.num_frame_per_block
664
+ range_tensor = torch.arange(block_num).view(-1, 1)
665
+ range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten()
666
+ casual_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
667
+ casual_mask = casual_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x.device)
668
+ casual_mask = casual_mask.repeat(1, height, width, 1, height, width)
669
+ casual_mask = casual_mask.reshape(frame_num * height * width, frame_num * height * width)
670
+ self.block_mask = casual_mask.unsqueeze(0).unsqueeze(0)
671
+
672
+ # time embeddings
673
+ with amp.autocast("cuda", dtype=torch.float32):
674
+ if t.dim() == 2:
675
+ b, f = t.shape
676
+ _flag_df = True
677
+ else:
678
+ _flag_df = False
679
+
680
+ e = self.time_embedding(
681
+ sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
682
+ ) # b, dim
683
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # b, 6, dim
684
+
685
+ if self.inject_sample_info:
686
+ fps = torch.tensor(fps, dtype=torch.long, device=device)
687
+
688
+ fps_emb = self.fps_embedding(fps).float()
689
+ if _flag_df:
690
+ e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
691
+ else:
692
+ e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
693
+
694
+ if _flag_df:
695
+ e = e.view(b, f, 1, 1, self.dim)
696
+ e0 = e0.view(b, f, 1, 1, 6, self.dim)
697
+ e = e.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3)
698
+ e0 = e0.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3)
699
+ e0 = e0.transpose(1, 2).contiguous()
700
+
701
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
702
+
703
+ # context
704
+ context = self.text_embedding(context)
705
+
706
+ if clip_fea is not None:
707
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
708
+ context = torch.concat([context_clip, context], dim=1)
709
+
710
+ # arguments
711
+ kwargs = dict(e=e0, grid_sizes=grid_sizes, freqs=self.freqs, context=context, block_mask=self.block_mask)
712
+ if self.enable_teacache:
713
+ modulated_inp = e0 if self.use_ref_steps else e
714
+ # teacache
715
+ if self.cnt%2==0: # even -> conditon
716
+ self.is_even = True
717
+ if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
718
+ should_calc_even = True
719
+ self.accumulated_rel_l1_distance_even = 0
720
+ else:
721
+ rescale_func = np.poly1d(self.coefficients)
722
+ self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
723
+ if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
724
+ should_calc_even = False
725
+ else:
726
+ should_calc_even = True
727
+ self.accumulated_rel_l1_distance_even = 0
728
+ self.previous_e0_even = modulated_inp.clone()
729
+
730
+ else: # odd -> unconditon
731
+ self.is_even = False
732
+ if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
733
+ should_calc_odd = True
734
+ self.accumulated_rel_l1_distance_odd = 0
735
+ else:
736
+ rescale_func = np.poly1d(self.coefficients)
737
+ self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp-self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()).cpu().item())
738
+ if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
739
+ should_calc_odd = False
740
+ else:
741
+ should_calc_odd = True
742
+ self.accumulated_rel_l1_distance_odd = 0
743
+ self.previous_e0_odd = modulated_inp.clone()
744
+
745
+ if self.enable_teacache:
746
+ if self.is_even:
747
+ if not should_calc_even:
748
+ x += self.previous_residual_even
749
+ else:
750
+ ori_x = x.clone()
751
+ for block in self.blocks:
752
+ x = block(x, **kwargs)
753
+ self.previous_residual_even = x - ori_x
754
+ else:
755
+ if not should_calc_odd:
756
+ x += self.previous_residual_odd
757
+ else:
758
+ ori_x = x.clone()
759
+ for block in self.blocks:
760
+ x = block(x, **kwargs)
761
+ self.previous_residual_odd = x - ori_x
762
+
763
+ self.cnt += 1
764
+ if self.cnt >= self.num_steps:
765
+ self.cnt = 0
766
+ else:
767
+ for block in self.blocks:
768
+ x = block(x, **kwargs)
769
+
770
+ x = self.head(x, e)
771
+
772
+ # unpatchify
773
+ x = self.unpatchify(x, grid_sizes)
774
+
775
+ return x.float()
776
+
777
+ def unpatchify(self, x, grid_sizes):
778
+ r"""
779
+ Reconstruct video tensors from patch embeddings.
780
+
781
+ Args:
782
+ x (List[Tensor]):
783
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
784
+ grid_sizes (Tensor):
785
+ Original spatial-temporal grid dimensions before patching,
786
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
787
+
788
+ Returns:
789
+ List[Tensor]:
790
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
791
+ """
792
+
793
+ c = self.out_dim
794
+ bs = x.shape[0]
795
+ x = x.view(bs, *grid_sizes, *self.patch_size, c)
796
+ x = torch.einsum("bfhwpqrc->bcfphqwr", x)
797
+ x = x.reshape(bs, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
798
+
799
+ return x
800
+
801
+ def set_ar_attention(self, causal_block_size):
802
+ self.num_frame_per_block = causal_block_size
803
+ self.flag_causal_attention = True
804
+ for block in self.blocks:
805
+ block.set_ar_attention()
806
+
807
+ def init_weights(self):
808
+ r"""
809
+ Initialize model parameters using Xavier initialization.
810
+ """
811
+
812
+ # basic init
813
+ for m in self.modules():
814
+ if isinstance(m, nn.Linear):
815
+ nn.init.xavier_uniform_(m.weight)
816
+ if m.bias is not None:
817
+ nn.init.zeros_(m.bias)
818
+
819
+ # init embeddings
820
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
821
+ for m in self.text_embedding.modules():
822
+ if isinstance(m, nn.Linear):
823
+ nn.init.normal_(m.weight, std=0.02)
824
+ for m in self.time_embedding.modules():
825
+ if isinstance(m, nn.Linear):
826
+ nn.init.normal_(m.weight, std=0.02)
827
+
828
+ if self.inject_sample_info:
829
+ nn.init.normal_(self.fps_embedding.weight, std=0.02)
830
+
831
+ for m in self.fps_projection.modules():
832
+ if isinstance(m, nn.Linear):
833
+ nn.init.normal_(m.weight, std=0.02)
834
+
835
+ nn.init.zeros_(self.fps_projection[-1].weight)
836
+ nn.init.zeros_(self.fps_projection[-1].bias)
837
+
838
+ # init output layer
839
+ nn.init.zeros_(self.head.head.weight)
skyreels_v2_infer/modules/vae.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import logging
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+
9
+
10
+ __all__ = [
11
+ "WanVAE",
12
+ ]
13
+
14
+ CACHE_T = 2
15
+
16
+
17
+ class CausalConv3d(nn.Conv3d):
18
+ """
19
+ Causal 3d convolusion.
20
+ """
21
+
22
+ def __init__(self, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
25
+ self.padding = (0, 0, 0)
26
+
27
+ def forward(self, x, cache_x=None):
28
+ padding = list(self._padding)
29
+ if cache_x is not None and self._padding[4] > 0:
30
+ cache_x = cache_x.to(x.device)
31
+ x = torch.cat([cache_x, x], dim=2)
32
+ padding[4] -= cache_x.shape[2]
33
+ x = F.pad(x, padding)
34
+
35
+ return super().forward(x)
36
+
37
+
38
+ class RMS_norm(nn.Module):
39
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
40
+ super().__init__()
41
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
42
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
43
+
44
+ self.channel_first = channel_first
45
+ self.scale = dim**0.5
46
+ self.gamma = nn.Parameter(torch.ones(shape))
47
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
48
+
49
+ def forward(self, x):
50
+ return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
51
+
52
+
53
+ class Upsample(nn.Upsample):
54
+ def forward(self, x):
55
+ """
56
+ Fix bfloat16 support for nearest neighbor interpolation.
57
+ """
58
+ return super().forward(x.float()).type_as(x)
59
+
60
+
61
+ class Resample(nn.Module):
62
+ def __init__(self, dim, mode):
63
+ assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d")
64
+ super().__init__()
65
+ self.dim = dim
66
+ self.mode = mode
67
+
68
+ # layers
69
+ if mode == "upsample2d":
70
+ self.resample = nn.Sequential(
71
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
72
+ )
73
+ elif mode == "upsample3d":
74
+ self.resample = nn.Sequential(
75
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
76
+ )
77
+ self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
78
+
79
+ elif mode == "downsample2d":
80
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
81
+ elif mode == "downsample3d":
82
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
83
+ self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
84
+
85
+ else:
86
+ self.resample = nn.Identity()
87
+
88
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
89
+ b, c, t, h, w = x.size()
90
+ if self.mode == "upsample3d":
91
+ if feat_cache is not None:
92
+ idx = feat_idx[0]
93
+ if feat_cache[idx] is None:
94
+ feat_cache[idx] = "Rep"
95
+ feat_idx[0] += 1
96
+ else:
97
+
98
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
99
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
100
+ # cache last frame of last two chunk
101
+ cache_x = torch.cat(
102
+ [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
103
+ )
104
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
105
+ cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
106
+ if feat_cache[idx] == "Rep":
107
+ x = self.time_conv(x)
108
+ else:
109
+ x = self.time_conv(x, feat_cache[idx])
110
+ feat_cache[idx] = cache_x
111
+ feat_idx[0] += 1
112
+
113
+ x = x.reshape(b, 2, c, t, h, w)
114
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
115
+ x = x.reshape(b, c, t * 2, h, w)
116
+ t = x.shape[2]
117
+ x = rearrange(x, "b c t h w -> (b t) c h w")
118
+ x = self.resample(x)
119
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
120
+
121
+ if self.mode == "downsample3d":
122
+ if feat_cache is not None:
123
+ idx = feat_idx[0]
124
+ if feat_cache[idx] is None:
125
+ feat_cache[idx] = x.clone()
126
+ feat_idx[0] += 1
127
+ else:
128
+
129
+ cache_x = x[:, :, -1:, :, :].clone()
130
+ # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
131
+ # # cache last frame of last two chunk
132
+ # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
133
+
134
+ x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
135
+ feat_cache[idx] = cache_x
136
+ feat_idx[0] += 1
137
+ return x
138
+
139
+ def init_weight(self, conv):
140
+ conv_weight = conv.weight
141
+ nn.init.zeros_(conv_weight)
142
+ c1, c2, t, h, w = conv_weight.size()
143
+ one_matrix = torch.eye(c1, c2)
144
+ init_matrix = one_matrix
145
+ nn.init.zeros_(conv_weight)
146
+ # conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
147
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
148
+ conv.weight.data.copy_(conv_weight)
149
+ nn.init.zeros_(conv.bias.data)
150
+
151
+ def init_weight2(self, conv):
152
+ conv_weight = conv.weight.data
153
+ nn.init.zeros_(conv_weight)
154
+ c1, c2, t, h, w = conv_weight.size()
155
+ init_matrix = torch.eye(c1 // 2, c2)
156
+ # init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
157
+ conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix
158
+ conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix
159
+ conv.weight.data.copy_(conv_weight)
160
+ nn.init.zeros_(conv.bias.data)
161
+
162
+
163
+ class ResidualBlock(nn.Module):
164
+ def __init__(self, in_dim, out_dim, dropout=0.0):
165
+ super().__init__()
166
+ self.in_dim = in_dim
167
+ self.out_dim = out_dim
168
+
169
+ # layers
170
+ self.residual = nn.Sequential(
171
+ RMS_norm(in_dim, images=False),
172
+ nn.SiLU(),
173
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
174
+ RMS_norm(out_dim, images=False),
175
+ nn.SiLU(),
176
+ nn.Dropout(dropout),
177
+ CausalConv3d(out_dim, out_dim, 3, padding=1),
178
+ )
179
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
180
+
181
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
182
+ h = self.shortcut(x)
183
+ for layer in self.residual:
184
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
185
+ idx = feat_idx[0]
186
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
187
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
188
+ # cache last frame of last two chunk
189
+ cache_x = torch.cat(
190
+ [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
191
+ )
192
+ x = layer(x, feat_cache[idx])
193
+ feat_cache[idx] = cache_x
194
+ feat_idx[0] += 1
195
+ else:
196
+ x = layer(x)
197
+ return x + h
198
+
199
+
200
+ class AttentionBlock(nn.Module):
201
+ """
202
+ Causal self-attention with a single head.
203
+ """
204
+
205
+ def __init__(self, dim):
206
+ super().__init__()
207
+ self.dim = dim
208
+
209
+ # layers
210
+ self.norm = RMS_norm(dim)
211
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
212
+ self.proj = nn.Conv2d(dim, dim, 1)
213
+
214
+ # zero out the last layer params
215
+ nn.init.zeros_(self.proj.weight)
216
+
217
+ def forward(self, x):
218
+ identity = x
219
+ b, c, t, h, w = x.size()
220
+ x = rearrange(x, "b c t h w -> (b t) c h w")
221
+ x = self.norm(x)
222
+ # compute query, key, value
223
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1)
224
+
225
+ # apply attention
226
+ x = F.scaled_dot_product_attention(
227
+ q,
228
+ k,
229
+ v,
230
+ )
231
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
232
+
233
+ # output
234
+ x = self.proj(x)
235
+ x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
236
+ return x + identity
237
+
238
+
239
+ class Encoder3d(nn.Module):
240
+ def __init__(
241
+ self,
242
+ dim=128,
243
+ z_dim=4,
244
+ dim_mult=[1, 2, 4, 4],
245
+ num_res_blocks=2,
246
+ attn_scales=[],
247
+ temperal_downsample=[True, True, False],
248
+ dropout=0.0,
249
+ ):
250
+ super().__init__()
251
+ self.dim = dim
252
+ self.z_dim = z_dim
253
+ self.dim_mult = dim_mult
254
+ self.num_res_blocks = num_res_blocks
255
+ self.attn_scales = attn_scales
256
+ self.temperal_downsample = temperal_downsample
257
+
258
+ # dimensions
259
+ dims = [dim * u for u in [1] + dim_mult]
260
+ scale = 1.0
261
+
262
+ # init block
263
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
264
+
265
+ # downsample blocks
266
+ downsamples = []
267
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
268
+ # residual (+attention) blocks
269
+ for _ in range(num_res_blocks):
270
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
271
+ if scale in attn_scales:
272
+ downsamples.append(AttentionBlock(out_dim))
273
+ in_dim = out_dim
274
+
275
+ # downsample block
276
+ if i != len(dim_mult) - 1:
277
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
278
+ downsamples.append(Resample(out_dim, mode=mode))
279
+ scale /= 2.0
280
+ self.downsamples = nn.Sequential(*downsamples)
281
+
282
+ # middle blocks
283
+ self.middle = nn.Sequential(
284
+ ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout)
285
+ )
286
+
287
+ # output blocks
288
+ self.head = nn.Sequential(
289
+ RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1)
290
+ )
291
+
292
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
293
+ if feat_cache is not None:
294
+ idx = feat_idx[0]
295
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
296
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
297
+ # cache last frame of last two chunk
298
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
299
+ x = self.conv1(x, feat_cache[idx])
300
+ feat_cache[idx] = cache_x
301
+ feat_idx[0] += 1
302
+ else:
303
+ x = self.conv1(x)
304
+
305
+ ## downsamples
306
+ for layer in self.downsamples:
307
+ if feat_cache is not None:
308
+ x = layer(x, feat_cache, feat_idx)
309
+ else:
310
+ x = layer(x)
311
+
312
+ ## middle
313
+ for layer in self.middle:
314
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
315
+ x = layer(x, feat_cache, feat_idx)
316
+ else:
317
+ x = layer(x)
318
+
319
+ ## head
320
+ for layer in self.head:
321
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
322
+ idx = feat_idx[0]
323
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
324
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
325
+ # cache last frame of last two chunk
326
+ cache_x = torch.cat(
327
+ [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
328
+ )
329
+ x = layer(x, feat_cache[idx])
330
+ feat_cache[idx] = cache_x
331
+ feat_idx[0] += 1
332
+ else:
333
+ x = layer(x)
334
+ return x
335
+
336
+
337
+ class Decoder3d(nn.Module):
338
+ def __init__(
339
+ self,
340
+ dim=128,
341
+ z_dim=4,
342
+ dim_mult=[1, 2, 4, 4],
343
+ num_res_blocks=2,
344
+ attn_scales=[],
345
+ temperal_upsample=[False, True, True],
346
+ dropout=0.0,
347
+ ):
348
+ super().__init__()
349
+ self.dim = dim
350
+ self.z_dim = z_dim
351
+ self.dim_mult = dim_mult
352
+ self.num_res_blocks = num_res_blocks
353
+ self.attn_scales = attn_scales
354
+ self.temperal_upsample = temperal_upsample
355
+
356
+ # dimensions
357
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
358
+ scale = 1.0 / 2 ** (len(dim_mult) - 2)
359
+
360
+ # init block
361
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
362
+
363
+ # middle blocks
364
+ self.middle = nn.Sequential(
365
+ ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout)
366
+ )
367
+
368
+ # upsample blocks
369
+ upsamples = []
370
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
371
+ # residual (+attention) blocks
372
+ if i == 1 or i == 2 or i == 3:
373
+ in_dim = in_dim // 2
374
+ for _ in range(num_res_blocks + 1):
375
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
376
+ if scale in attn_scales:
377
+ upsamples.append(AttentionBlock(out_dim))
378
+ in_dim = out_dim
379
+
380
+ # upsample block
381
+ if i != len(dim_mult) - 1:
382
+ mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
383
+ upsamples.append(Resample(out_dim, mode=mode))
384
+ scale *= 2.0
385
+ self.upsamples = nn.Sequential(*upsamples)
386
+
387
+ # output blocks
388
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 3, 3, padding=1))
389
+
390
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
391
+ ## conv1
392
+ if feat_cache is not None:
393
+ idx = feat_idx[0]
394
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
395
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
396
+ # cache last frame of last two chunk
397
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
398
+ x = self.conv1(x, feat_cache[idx])
399
+ feat_cache[idx] = cache_x
400
+ feat_idx[0] += 1
401
+ else:
402
+ x = self.conv1(x)
403
+
404
+ ## middle
405
+ for layer in self.middle:
406
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
407
+ x = layer(x, feat_cache, feat_idx)
408
+ else:
409
+ x = layer(x)
410
+
411
+ ## upsamples
412
+ for layer in self.upsamples:
413
+ if feat_cache is not None:
414
+ x = layer(x, feat_cache, feat_idx)
415
+ else:
416
+ x = layer(x)
417
+
418
+ ## head
419
+ for layer in self.head:
420
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
421
+ idx = feat_idx[0]
422
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
423
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
424
+ # cache last frame of last two chunk
425
+ cache_x = torch.cat(
426
+ [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
427
+ )
428
+ x = layer(x, feat_cache[idx])
429
+ feat_cache[idx] = cache_x
430
+ feat_idx[0] += 1
431
+ else:
432
+ x = layer(x)
433
+ return x
434
+
435
+
436
+ def count_conv3d(model):
437
+ count = 0
438
+ for m in model.modules():
439
+ if isinstance(m, CausalConv3d):
440
+ count += 1
441
+ return count
442
+
443
+
444
+ class WanVAE_(nn.Module):
445
+ def __init__(
446
+ self,
447
+ dim=128,
448
+ z_dim=4,
449
+ dim_mult=[1, 2, 4, 4],
450
+ num_res_blocks=2,
451
+ attn_scales=[],
452
+ temperal_downsample=[True, True, False],
453
+ dropout=0.0,
454
+ ):
455
+ super().__init__()
456
+ self.dim = dim
457
+ self.z_dim = z_dim
458
+ self.dim_mult = dim_mult
459
+ self.num_res_blocks = num_res_blocks
460
+ self.attn_scales = attn_scales
461
+ self.temperal_downsample = temperal_downsample
462
+ self.temperal_upsample = temperal_downsample[::-1]
463
+
464
+ # modules
465
+ self.encoder = Encoder3d(
466
+ dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
467
+ )
468
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
469
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
470
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout)
471
+
472
+ def forward(self, x):
473
+ mu, log_var = self.encode(x)
474
+ z = self.reparameterize(mu, log_var)
475
+ x_recon = self.decode(z)
476
+ return x_recon, mu, log_var
477
+
478
+ def encode(self, x, scale):
479
+ self.clear_cache()
480
+ ## cache
481
+ t = x.shape[2]
482
+ iter_ = 1 + (t - 1) // 4
483
+ ## 对encode输入的x,按时间拆分为1、4、4、4....
484
+ for i in range(iter_):
485
+ self._enc_conv_idx = [0]
486
+ if i == 0:
487
+ out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
488
+ else:
489
+ out_ = self.encoder(
490
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
491
+ feat_cache=self._enc_feat_map,
492
+ feat_idx=self._enc_conv_idx,
493
+ )
494
+ out = torch.cat([out, out_], 2)
495
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
496
+ if isinstance(scale[0], torch.Tensor):
497
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
498
+ else:
499
+ mu = (mu - scale[0]) * scale[1]
500
+ self.clear_cache()
501
+ return mu
502
+
503
+ def decode(self, z, scale):
504
+ self.clear_cache()
505
+ # z: [b,c,t,h,w]
506
+ if isinstance(scale[0], torch.Tensor):
507
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
508
+ else:
509
+ z = z / scale[1] + scale[0]
510
+ iter_ = z.shape[2]
511
+ x = self.conv2(z)
512
+ for i in range(iter_):
513
+ self._conv_idx = [0]
514
+ if i == 0:
515
+ out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
516
+ else:
517
+ out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
518
+ out = torch.cat([out, out_], 2)
519
+ self.clear_cache()
520
+ return out
521
+
522
+ def reparameterize(self, mu, log_var):
523
+ std = torch.exp(0.5 * log_var)
524
+ eps = torch.randn_like(std)
525
+ return eps * std + mu
526
+
527
+ def sample(self, imgs, deterministic=False):
528
+ mu, log_var = self.encode(imgs)
529
+ if deterministic:
530
+ return mu
531
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
532
+ return mu + std * torch.randn_like(std)
533
+
534
+ def clear_cache(self):
535
+ self._conv_num = count_conv3d(self.decoder)
536
+ self._conv_idx = [0]
537
+ self._feat_map = [None] * self._conv_num
538
+ # cache encode
539
+ self._enc_conv_num = count_conv3d(self.encoder)
540
+ self._enc_conv_idx = [0]
541
+ self._enc_feat_map = [None] * self._enc_conv_num
542
+
543
+
544
+ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs):
545
+ """
546
+ Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
547
+ """
548
+ # params
549
+ cfg = dict(
550
+ dim=96,
551
+ z_dim=z_dim,
552
+ dim_mult=[1, 2, 4, 4],
553
+ num_res_blocks=2,
554
+ attn_scales=[],
555
+ temperal_downsample=[False, True, True],
556
+ dropout=0.0,
557
+ )
558
+ cfg.update(**kwargs)
559
+
560
+ # init model
561
+ with torch.device("meta"):
562
+ model = WanVAE_(**cfg)
563
+
564
+ # load checkpoint
565
+ logging.info(f"loading {pretrained_path}")
566
+ model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True)
567
+
568
+ return model
569
+
570
+
571
+ class WanVAE:
572
+ def __init__(self, vae_pth="cache/vae_step_411000.pth", z_dim=16):
573
+
574
+ mean = [
575
+ -0.7571,
576
+ -0.7089,
577
+ -0.9113,
578
+ 0.1075,
579
+ -0.1745,
580
+ 0.9653,
581
+ -0.1517,
582
+ 1.5508,
583
+ 0.4134,
584
+ -0.0715,
585
+ 0.5517,
586
+ -0.3632,
587
+ -0.1922,
588
+ -0.9497,
589
+ 0.2503,
590
+ -0.2921,
591
+ ]
592
+ std = [
593
+ 2.8184,
594
+ 1.4541,
595
+ 2.3275,
596
+ 2.6558,
597
+ 1.2196,
598
+ 1.7708,
599
+ 2.6052,
600
+ 2.0743,
601
+ 3.2687,
602
+ 2.1526,
603
+ 2.8652,
604
+ 1.5579,
605
+ 1.6382,
606
+ 1.1253,
607
+ 2.8251,
608
+ 1.9160,
609
+ ]
610
+ self.vae_stride = (4, 8, 8)
611
+ self.mean = torch.tensor(mean)
612
+ self.std = torch.tensor(std)
613
+ self.scale = [self.mean, 1.0 / self.std]
614
+
615
+ # init model
616
+ self.vae = (
617
+ _video_vae(
618
+ pretrained_path=vae_pth,
619
+ z_dim=z_dim,
620
+ )
621
+ .eval()
622
+ .requires_grad_(False)
623
+ )
624
+
625
+ def encode(self, video):
626
+ """
627
+ videos: A list of videos each with shape [C, T, H, W].
628
+ """
629
+ return self.vae.encode(video, self.scale).float()
630
+
631
+ def to(self, *args, **kwargs):
632
+ self.mean = self.mean.to(*args, **kwargs)
633
+ self.std = self.std.to(*args, **kwargs)
634
+ self.scale = [self.mean, 1.0 / self.std]
635
+ self.vae = self.vae.to(*args, **kwargs)
636
+ return self
637
+
638
+ def decode(self, z):
639
+ return self.vae.decode(z, self.scale).float().clamp_(-1, 1)
skyreels_v2_infer/modules/xlm_roberta.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ __all__ = ["XLMRoberta", "xlm_roberta_large"]
8
+
9
+
10
+ class SelfAttention(nn.Module):
11
+ def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
12
+ assert dim % num_heads == 0
13
+ super().__init__()
14
+ self.dim = dim
15
+ self.num_heads = num_heads
16
+ self.head_dim = dim // num_heads
17
+ self.eps = eps
18
+
19
+ # layers
20
+ self.q = nn.Linear(dim, dim)
21
+ self.k = nn.Linear(dim, dim)
22
+ self.v = nn.Linear(dim, dim)
23
+ self.o = nn.Linear(dim, dim)
24
+ self.dropout = nn.Dropout(dropout)
25
+
26
+ def forward(self, x, mask):
27
+ """
28
+ x: [B, L, C].
29
+ """
30
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
31
+
32
+ # compute query, key, value
33
+ q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
34
+ k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
35
+ v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
36
+
37
+ # compute attention
38
+ p = self.dropout.p if self.training else 0.0
39
+ x = F.scaled_dot_product_attention(q, k, v, mask, p)
40
+ x = x.permute(0, 2, 1, 3).reshape(b, s, c)
41
+
42
+ # output
43
+ x = self.o(x)
44
+ x = self.dropout(x)
45
+ return x
46
+
47
+
48
+ class AttentionBlock(nn.Module):
49
+ def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
50
+ super().__init__()
51
+ self.dim = dim
52
+ self.num_heads = num_heads
53
+ self.post_norm = post_norm
54
+ self.eps = eps
55
+
56
+ # layers
57
+ self.attn = SelfAttention(dim, num_heads, dropout, eps)
58
+ self.norm1 = nn.LayerNorm(dim, eps=eps)
59
+ self.ffn = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), nn.Dropout(dropout))
60
+ self.norm2 = nn.LayerNorm(dim, eps=eps)
61
+
62
+ def forward(self, x, mask):
63
+ if self.post_norm:
64
+ x = self.norm1(x + self.attn(x, mask))
65
+ x = self.norm2(x + self.ffn(x))
66
+ else:
67
+ x = x + self.attn(self.norm1(x), mask)
68
+ x = x + self.ffn(self.norm2(x))
69
+ return x
70
+
71
+
72
+ class XLMRoberta(nn.Module):
73
+ """
74
+ XLMRobertaModel with no pooler and no LM head.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ vocab_size=250002,
80
+ max_seq_len=514,
81
+ type_size=1,
82
+ pad_id=1,
83
+ dim=1024,
84
+ num_heads=16,
85
+ num_layers=24,
86
+ post_norm=True,
87
+ dropout=0.1,
88
+ eps=1e-5,
89
+ ):
90
+ super().__init__()
91
+ self.vocab_size = vocab_size
92
+ self.max_seq_len = max_seq_len
93
+ self.type_size = type_size
94
+ self.pad_id = pad_id
95
+ self.dim = dim
96
+ self.num_heads = num_heads
97
+ self.num_layers = num_layers
98
+ self.post_norm = post_norm
99
+ self.eps = eps
100
+
101
+ # embeddings
102
+ self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
103
+ self.type_embedding = nn.Embedding(type_size, dim)
104
+ self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
105
+ self.dropout = nn.Dropout(dropout)
106
+
107
+ # blocks
108
+ self.blocks = nn.ModuleList(
109
+ [AttentionBlock(dim, num_heads, post_norm, dropout, eps) for _ in range(num_layers)]
110
+ )
111
+
112
+ # norm layer
113
+ self.norm = nn.LayerNorm(dim, eps=eps)
114
+
115
+ def forward(self, ids):
116
+ """
117
+ ids: [B, L] of torch.LongTensor.
118
+ """
119
+ b, s = ids.shape
120
+ mask = ids.ne(self.pad_id).long()
121
+
122
+ # embeddings
123
+ x = (
124
+ self.token_embedding(ids)
125
+ + self.type_embedding(torch.zeros_like(ids))
126
+ + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
127
+ )
128
+ if self.post_norm:
129
+ x = self.norm(x)
130
+ x = self.dropout(x)
131
+
132
+ # blocks
133
+ mask = torch.where(mask.view(b, 1, 1, s).gt(0), 0.0, torch.finfo(x.dtype).min)
134
+ for block in self.blocks:
135
+ x = block(x, mask)
136
+
137
+ # output
138
+ if not self.post_norm:
139
+ x = self.norm(x)
140
+ return x
141
+
142
+
143
+ def xlm_roberta_large(pretrained=False, return_tokenizer=False, device="cpu", **kwargs):
144
+ """
145
+ XLMRobertaLarge adapted from Huggingface.
146
+ """
147
+ # params
148
+ cfg = dict(
149
+ vocab_size=250002,
150
+ max_seq_len=514,
151
+ type_size=1,
152
+ pad_id=1,
153
+ dim=1024,
154
+ num_heads=16,
155
+ num_layers=24,
156
+ post_norm=True,
157
+ dropout=0.1,
158
+ eps=1e-5,
159
+ )
160
+ cfg.update(**kwargs)
161
+
162
+ # init a model on device
163
+ with torch.device(device):
164
+ model = XLMRoberta(**cfg)
165
+ return model
skyreels_v2_infer/pipelines/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .diffusion_forcing_pipeline import DiffusionForcingPipeline
2
+ from .image2video_pipeline import Image2VideoPipeline
3
+ from .image2video_pipeline import resizecrop
4
+ from .prompt_enhancer import PromptEnhancer
5
+ from .text2video_pipeline import Text2VideoPipeline
skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from typing import List
4
+ from typing import Optional
5
+ from typing import Tuple
6
+ from typing import Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ from diffusers.image_processor import PipelineImageInput
11
+ from diffusers.utils.torch_utils import randn_tensor
12
+ from diffusers.video_processor import VideoProcessor
13
+ from tqdm import tqdm
14
+
15
+ from ..modules import get_text_encoder
16
+ from ..modules import get_transformer
17
+ from ..modules import get_vae
18
+ from ..scheduler.fm_solvers_unipc import FlowUniPCMultistepScheduler
19
+
20
+
21
+ class DiffusionForcingPipeline:
22
+ """
23
+ A pipeline for diffusion-based video generation tasks.
24
+
25
+ This pipeline supports two main tasks:
26
+ - Image-to-Video (i2v): Generates a video sequence from a source image
27
+ - Text-to-Video (t2v): Generates a video sequence from a text description
28
+
29
+ The pipeline integrates multiple components including:
30
+ - A transformer model for diffusion
31
+ - A VAE for encoding/decoding
32
+ - A text encoder for processing text prompts
33
+ - An image encoder for processing image inputs (i2v mode only)
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ model_path: str,
39
+ dit_path: str,
40
+ device: str = "cuda",
41
+ weight_dtype=torch.bfloat16,
42
+ use_usp=False,
43
+ offload=False,
44
+ ):
45
+ """
46
+ Initialize the diffusion forcing pipeline class
47
+
48
+ Args:
49
+ model_path (str): Path to the model
50
+ dit_path (str): Path to the DIT model, containing model configuration file (config.json) and weight file (*.safetensor)
51
+ device (str): Device to run on, defaults to 'cuda'
52
+ weight_dtype: Weight data type, defaults to torch.bfloat16
53
+ """
54
+ load_device = "cpu" if offload else device
55
+ self.transformer = get_transformer(dit_path, load_device, weight_dtype)
56
+ vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth")
57
+ self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
58
+ self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype)
59
+ self.video_processor = VideoProcessor(vae_scale_factor=16)
60
+ self.device = device
61
+ self.offload = offload
62
+
63
+ if use_usp:
64
+ from xfuser.core.distributed import get_sequence_parallel_world_size
65
+ from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
66
+ import types
67
+
68
+ for block in self.transformer.blocks:
69
+ block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
70
+ self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer)
71
+ self.sp_size = get_sequence_parallel_world_size()
72
+
73
+ self.scheduler = FlowUniPCMultistepScheduler()
74
+
75
+ @property
76
+ def do_classifier_free_guidance(self) -> bool:
77
+ return self._guidance_scale > 1
78
+
79
+ def encode_image(
80
+ self, image: PipelineImageInput, height: int, width: int, num_frames: int
81
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
82
+
83
+ # prefix_video
84
+ prefix_video = np.array(image.resize((width, height))).transpose(2, 0, 1)
85
+ prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1)
86
+ if prefix_video.dtype == torch.uint8:
87
+ prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0
88
+ prefix_video = prefix_video.to(self.device)
89
+ prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
90
+ causal_block_size = self.transformer.num_frame_per_block
91
+ if prefix_video[0].shape[1] % causal_block_size != 0:
92
+ truncate_len = prefix_video[0].shape[1] % causal_block_size
93
+ print("the length of prefix video is truncated for the casual block size alignment.")
94
+ prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
95
+ predix_video_latent_length = prefix_video[0].shape[1]
96
+ return prefix_video, predix_video_latent_length
97
+
98
+ def prepare_latents(
99
+ self,
100
+ shape: Tuple[int],
101
+ dtype: Optional[torch.dtype] = None,
102
+ device: Optional[torch.device] = None,
103
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
104
+ ) -> torch.Tensor:
105
+ return randn_tensor(shape, generator, device=device, dtype=dtype)
106
+
107
+ def generate_timestep_matrix(
108
+ self,
109
+ num_frames,
110
+ step_template,
111
+ base_num_frames,
112
+ ar_step=5,
113
+ num_pre_ready=0,
114
+ casual_block_size=1,
115
+ shrink_interval_with_mask=False,
116
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
117
+ step_matrix, step_index = [], []
118
+ update_mask, valid_interval = [], []
119
+ num_iterations = len(step_template) + 1
120
+ num_frames_block = num_frames // casual_block_size
121
+ base_num_frames_block = base_num_frames // casual_block_size
122
+ if base_num_frames_block < num_frames_block:
123
+ infer_step_num = len(step_template)
124
+ gen_block = base_num_frames_block
125
+ min_ar_step = infer_step_num / gen_block
126
+ assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting"
127
+ # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block)
128
+ step_template = torch.cat(
129
+ [
130
+ torch.tensor([999], dtype=torch.int64, device=step_template.device),
131
+ step_template.long(),
132
+ torch.tensor([0], dtype=torch.int64, device=step_template.device),
133
+ ]
134
+ ) # to handle the counter in row works starting from 1
135
+ pre_row = torch.zeros(num_frames_block, dtype=torch.long)
136
+ if num_pre_ready > 0:
137
+ pre_row[: num_pre_ready // casual_block_size] = num_iterations
138
+
139
+ while torch.all(pre_row >= (num_iterations - 1)) == False:
140
+ new_row = torch.zeros(num_frames_block, dtype=torch.long)
141
+ for i in range(num_frames_block):
142
+ if i == 0 or pre_row[i - 1] >= (
143
+ num_iterations - 1
144
+ ): # the first frame or the last frame is completely denoised
145
+ new_row[i] = pre_row[i] + 1
146
+ else:
147
+ new_row[i] = new_row[i - 1] - ar_step
148
+ new_row = new_row.clamp(0, num_iterations)
149
+
150
+ update_mask.append(
151
+ (new_row != pre_row) & (new_row != num_iterations)
152
+ ) # False: no need to update, True: need to update
153
+ step_index.append(new_row)
154
+ step_matrix.append(step_template[new_row])
155
+ pre_row = new_row
156
+
157
+ # for long video we split into several sequences, base_num_frames is set to the model max length (for training)
158
+ terminal_flag = base_num_frames_block
159
+ if shrink_interval_with_mask:
160
+ idx_sequence = torch.arange(num_frames_block, dtype=torch.int64)
161
+ update_mask = update_mask[0]
162
+ update_mask_idx = idx_sequence[update_mask]
163
+ last_update_idx = update_mask_idx[-1].item()
164
+ terminal_flag = last_update_idx + 1
165
+ # for i in range(0, len(update_mask)):
166
+ for curr_mask in update_mask:
167
+ if terminal_flag < num_frames_block and curr_mask[terminal_flag]:
168
+ terminal_flag += 1
169
+ valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag))
170
+
171
+ step_update_mask = torch.stack(update_mask, dim=0)
172
+ step_index = torch.stack(step_index, dim=0)
173
+ step_matrix = torch.stack(step_matrix, dim=0)
174
+
175
+ if casual_block_size > 1:
176
+ step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
177
+ step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
178
+ step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
179
+ valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval]
180
+
181
+ return step_matrix, step_index, step_update_mask, valid_interval
182
+
183
+ @torch.no_grad()
184
+ def __call__(
185
+ self,
186
+ prompt: Union[str, List[str]],
187
+ negative_prompt: Union[str, List[str]] = "",
188
+ image: PipelineImageInput = None,
189
+ height: int = 480,
190
+ width: int = 832,
191
+ num_frames: int = 97,
192
+ num_inference_steps: int = 50,
193
+ shift: float = 1.0,
194
+ guidance_scale: float = 5.0,
195
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
196
+ overlap_history: int = None,
197
+ addnoise_condition: int = 0,
198
+ base_num_frames: int = 97,
199
+ ar_step: int = 5,
200
+ causal_block_size: int = None,
201
+ fps: int = 24,
202
+ ):
203
+ latent_height = height // 8
204
+ latent_width = width // 8
205
+ latent_length = (num_frames - 1) // 4 + 1
206
+
207
+ self._guidance_scale = guidance_scale
208
+
209
+ i2v_extra_kwrags = {}
210
+ prefix_video = None
211
+ predix_video_latent_length = 0
212
+ if image:
213
+ prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames)
214
+
215
+ self.text_encoder.to(self.device)
216
+ prompt_embeds = self.text_encoder.encode(prompt).to(self.transformer.dtype)
217
+ if self.do_classifier_free_guidance:
218
+ negative_prompt_embeds = self.text_encoder.encode(negative_prompt).to(self.transformer.dtype)
219
+ if self.offload:
220
+ self.text_encoder.cpu()
221
+ torch.cuda.empty_cache()
222
+
223
+ self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift)
224
+ init_timesteps = self.scheduler.timesteps
225
+ if causal_block_size is None:
226
+ causal_block_size = self.transformer.num_frame_per_block
227
+ fps_embeds = [fps] * prompt_embeds.shape[0]
228
+ fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
229
+ transformer_dtype = self.transformer.dtype
230
+ # with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad():
231
+ if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames:
232
+ # short video generation
233
+ latent_shape = [16, latent_length, latent_height, latent_width]
234
+ latents = self.prepare_latents(
235
+ latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator
236
+ )
237
+ latents = [latents]
238
+ if prefix_video is not None:
239
+ latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype)
240
+ base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
241
+ step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
242
+ latent_length, init_timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size
243
+ )
244
+ sample_schedulers = []
245
+ for _ in range(latent_length):
246
+ sample_scheduler = FlowUniPCMultistepScheduler(
247
+ num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
248
+ )
249
+ sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift)
250
+ sample_schedulers.append(sample_scheduler)
251
+ sample_schedulers_counter = [0] * latent_length
252
+ self.transformer.to(self.device)
253
+ for i, timestep_i in enumerate(tqdm(step_matrix)):
254
+ update_mask_i = step_update_mask[i]
255
+ valid_interval_i = valid_interval[i]
256
+ valid_interval_start, valid_interval_end = valid_interval_i
257
+ timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
258
+ latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
259
+ if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
260
+ noise_factor = 0.001 * addnoise_condition
261
+ timestep_for_noised_condition = addnoise_condition
262
+ latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
263
+ latent_model_input[0][:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor)
264
+ + torch.randn_like(latent_model_input[0][:, valid_interval_start:predix_video_latent_length])
265
+ * noise_factor
266
+ )
267
+ timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
268
+ if not self.do_classifier_free_guidance:
269
+ noise_pred = self.transformer(
270
+ torch.stack([latent_model_input[0]]),
271
+ t=timestep,
272
+ context=prompt_embeds,
273
+ fps=fps_embeds,
274
+ **i2v_extra_kwrags,
275
+ )[0]
276
+ else:
277
+ noise_pred_cond = self.transformer(
278
+ torch.stack([latent_model_input[0]]),
279
+ t=timestep,
280
+ context=prompt_embeds,
281
+ fps=fps_embeds,
282
+ **i2v_extra_kwrags,
283
+ )[0]
284
+ noise_pred_uncond = self.transformer(
285
+ torch.stack([latent_model_input[0]]),
286
+ t=timestep,
287
+ context=negative_prompt_embeds,
288
+ fps=fps_embeds,
289
+ **i2v_extra_kwrags,
290
+ )[0]
291
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
292
+ for idx in range(valid_interval_start, valid_interval_end):
293
+ if update_mask_i[idx].item():
294
+ latents[0][:, idx] = sample_schedulers[idx].step(
295
+ noise_pred[:, idx - valid_interval_start],
296
+ timestep_i[idx],
297
+ latents[0][:, idx],
298
+ return_dict=False,
299
+ generator=generator,
300
+ )[0]
301
+ sample_schedulers_counter[idx] += 1
302
+ if self.offload:
303
+ self.transformer.cpu()
304
+ torch.cuda.empty_cache()
305
+ x0 = latents[0].unsqueeze(0)
306
+ videos = self.vae.decode(x0)
307
+ videos = (videos / 2 + 0.5).clamp(0, 1)
308
+ videos = [video for video in videos]
309
+ videos = [video.permute(1, 2, 3, 0) * 255 for video in videos]
310
+ videos = [video.cpu().numpy().astype(np.uint8) for video in videos]
311
+ return videos
312
+ else:
313
+ # long video generation
314
+ base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
315
+ overlap_history_frames = (overlap_history - 1) // 4 + 1
316
+ n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1
317
+ print(f"n_iter:{n_iter}")
318
+ output_video = None
319
+ for i in range(n_iter):
320
+ if output_video is not None: # i !=0
321
+ prefix_video = output_video[:, -overlap_history:].to(prompt_embeds.device)
322
+ prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
323
+ if prefix_video[0].shape[1] % causal_block_size != 0:
324
+ truncate_len = prefix_video[0].shape[1] % causal_block_size
325
+ print("the length of prefix video is truncated for the casual block size alignment.")
326
+ prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
327
+ predix_video_latent_length = prefix_video[0].shape[1]
328
+ finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames
329
+ left_frame_num = latent_length - finished_frame_num
330
+ base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames)
331
+ if ar_step > 0 and self.transformer.enable_teacache:
332
+ num_steps = num_inference_steps + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step
333
+ self.transformer.num_steps = num_steps
334
+ else: # i == 0
335
+ base_num_frames_iter = base_num_frames
336
+ latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
337
+ latents = self.prepare_latents(
338
+ latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator
339
+ )
340
+ latents = [latents]
341
+ if prefix_video is not None:
342
+ latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype)
343
+ step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
344
+ base_num_frames_iter,
345
+ init_timesteps,
346
+ base_num_frames_iter,
347
+ ar_step,
348
+ predix_video_latent_length,
349
+ causal_block_size,
350
+ )
351
+ sample_schedulers = []
352
+ for _ in range(base_num_frames_iter):
353
+ sample_scheduler = FlowUniPCMultistepScheduler(
354
+ num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
355
+ )
356
+ sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift)
357
+ sample_schedulers.append(sample_scheduler)
358
+ sample_schedulers_counter = [0] * base_num_frames_iter
359
+ self.transformer.to(self.device)
360
+ for i, timestep_i in enumerate(tqdm(step_matrix)):
361
+ update_mask_i = step_update_mask[i]
362
+ valid_interval_i = valid_interval[i]
363
+ valid_interval_start, valid_interval_end = valid_interval_i
364
+ timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
365
+ latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
366
+ if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
367
+ noise_factor = 0.001 * addnoise_condition
368
+ timestep_for_noised_condition = addnoise_condition
369
+ latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
370
+ latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
371
+ * (1.0 - noise_factor)
372
+ + torch.randn_like(
373
+ latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
374
+ )
375
+ * noise_factor
376
+ )
377
+ timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
378
+ if not self.do_classifier_free_guidance:
379
+ noise_pred = self.transformer(
380
+ torch.stack([latent_model_input[0]]),
381
+ t=timestep,
382
+ context=prompt_embeds,
383
+ fps=fps_embeds,
384
+ **i2v_extra_kwrags,
385
+ )[0]
386
+ else:
387
+ noise_pred_cond = self.transformer(
388
+ torch.stack([latent_model_input[0]]),
389
+ t=timestep,
390
+ context=prompt_embeds,
391
+ fps=fps_embeds,
392
+ **i2v_extra_kwrags,
393
+ )[0]
394
+ noise_pred_uncond = self.transformer(
395
+ torch.stack([latent_model_input[0]]),
396
+ t=timestep,
397
+ context=negative_prompt_embeds,
398
+ fps=fps_embeds,
399
+ **i2v_extra_kwrags,
400
+ )[0]
401
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
402
+ for idx in range(valid_interval_start, valid_interval_end):
403
+ if update_mask_i[idx].item():
404
+ latents[0][:, idx] = sample_schedulers[idx].step(
405
+ noise_pred[:, idx - valid_interval_start],
406
+ timestep_i[idx],
407
+ latents[0][:, idx],
408
+ return_dict=False,
409
+ generator=generator,
410
+ )[0]
411
+ sample_schedulers_counter[idx] += 1
412
+ if self.offload:
413
+ self.transformer.cpu()
414
+ torch.cuda.empty_cache()
415
+ x0 = latents[0].unsqueeze(0)
416
+ videos = [self.vae.decode(x0)[0]]
417
+ if output_video is None:
418
+ output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w
419
+ else:
420
+ output_video = torch.cat(
421
+ [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1
422
+ ) # c, f, h, w
423
+ output_video = [(output_video / 2 + 0.5).clamp(0, 1)]
424
+ output_video = [video for video in output_video]
425
+ output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video]
426
+ output_video = [video.cpu().numpy().astype(np.uint8) for video in output_video]
427
+ return output_video
skyreels_v2_infer/pipelines/image2video_pipeline.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ from typing import Optional
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from diffusers.image_processor import PipelineImageInput
9
+ from diffusers.video_processor import VideoProcessor
10
+ from PIL import Image
11
+ from tqdm import tqdm
12
+
13
+ from ..modules import get_image_encoder
14
+ from ..modules import get_text_encoder
15
+ from ..modules import get_transformer
16
+ from ..modules import get_vae
17
+ from ..scheduler.fm_solvers_unipc import FlowUniPCMultistepScheduler
18
+
19
+
20
+ def resizecrop(image: Image.Image, th, tw):
21
+ w, h = image.size
22
+ if w == tw and h == th:
23
+ return image
24
+ if h / w > th / tw:
25
+ new_w = int(w)
26
+ new_h = int(new_w * th / tw)
27
+ else:
28
+ new_h = int(h)
29
+ new_w = int(new_h * tw / th)
30
+ left = (w - new_w) / 2
31
+ top = (h - new_h) / 2
32
+ right = (w + new_w) / 2
33
+ bottom = (h + new_h) / 2
34
+ image = image.crop((left, top, right, bottom))
35
+ return image
36
+
37
+
38
+ class Image2VideoPipeline:
39
+ def __init__(
40
+ self, model_path, dit_path, device: str = "cuda", weight_dtype=torch.bfloat16, use_usp=False, offload=False
41
+ ):
42
+ load_device = "cpu" if offload else device
43
+ self.transformer = get_transformer(dit_path, load_device, weight_dtype)
44
+ vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth")
45
+ self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
46
+ self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype)
47
+ self.clip = get_image_encoder(model_path, load_device, weight_dtype)
48
+ self.sp_size = 1
49
+ self.device = device
50
+ self.offload = offload
51
+ self.video_processor = VideoProcessor(vae_scale_factor=16)
52
+ if use_usp:
53
+ from xfuser.core.distributed import get_sequence_parallel_world_size
54
+ from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
55
+ import types
56
+
57
+ for block in self.transformer.blocks:
58
+ block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
59
+ self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer)
60
+ self.sp_size = get_sequence_parallel_world_size()
61
+
62
+ self.scheduler = FlowUniPCMultistepScheduler()
63
+ self.vae_stride = (4, 8, 8)
64
+ self.patch_size = (1, 2, 2)
65
+
66
+ @torch.no_grad()
67
+ def __call__(
68
+ self,
69
+ image: PipelineImageInput,
70
+ prompt: Union[str, List[str]] = None,
71
+ negative_prompt: Union[str, List[str]] = None,
72
+ height: int = 544,
73
+ width: int = 960,
74
+ num_frames: int = 97,
75
+ num_inference_steps: int = 50,
76
+ guidance_scale: float = 5.0,
77
+ shift: float = 5.0,
78
+ generator: Optional[torch.Generator] = None,
79
+ ):
80
+ F = num_frames
81
+
82
+ latent_height = height // 8 // 2 * 2
83
+ latent_width = width // 8 // 2 * 2
84
+ latent_length = (F - 1) // 4 + 1
85
+
86
+ h = latent_height * 8
87
+ w = latent_width * 8
88
+
89
+ img = self.video_processor.preprocess(image, height=h, width=w)
90
+
91
+ img = img.to(device=self.device, dtype=self.transformer.dtype)
92
+
93
+ padding_video = torch.zeros(img.shape[0], 3, F - 1, h, w, device=self.device)
94
+
95
+ img = img.unsqueeze(2)
96
+ img_cond = torch.concat([img, padding_video], dim=2)
97
+ img_cond = self.vae.encode(img_cond)
98
+ mask = torch.ones_like(img_cond)
99
+ mask[:, :, 1:] = 0
100
+ y = torch.cat([mask[:, :4], img_cond], dim=1)
101
+ self.clip.to(self.device)
102
+ clip_context = self.clip.encode_video(img)
103
+ if self.offload:
104
+ self.clip.cpu()
105
+ torch.cuda.empty_cache()
106
+
107
+ # preprocess
108
+ self.text_encoder.to(self.device)
109
+ context = self.text_encoder.encode(prompt).to(self.device)
110
+ context_null = self.text_encoder.encode(negative_prompt).to(self.device)
111
+ if self.offload:
112
+ self.text_encoder.cpu()
113
+ torch.cuda.empty_cache()
114
+
115
+ latent = torch.randn(
116
+ 16, latent_length, latent_height, latent_width, dtype=torch.float32, generator=generator, device=self.device
117
+ )
118
+
119
+ self.transformer.to(self.device)
120
+ with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad():
121
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
122
+ timesteps = self.scheduler.timesteps
123
+
124
+ arg_c = {
125
+ "context": context,
126
+ "clip_fea": clip_context,
127
+ "y": y,
128
+ }
129
+
130
+ arg_null = {
131
+ "context": context_null,
132
+ "clip_fea": clip_context,
133
+ "y": y,
134
+ }
135
+
136
+ self.transformer.to(self.device)
137
+ for _, t in enumerate(tqdm(timesteps)):
138
+ latent_model_input = torch.stack([latent]).to(self.device)
139
+ timestep = torch.stack([t]).to(self.device)
140
+ noise_pred_cond = self.transformer(latent_model_input, t=timestep, **arg_c)[0].to(self.device)
141
+ noise_pred_uncond = self.transformer(latent_model_input, t=timestep, **arg_null)[0].to(self.device)
142
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
143
+
144
+ temp_x0 = self.scheduler.step(
145
+ noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False, generator=generator
146
+ )[0]
147
+ latent = temp_x0.squeeze(0)
148
+ if self.offload:
149
+ self.transformer.cpu()
150
+ torch.cuda.empty_cache()
151
+ videos = self.vae.decode(latent)
152
+ videos = (videos / 2 + 0.5).clamp(0, 1)
153
+ videos = [video for video in videos]
154
+ videos = [video.permute(1, 2, 3, 0) * 255 for video in videos]
155
+ videos = [video.cpu().numpy().astype(np.uint8) for video in videos]
156
+ return videos
skyreels_v2_infer/pipelines/prompt_enhancer.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+
4
+ sys_prompt = """
5
+ Transform the short prompt into a detailed video-generation caption using this structure:
6
+ ​​Opening shot type​​ (long/medium/close-up/extreme close-up/full shot)
7
+ ​​Primary subject(s)​​ with vivid attributes (colors, textures, actions, interactions)
8
+ ​​Dynamic elements​​ (movement, transitions, or changes over time, e.g., 'gradually lowers,' 'begins to climb,' 'camera moves toward...')
9
+ ​​Scene composition​​ (background, environment, spatial relationships)
10
+ ​​Lighting/atmosphere​​ (natural/artificial, time of day, mood)
11
+ ​​Camera motion​​ (zooms, pans, static/handheld shots) if applicable.
12
+
13
+ Pattern Summary from Examples:
14
+ [Shot Type] of [Subject+Action] + [Detailed Subject Description] + [Environmental Context] + [Lighting Conditions] + [Camera Movement]
15
+
16
+ ​One case:
17
+ Short prompt: a person is playing football
18
+ Long prompt: Medium shot of a young athlete in a red jersey sprinting across a muddy field, dribbling a soccer ball with precise footwork. The player glances toward the goalpost, adjusts their stance, and kicks the ball forcefully into the net. Raindrops fall lightly, creating reflections under stadium floodlights. The camera follows the ball’s trajectory in a smooth pan.
19
+
20
+ Note: If the subject is stationary, incorporate camera movement to ensure the generated video remains dynamic.
21
+
22
+ ​​Now expand this short prompt:​​ [{}]. Please only output the final long prompt in English.
23
+ """
24
+
25
+ class PromptEnhancer:
26
+ def __init__(self, model_name="Qwen/Qwen2.5-32B-Instruct"):
27
+ self.model = AutoModelForCausalLM.from_pretrained(
28
+ model_name,
29
+ torch_dtype="auto",
30
+ device_map="cuda:0",
31
+ )
32
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
33
+
34
+ def __call__(self, prompt):
35
+ prompt = prompt.strip()
36
+ prompt = sys_prompt.format(prompt)
37
+ messages = [
38
+ {"role": "system", "content": "You are a helpful assistant."},
39
+ {"role": "user", "content": prompt}
40
+ ]
41
+ text = self.tokenizer.apply_chat_template(
42
+ messages,
43
+ tokenize=False,
44
+ add_generation_prompt=True
45
+ )
46
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
47
+ generated_ids = self.model.generate(
48
+ **model_inputs,
49
+ max_new_tokens=2048,
50
+ )
51
+ generated_ids = [
52
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
53
+ ]
54
+ rewritten_prompt = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
55
+ return rewritten_prompt
56
+
57
+ if __name__ == '__main__':
58
+ parser = argparse.ArgumentParser()
59
+ parser.add_argument("--prompt", type=str, default="In a still frame, a stop sign")
60
+ args = parser.parse_args()
61
+
62
+ prompt_enhancer = PromptEnhancer()
63
+ enhanced_prompt = prompt_enhancer(args.prompt)
64
+ print(f'Original prompt: {args.prompt}')
65
+ print(f'Enhanced prompt: {enhanced_prompt}')
skyreels_v2_infer/pipelines/text2video_pipeline.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ from typing import Optional
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from diffusers.video_processor import VideoProcessor
9
+ from tqdm import tqdm
10
+
11
+ from ..modules import get_text_encoder
12
+ from ..modules import get_transformer
13
+ from ..modules import get_vae
14
+ from ..scheduler.fm_solvers_unipc import FlowUniPCMultistepScheduler
15
+
16
+
17
+ class Text2VideoPipeline:
18
+ def __init__(
19
+ self, model_path, dit_path, device: str = "cuda", weight_dtype=torch.bfloat16, use_usp=False, offload=False
20
+ ):
21
+ load_device = "cpu" if offload else device
22
+ self.transformer = get_transformer(dit_path, load_device, weight_dtype)
23
+ vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth")
24
+ self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
25
+ self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype)
26
+ self.video_processor = VideoProcessor(vae_scale_factor=16)
27
+ self.sp_size = 1
28
+ self.device = device
29
+ self.offload = offload
30
+ if use_usp:
31
+ from xfuser.core.distributed import get_sequence_parallel_world_size
32
+ from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
33
+ import types
34
+
35
+ for block in self.transformer.blocks:
36
+ block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
37
+ self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer)
38
+ self.sp_size = get_sequence_parallel_world_size()
39
+
40
+ self.scheduler = FlowUniPCMultistepScheduler()
41
+ self.vae_stride = (4, 8, 8)
42
+ self.patch_size = (1, 2, 2)
43
+
44
+ @torch.no_grad()
45
+ def __call__(
46
+ self,
47
+ prompt: Union[str, List[str]] = None,
48
+ negative_prompt: Union[str, List[str]] = None,
49
+ width: int = 544,
50
+ height: int = 960,
51
+ num_frames: int = 97,
52
+ num_inference_steps: int = 50,
53
+ guidance_scale: float = 5.0,
54
+ shift: float = 5.0,
55
+ generator: Optional[torch.Generator] = None,
56
+ ):
57
+ # preprocess
58
+ F = num_frames
59
+ target_shape = (
60
+ self.vae.vae.z_dim,
61
+ (F - 1) // self.vae_stride[0] + 1,
62
+ height // self.vae_stride[1],
63
+ width // self.vae_stride[2],
64
+ )
65
+ self.text_encoder.to(self.device)
66
+ context = self.text_encoder.encode(prompt).to(self.device)
67
+ context_null = self.text_encoder.encode(negative_prompt).to(self.device)
68
+ if self.offload:
69
+ self.text_encoder.cpu()
70
+ torch.cuda.empty_cache()
71
+
72
+ latents = [
73
+ torch.randn(
74
+ target_shape[0],
75
+ target_shape[1],
76
+ target_shape[2],
77
+ target_shape[3],
78
+ dtype=torch.float32,
79
+ device=self.device,
80
+ generator=generator,
81
+ )
82
+ ]
83
+
84
+ # evaluation mode
85
+ self.transformer.to(self.device)
86
+ with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad():
87
+ self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
88
+ timesteps = self.scheduler.timesteps
89
+
90
+ for _, t in enumerate(tqdm(timesteps)):
91
+ latent_model_input = torch.stack(latents)
92
+ timestep = torch.stack([t])
93
+ noise_pred_cond = self.transformer(latent_model_input, t=timestep, context=context)[0]
94
+ noise_pred_uncond = self.transformer(latent_model_input, t=timestep, context=context_null)[0]
95
+
96
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
97
+
98
+ temp_x0 = self.scheduler.step(
99
+ noise_pred.unsqueeze(0), t, latents[0].unsqueeze(0), return_dict=False, generator=generator
100
+ )[0]
101
+ latents = [temp_x0.squeeze(0)]
102
+ if self.offload:
103
+ self.transformer.cpu()
104
+ torch.cuda.empty_cache()
105
+ videos = self.vae.decode(latents[0])
106
+ videos = (videos / 2 + 0.5).clamp(0, 1)
107
+ videos = [video for video in videos]
108
+ videos = [video.permute(1, 2, 3, 0) * 255 for video in videos]
109
+ videos = [video.cpu().numpy().astype(np.uint8) for video in videos]
110
+ return videos
skyreels_v2_infer/scheduler/__init__.py ADDED
File without changes
skyreels_v2_infer/scheduler/fm_solvers_unipc.py ADDED
@@ -0,0 +1,759 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
2
+ # Convert unipc for flow matching
3
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
4
+ import math
5
+ from typing import List
6
+ from typing import Optional
7
+ from typing import Tuple
8
+ from typing import Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ from diffusers.configuration_utils import ConfigMixin
13
+ from diffusers.configuration_utils import register_to_config
14
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers
15
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
16
+ from diffusers.schedulers.scheduling_utils import SchedulerOutput
17
+ from diffusers.utils import deprecate
18
+
19
+
20
+ class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
21
+ """
22
+ `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
23
+
24
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
25
+ methods the library implements for all schedulers such as loading and saving.
26
+
27
+ Args:
28
+ num_train_timesteps (`int`, defaults to 1000):
29
+ The number of diffusion steps to train the model.
30
+ solver_order (`int`, default `2`):
31
+ The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
32
+ due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
33
+ unconditional sampling.
34
+ prediction_type (`str`, defaults to "flow_prediction"):
35
+ Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
36
+ the flow of the diffusion process.
37
+ thresholding (`bool`, defaults to `False`):
38
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
39
+ as Stable Diffusion.
40
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
41
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
42
+ sample_max_value (`float`, defaults to 1.0):
43
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
44
+ predict_x0 (`bool`, defaults to `True`):
45
+ Whether to use the updating algorithm on the predicted x0.
46
+ solver_type (`str`, default `bh2`):
47
+ Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
48
+ otherwise.
49
+ lower_order_final (`bool`, default `True`):
50
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
51
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
52
+ disable_corrector (`list`, default `[]`):
53
+ Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
54
+ and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
55
+ usually disabled during the first few steps.
56
+ solver_p (`SchedulerMixin`, default `None`):
57
+ Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
58
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
59
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
60
+ the sigmas are determined according to a sequence of noise levels {σi}.
61
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
62
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
63
+ timestep_spacing (`str`, defaults to `"linspace"`):
64
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
65
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
66
+ steps_offset (`int`, defaults to 0):
67
+ An offset added to the inference steps, as required by some model families.
68
+ final_sigmas_type (`str`, defaults to `"zero"`):
69
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
70
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
71
+ """
72
+
73
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
74
+ order = 1
75
+
76
+ @register_to_config
77
+ def __init__(
78
+ self,
79
+ num_train_timesteps: int = 1000,
80
+ solver_order: int = 2,
81
+ prediction_type: str = "flow_prediction",
82
+ shift: Optional[float] = 1.0,
83
+ use_dynamic_shifting=False,
84
+ thresholding: bool = False,
85
+ dynamic_thresholding_ratio: float = 0.995,
86
+ sample_max_value: float = 1.0,
87
+ predict_x0: bool = True,
88
+ solver_type: str = "bh2",
89
+ lower_order_final: bool = True,
90
+ disable_corrector: List[int] = [],
91
+ solver_p: SchedulerMixin = None,
92
+ timestep_spacing: str = "linspace",
93
+ steps_offset: int = 0,
94
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
95
+ ):
96
+
97
+ if solver_type not in ["bh1", "bh2"]:
98
+ if solver_type in ["midpoint", "heun", "logrho"]:
99
+ self.register_to_config(solver_type="bh2")
100
+ else:
101
+ raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
102
+
103
+ self.predict_x0 = predict_x0
104
+ # setable values
105
+ self.num_inference_steps = None
106
+ alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
107
+ sigmas = 1.0 - alphas
108
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
109
+
110
+ if not use_dynamic_shifting:
111
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
112
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
113
+
114
+ self.sigmas = sigmas
115
+ self.timesteps = sigmas * num_train_timesteps
116
+
117
+ self.model_outputs = [None] * solver_order
118
+ self.timestep_list = [None] * solver_order
119
+ self.lower_order_nums = 0
120
+ self.disable_corrector = disable_corrector
121
+ self.solver_p = solver_p
122
+ self.last_sample = None
123
+ self._step_index = None
124
+ self._begin_index = None
125
+
126
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
127
+ self.sigma_min = self.sigmas[-1].item()
128
+ self.sigma_max = self.sigmas[0].item()
129
+
130
+ @property
131
+ def step_index(self):
132
+ """
133
+ The index counter for current timestep. It will increase 1 after each scheduler step.
134
+ """
135
+ return self._step_index
136
+
137
+ @property
138
+ def begin_index(self):
139
+ """
140
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
141
+ """
142
+ return self._begin_index
143
+
144
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
145
+ def set_begin_index(self, begin_index: int = 0):
146
+ """
147
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
148
+
149
+ Args:
150
+ begin_index (`int`):
151
+ The begin index for the scheduler.
152
+ """
153
+ self._begin_index = begin_index
154
+
155
+ # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
156
+ def set_timesteps(
157
+ self,
158
+ num_inference_steps: Union[int, None] = None,
159
+ device: Union[str, torch.device] = None,
160
+ sigmas: Optional[List[float]] = None,
161
+ mu: Optional[Union[float, None]] = None,
162
+ shift: Optional[Union[float, None]] = None,
163
+ ):
164
+ """
165
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
166
+ Args:
167
+ num_inference_steps (`int`):
168
+ Total number of the spacing of the time steps.
169
+ device (`str` or `torch.device`, *optional*):
170
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
171
+ """
172
+
173
+ if self.config.use_dynamic_shifting and mu is None:
174
+ raise ValueError(" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
175
+
176
+ if sigmas is None:
177
+ sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] # pyright: ignore
178
+
179
+ if self.config.use_dynamic_shifting:
180
+ sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
181
+ else:
182
+ if shift is None:
183
+ shift = self.config.shift
184
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
185
+
186
+ if self.config.final_sigmas_type == "sigma_min":
187
+ sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
188
+ elif self.config.final_sigmas_type == "zero":
189
+ sigma_last = 0
190
+ else:
191
+ raise ValueError(
192
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
193
+ )
194
+
195
+ timesteps = sigmas * self.config.num_train_timesteps
196
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) # pyright: ignore
197
+
198
+ self.sigmas = torch.from_numpy(sigmas)
199
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
200
+
201
+ self.num_inference_steps = len(timesteps)
202
+
203
+ self.model_outputs = [
204
+ None,
205
+ ] * self.config.solver_order
206
+ self.lower_order_nums = 0
207
+ self.last_sample = None
208
+ if self.solver_p:
209
+ self.solver_p.set_timesteps(self.num_inference_steps, device=device)
210
+
211
+ # add an index counter for schedulers that allow duplicated timesteps
212
+ self._step_index = None
213
+ self._begin_index = None
214
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
215
+
216
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
217
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
218
+ """
219
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
220
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
221
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
222
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
223
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
224
+
225
+ https://arxiv.org/abs/2205.11487
226
+ """
227
+ dtype = sample.dtype
228
+ batch_size, channels, *remaining_dims = sample.shape
229
+
230
+ if dtype not in (torch.float32, torch.float64):
231
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
232
+
233
+ # Flatten sample for doing quantile calculation along each image
234
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
235
+
236
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
237
+
238
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
239
+ s = torch.clamp(
240
+ s, min=1, max=self.config.sample_max_value
241
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
242
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
243
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
244
+
245
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
246
+ sample = sample.to(dtype)
247
+
248
+ return sample
249
+
250
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
251
+ def _sigma_to_t(self, sigma):
252
+ return sigma * self.config.num_train_timesteps
253
+
254
+ def _sigma_to_alpha_sigma_t(self, sigma):
255
+ return 1 - sigma, sigma
256
+
257
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
258
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
259
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
260
+
261
+ def convert_model_output(
262
+ self,
263
+ model_output: torch.Tensor,
264
+ *args,
265
+ sample: torch.Tensor = None,
266
+ **kwargs,
267
+ ) -> torch.Tensor:
268
+ r"""
269
+ Convert the model output to the corresponding type the UniPC algorithm needs.
270
+
271
+ Args:
272
+ model_output (`torch.Tensor`):
273
+ The direct output from the learned diffusion model.
274
+ timestep (`int`):
275
+ The current discrete timestep in the diffusion chain.
276
+ sample (`torch.Tensor`):
277
+ A current instance of a sample created by the diffusion process.
278
+
279
+ Returns:
280
+ `torch.Tensor`:
281
+ The converted model output.
282
+ """
283
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
284
+ if sample is None:
285
+ if len(args) > 1:
286
+ sample = args[1]
287
+ else:
288
+ raise ValueError("missing `sample` as a required keyward argument")
289
+ if timestep is not None:
290
+ deprecate(
291
+ "timesteps",
292
+ "1.0.0",
293
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
294
+ )
295
+
296
+ sigma = self.sigmas[self.step_index]
297
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
298
+
299
+ if self.predict_x0:
300
+ if self.config.prediction_type == "flow_prediction":
301
+ sigma_t = self.sigmas[self.step_index]
302
+ x0_pred = sample - sigma_t * model_output
303
+ else:
304
+ raise ValueError(
305
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
306
+ " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
307
+ )
308
+
309
+ if self.config.thresholding:
310
+ x0_pred = self._threshold_sample(x0_pred)
311
+
312
+ return x0_pred
313
+ else:
314
+ if self.config.prediction_type == "flow_prediction":
315
+ sigma_t = self.sigmas[self.step_index]
316
+ epsilon = sample - (1 - sigma_t) * model_output
317
+ else:
318
+ raise ValueError(
319
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
320
+ " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
321
+ )
322
+
323
+ if self.config.thresholding:
324
+ sigma_t = self.sigmas[self.step_index]
325
+ x0_pred = sample - sigma_t * model_output
326
+ x0_pred = self._threshold_sample(x0_pred)
327
+ epsilon = model_output + x0_pred
328
+
329
+ return epsilon
330
+
331
+ def multistep_uni_p_bh_update(
332
+ self,
333
+ model_output: torch.Tensor,
334
+ *args,
335
+ sample: torch.Tensor = None,
336
+ order: int = None, # pyright: ignore
337
+ **kwargs,
338
+ ) -> torch.Tensor:
339
+ """
340
+ One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
341
+
342
+ Args:
343
+ model_output (`torch.Tensor`):
344
+ The direct output from the learned diffusion model at the current timestep.
345
+ prev_timestep (`int`):
346
+ The previous discrete timestep in the diffusion chain.
347
+ sample (`torch.Tensor`):
348
+ A current instance of a sample created by the diffusion process.
349
+ order (`int`):
350
+ The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
351
+
352
+ Returns:
353
+ `torch.Tensor`:
354
+ The sample tensor at the previous timestep.
355
+ """
356
+ prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
357
+ if sample is None:
358
+ if len(args) > 1:
359
+ sample = args[1]
360
+ else:
361
+ raise ValueError(" missing `sample` as a required keyward argument")
362
+ if order is None:
363
+ if len(args) > 2:
364
+ order = args[2]
365
+ else:
366
+ raise ValueError(" missing `order` as a required keyward argument")
367
+ if prev_timestep is not None:
368
+ deprecate(
369
+ "prev_timestep",
370
+ "1.0.0",
371
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
372
+ )
373
+ model_output_list = self.model_outputs
374
+
375
+ s0 = self.timestep_list[-1]
376
+ m0 = model_output_list[-1]
377
+ x = sample
378
+
379
+ if self.solver_p:
380
+ x_t = self.solver_p.step(model_output, s0, x).prev_sample
381
+ return x_t
382
+
383
+ sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] # pyright: ignore
384
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
385
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
386
+
387
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
388
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
389
+
390
+ h = lambda_t - lambda_s0
391
+ device = sample.device
392
+
393
+ rks = []
394
+ D1s = []
395
+ for i in range(1, order):
396
+ si = self.step_index - i # pyright: ignore
397
+ mi = model_output_list[-(i + 1)]
398
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
399
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
400
+ rk = (lambda_si - lambda_s0) / h
401
+ rks.append(rk)
402
+ D1s.append((mi - m0) / rk) # pyright: ignore
403
+
404
+ rks.append(1.0)
405
+ rks = torch.tensor(rks, device=device)
406
+
407
+ R = []
408
+ b = []
409
+
410
+ hh = -h if self.predict_x0 else h
411
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
412
+ h_phi_k = h_phi_1 / hh - 1
413
+
414
+ factorial_i = 1
415
+
416
+ if self.config.solver_type == "bh1":
417
+ B_h = hh
418
+ elif self.config.solver_type == "bh2":
419
+ B_h = torch.expm1(hh)
420
+ else:
421
+ raise NotImplementedError()
422
+
423
+ for i in range(1, order + 1):
424
+ R.append(torch.pow(rks, i - 1))
425
+ b.append(h_phi_k * factorial_i / B_h)
426
+ factorial_i *= i + 1
427
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
428
+
429
+ R = torch.stack(R)
430
+ b = torch.tensor(b, device=device)
431
+
432
+ if len(D1s) > 0:
433
+ D1s = torch.stack(D1s, dim=1) # (B, K)
434
+ # for order 2, we use a simplified version
435
+ if order == 2:
436
+ rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
437
+ else:
438
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
439
+ else:
440
+ D1s = None
441
+
442
+ if self.predict_x0:
443
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
444
+ if D1s is not None:
445
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore
446
+ else:
447
+ pred_res = 0
448
+ x_t = x_t_ - alpha_t * B_h * pred_res
449
+ else:
450
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
451
+ if D1s is not None:
452
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore
453
+ else:
454
+ pred_res = 0
455
+ x_t = x_t_ - sigma_t * B_h * pred_res
456
+
457
+ x_t = x_t.to(x.dtype)
458
+ return x_t
459
+
460
+ def multistep_uni_c_bh_update(
461
+ self,
462
+ this_model_output: torch.Tensor,
463
+ *args,
464
+ last_sample: torch.Tensor = None,
465
+ this_sample: torch.Tensor = None,
466
+ order: int = None, # pyright: ignore
467
+ **kwargs,
468
+ ) -> torch.Tensor:
469
+ """
470
+ One step for the UniC (B(h) version).
471
+
472
+ Args:
473
+ this_model_output (`torch.Tensor`):
474
+ The model outputs at `x_t`.
475
+ this_timestep (`int`):
476
+ The current timestep `t`.
477
+ last_sample (`torch.Tensor`):
478
+ The generated sample before the last predictor `x_{t-1}`.
479
+ this_sample (`torch.Tensor`):
480
+ The generated sample after the last predictor `x_{t}`.
481
+ order (`int`):
482
+ The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
483
+
484
+ Returns:
485
+ `torch.Tensor`:
486
+ The corrected sample tensor at the current timestep.
487
+ """
488
+ this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
489
+ if last_sample is None:
490
+ if len(args) > 1:
491
+ last_sample = args[1]
492
+ else:
493
+ raise ValueError(" missing`last_sample` as a required keyward argument")
494
+ if this_sample is None:
495
+ if len(args) > 2:
496
+ this_sample = args[2]
497
+ else:
498
+ raise ValueError(" missing`this_sample` as a required keyward argument")
499
+ if order is None:
500
+ if len(args) > 3:
501
+ order = args[3]
502
+ else:
503
+ raise ValueError(" missing`order` as a required keyward argument")
504
+ if this_timestep is not None:
505
+ deprecate(
506
+ "this_timestep",
507
+ "1.0.0",
508
+ "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
509
+ )
510
+
511
+ model_output_list = self.model_outputs
512
+
513
+ m0 = model_output_list[-1]
514
+ x = last_sample
515
+ x_t = this_sample
516
+ model_t = this_model_output
517
+
518
+ sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] # pyright: ignore
519
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
520
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
521
+
522
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
523
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
524
+
525
+ h = lambda_t - lambda_s0
526
+ device = this_sample.device
527
+
528
+ rks = []
529
+ D1s = []
530
+ for i in range(1, order):
531
+ si = self.step_index - (i + 1) # pyright: ignore
532
+ mi = model_output_list[-(i + 1)]
533
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
534
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
535
+ rk = (lambda_si - lambda_s0) / h
536
+ rks.append(rk)
537
+ D1s.append((mi - m0) / rk) # pyright: ignore
538
+
539
+ rks.append(1.0)
540
+ rks = torch.tensor(rks, device=device)
541
+
542
+ R = []
543
+ b = []
544
+
545
+ hh = -h if self.predict_x0 else h
546
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
547
+ h_phi_k = h_phi_1 / hh - 1
548
+
549
+ factorial_i = 1
550
+
551
+ if self.config.solver_type == "bh1":
552
+ B_h = hh
553
+ elif self.config.solver_type == "bh2":
554
+ B_h = torch.expm1(hh)
555
+ else:
556
+ raise NotImplementedError()
557
+
558
+ for i in range(1, order + 1):
559
+ R.append(torch.pow(rks, i - 1))
560
+ b.append(h_phi_k * factorial_i / B_h)
561
+ factorial_i *= i + 1
562
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
563
+
564
+ R = torch.stack(R)
565
+ b = torch.tensor(b, device=device)
566
+
567
+ if len(D1s) > 0:
568
+ D1s = torch.stack(D1s, dim=1)
569
+ else:
570
+ D1s = None
571
+
572
+ # for order 1, we use a simplified version
573
+ if order == 1:
574
+ rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
575
+ else:
576
+ rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
577
+
578
+ if self.predict_x0:
579
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
580
+ if D1s is not None:
581
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
582
+ else:
583
+ corr_res = 0
584
+ D1_t = model_t - m0
585
+ x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
586
+ else:
587
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
588
+ if D1s is not None:
589
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
590
+ else:
591
+ corr_res = 0
592
+ D1_t = model_t - m0
593
+ x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
594
+ x_t = x_t.to(x.dtype)
595
+ return x_t
596
+
597
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
598
+ if schedule_timesteps is None:
599
+ schedule_timesteps = self.timesteps
600
+
601
+ indices = (schedule_timesteps == timestep).nonzero()
602
+
603
+ # The sigma index that is taken for the **very** first `step`
604
+ # is always the second index (or the last index if there is only 1)
605
+ # This way we can ensure we don't accidentally skip a sigma in
606
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
607
+ pos = 1 if len(indices) > 1 else 0
608
+
609
+ return indices[pos].item()
610
+
611
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
612
+ def _init_step_index(self, timestep):
613
+ """
614
+ Initialize the step_index counter for the scheduler.
615
+ """
616
+
617
+ if self.begin_index is None:
618
+ if isinstance(timestep, torch.Tensor):
619
+ timestep = timestep.to(self.timesteps.device)
620
+ self._step_index = self.index_for_timestep(timestep)
621
+ else:
622
+ self._step_index = self._begin_index
623
+
624
+ def step(
625
+ self,
626
+ model_output: torch.Tensor,
627
+ timestep: Union[int, torch.Tensor],
628
+ sample: torch.Tensor,
629
+ return_dict: bool = True,
630
+ generator=None,
631
+ ) -> Union[SchedulerOutput, Tuple]:
632
+ """
633
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
634
+ the multistep UniPC.
635
+
636
+ Args:
637
+ model_output (`torch.Tensor`):
638
+ The direct output from learned diffusion model.
639
+ timestep (`int`):
640
+ The current discrete timestep in the diffusion chain.
641
+ sample (`torch.Tensor`):
642
+ A current instance of a sample created by the diffusion process.
643
+ return_dict (`bool`):
644
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
645
+
646
+ Returns:
647
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
648
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
649
+ tuple is returned where the first element is the sample tensor.
650
+
651
+ """
652
+ if self.num_inference_steps is None:
653
+ raise ValueError(
654
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
655
+ )
656
+
657
+ if self.step_index is None:
658
+ self._init_step_index(timestep)
659
+
660
+ use_corrector = (
661
+ self.step_index > 0
662
+ and self.step_index - 1 not in self.disable_corrector
663
+ and self.last_sample is not None # pyright: ignore
664
+ )
665
+
666
+ model_output_convert = self.convert_model_output(model_output, sample=sample)
667
+ if use_corrector:
668
+ sample = self.multistep_uni_c_bh_update(
669
+ this_model_output=model_output_convert,
670
+ last_sample=self.last_sample,
671
+ this_sample=sample,
672
+ order=self.this_order,
673
+ )
674
+
675
+ for i in range(self.config.solver_order - 1):
676
+ self.model_outputs[i] = self.model_outputs[i + 1]
677
+ self.timestep_list[i] = self.timestep_list[i + 1]
678
+
679
+ self.model_outputs[-1] = model_output_convert
680
+ self.timestep_list[-1] = timestep # pyright: ignore
681
+
682
+ if self.config.lower_order_final:
683
+ this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) # pyright: ignore
684
+ else:
685
+ this_order = self.config.solver_order
686
+
687
+ self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
688
+ assert self.this_order > 0
689
+
690
+ self.last_sample = sample
691
+ prev_sample = self.multistep_uni_p_bh_update(
692
+ model_output=model_output, # pass the original non-converted model output, in case solver-p is used
693
+ sample=sample,
694
+ order=self.this_order,
695
+ )
696
+
697
+ if self.lower_order_nums < self.config.solver_order:
698
+ self.lower_order_nums += 1
699
+
700
+ # upon completion increase step index by one
701
+ self._step_index += 1 # pyright: ignore
702
+
703
+ if not return_dict:
704
+ return (prev_sample,)
705
+
706
+ return SchedulerOutput(prev_sample=prev_sample)
707
+
708
+ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
709
+ """
710
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
711
+ current timestep.
712
+
713
+ Args:
714
+ sample (`torch.Tensor`):
715
+ The input sample.
716
+
717
+ Returns:
718
+ `torch.Tensor`:
719
+ A scaled input sample.
720
+ """
721
+ return sample
722
+
723
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
724
+ def add_noise(
725
+ self,
726
+ original_samples: torch.Tensor,
727
+ noise: torch.Tensor,
728
+ timesteps: torch.IntTensor,
729
+ ) -> torch.Tensor:
730
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
731
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
732
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
733
+ # mps does not support float64
734
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
735
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
736
+ else:
737
+ schedule_timesteps = self.timesteps.to(original_samples.device)
738
+ timesteps = timesteps.to(original_samples.device)
739
+
740
+ # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
741
+ if self.begin_index is None:
742
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
743
+ elif self.step_index is not None:
744
+ # add_noise is called after first denoising step (for inpainting)
745
+ step_indices = [self.step_index] * timesteps.shape[0]
746
+ else:
747
+ # add noise is called before first denoising step to create initial latent(img2img)
748
+ step_indices = [self.begin_index] * timesteps.shape[0]
749
+
750
+ sigma = sigmas[step_indices].flatten()
751
+ while len(sigma.shape) < len(original_samples.shape):
752
+ sigma = sigma.unsqueeze(-1)
753
+
754
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
755
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
756
+ return noisy_samples
757
+
758
+ def __len__(self):
759
+ return self.config.num_train_timesteps