Spaces:
Running
Running
Migrated from GitHub
Browse files- .gitattributes +5 -0
- .pre-commit-config.yaml +25 -0
- LICENSE.txt +38 -0
- ORIGINAL_README.md +717 -0
- assets/logo2.png +0 -0
- assets/main_pipeline.jpg +3 -0
- generate_video.py +161 -0
- generate_video_df.py +186 -0
- requirements.txt +16 -0
- skycaptioner_v1/README.md +262 -0
- skycaptioner_v1/examples/data/1.mp4 +3 -0
- skycaptioner_v1/examples/data/2.mp4 +3 -0
- skycaptioner_v1/examples/data/3.mp4 +3 -0
- skycaptioner_v1/examples/data/4.mp4 +3 -0
- skycaptioner_v1/examples/test.csv +5 -0
- skycaptioner_v1/examples/test_result.csv +5 -0
- skycaptioner_v1/infer_fusion_caption.sh +9 -0
- skycaptioner_v1/infer_struct_caption.sh +8 -0
- skycaptioner_v1/requirements.txt +3 -0
- skycaptioner_v1/scripts/utils.py +19 -0
- skycaptioner_v1/scripts/vllm_fusion_caption.py +250 -0
- skycaptioner_v1/scripts/vllm_struct_caption.py +153 -0
- skyreels_v2_infer/__init__.py +1 -0
- skyreels_v2_infer/distributed/__init__.py +0 -0
- skyreels_v2_infer/distributed/xdit_context_parallel.py +286 -0
- skyreels_v2_infer/modules/__init__.py +69 -0
- skyreels_v2_infer/modules/attention.py +179 -0
- skyreels_v2_infer/modules/clip.py +525 -0
- skyreels_v2_infer/modules/t5.py +454 -0
- skyreels_v2_infer/modules/tokenizers.py +78 -0
- skyreels_v2_infer/modules/transformer.py +839 -0
- skyreels_v2_infer/modules/vae.py +639 -0
- skyreels_v2_infer/modules/xlm_roberta.py +165 -0
- skyreels_v2_infer/pipelines/__init__.py +5 -0
- skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py +427 -0
- skyreels_v2_infer/pipelines/image2video_pipeline.py +156 -0
- skyreels_v2_infer/pipelines/prompt_enhancer.py +65 -0
- skyreels_v2_infer/pipelines/text2video_pipeline.py +110 -0
- skyreels_v2_infer/scheduler/__init__.py +0 -0
- skyreels_v2_infer/scheduler/fm_solvers_unipc.py +759 -0
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/main_pipeline.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
skycaptioner_v1/examples/data/1.mp4 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
skycaptioner_v1/examples/data/2.mp4 filter=lfs diff=lfs merge=lfs -text
|
39 |
+
skycaptioner_v1/examples/data/3.mp4 filter=lfs diff=lfs merge=lfs -text
|
40 |
+
skycaptioner_v1/examples/data/4.mp4 filter=lfs diff=lfs merge=lfs -text
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
- repo: https://github.com/asottile/reorder-python-imports.git
|
3 |
+
rev: v3.8.3
|
4 |
+
hooks:
|
5 |
+
- id: reorder-python-imports
|
6 |
+
name: Reorder Python imports
|
7 |
+
types: [file, python]
|
8 |
+
- repo: https://github.com/psf/black.git
|
9 |
+
rev: 22.8.0
|
10 |
+
hooks:
|
11 |
+
- id: black
|
12 |
+
additional_dependencies: ['click==8.0.4']
|
13 |
+
args: [--line-length=120]
|
14 |
+
types: [file, python]
|
15 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks.git
|
16 |
+
rev: v4.3.0
|
17 |
+
hooks:
|
18 |
+
- id: check-byte-order-marker
|
19 |
+
types: [file, python]
|
20 |
+
- id: trailing-whitespace
|
21 |
+
types: [file, python]
|
22 |
+
- id: end-of-file-fixer
|
23 |
+
types: [file, python]
|
24 |
+
|
25 |
+
|
LICENSE.txt
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- en
|
4 |
+
- zh
|
5 |
+
license: other
|
6 |
+
tasks:
|
7 |
+
- text-generation
|
8 |
+
|
9 |
+
---
|
10 |
+
|
11 |
+
<!-- markdownlint-disable first-line-h1 -->
|
12 |
+
<!-- markdownlint-disable html -->
|
13 |
+
|
14 |
+
# <span id="Terms">声明与协议/Terms and Conditions</span>
|
15 |
+
|
16 |
+
## 声明
|
17 |
+
|
18 |
+
我们在此声明,不要利用Skywork模型进行任何危害国家社会安全或违法的活动。另外,我们也要求使用者不要将 Skywork 模型用于未经适当安全审查和备案的互联网服务。我们希望所有的使用者都能遵守这个原则,确保科技的发展能在规范和合法的环境下进行。
|
19 |
+
|
20 |
+
我们已经尽我们所能,来确保模型训练过程中使用的数据的合规性。然而,尽管我们已经做出了巨大的努力,但由于模型和数据的复杂性,仍有可能存在一些无法预见的问题。因此,如果由于使用skywork开源模型而导致的任何问题,包括但不限于数据安全问题、公共舆论风险,或模型被误导、滥用、传播或不当利用所带来的任何风险和问题,我们将不承担任何责任。
|
21 |
+
|
22 |
+
We hereby declare that the Skywork model should not be used for any activities that pose a threat to national or societal security or engage in unlawful actions. Additionally, we request users not to deploy the Skywork model for internet services without appropriate security reviews and records. We hope that all users will adhere to this principle to ensure that technological advancements occur in a regulated and lawful environment.
|
23 |
+
|
24 |
+
We have done our utmost to ensure the compliance of the data used during the model's training process. However, despite our extensive efforts, due to the complexity of the model and data, there may still be unpredictable risks and issues. Therefore, if any problems arise as a result of using the Skywork open-source model, including but not limited to data security issues, public opinion risks, or any risks and problems arising from the model being misled, abused, disseminated, or improperly utilized, we will not assume any responsibility.
|
25 |
+
|
26 |
+
## 协议
|
27 |
+
|
28 |
+
社区使用Skywork模型需要遵循[《Skywork 模型社区许可协议》](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf)。Skywork模型支持商业用途,如果您计划将Skywork模型或其衍生品用于商业目的,无需再次申请, 但请您仔细阅读[《Skywork 模型社区许可协议》](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf)并严格遵守相关条款。
|
29 |
+
|
30 |
+
|
31 |
+
The community usage of Skywork model requires [Skywork Community License](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20Community%20License.pdf). The Skywork model supports commercial use. If you plan to use the Skywork model or its derivatives for commercial purposes, you must abide by terms and conditions within [Skywork Community License](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20Community%20License.pdf).
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
[《Skywork 模型社区许可协议》》]:https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf
|
36 |
+
|
37 |
+
|
38 |
+
[[email protected]]: mailto:[email protected]
|
ORIGINAL_README.md
ADDED
@@ -0,0 +1,717 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="center">
|
2 |
+
<img src="assets/logo2.png" alt="SkyReels Logo" width="50%">
|
3 |
+
</p>
|
4 |
+
|
5 |
+
<h1 align="center">SkyReels V2: Infinite-Length Film Generative Model</h1>
|
6 |
+
|
7 |
+
<p align="center">
|
8 |
+
📑 <a href="https://arxiv.org/pdf/2504.13074">Technical Report</a> · 👋 <a href="https://www.skyreels.ai/home?utm_campaign=github_SkyReels_V2" target="_blank">Playground</a> · 💬 <a href="https://discord.gg/PwM6NYtccQ" target="_blank">Discord</a> · 🤗 <a href="https://huggingface.co/collections/Skywork/skyreels-v2-6801b1b93df627d441d0d0d9" target="_blank">Hugging Face</a> · 🤖 <a href="https://www.modelscope.cn/collections/SkyReels-V2-f665650130b144" target="_blank">ModelScope</a>
|
9 |
+
</p>
|
10 |
+
|
11 |
+
---
|
12 |
+
Welcome to the **SkyReels V2** repository! Here, you'll find the model weights and inference code for our infinite-length film generative models. To the best of our knowledge, it represents the first open-source video generative model employing **AutoRegressive Diffusion-Forcing architecture** that achieves the **SOTA performance** among publicly available models.
|
13 |
+
|
14 |
+
|
15 |
+
## 🔥🔥🔥 News!!
|
16 |
+
* Apr 24, 2025: 🔥 We release the 720P models, [SkyReels-V2-DF-14B-720P](https://huggingface.co/Skywork/SkyReels-V2-DF-14B-720P) and [SkyReels-V2-I2V-14B-720P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P). The former facilitates infinite-length autoregressive video generation, and the latter focuses on Image2Video synthesis.
|
17 |
+
* Apr 21, 2025: 👋 We release the inference code and model weights of [SkyReels-V2](https://huggingface.co/collections/Skywork/skyreels-v2-6801b1b93df627d441d0d0d9) Series Models and the video captioning model [SkyCaptioner-V1](https://huggingface.co/Skywork/SkyCaptioner-V1) .
|
18 |
+
* Apr 3, 2025: 🔥 We also release [SkyReels-A2](https://github.com/SkyworkAI/SkyReels-A2). This is an open-sourced controllable video generation framework capable of assembling arbitrary visual elements.
|
19 |
+
* Feb 18, 2025: 🔥 we released [SkyReels-A1](https://github.com/SkyworkAI/SkyReels-A1). This is an open-sourced and effective framework for portrait image animation.
|
20 |
+
* Feb 18, 2025: 🔥 We released [SkyReels-V1](https://github.com/SkyworkAI/SkyReels-V1). This is the first and most advanced open-source human-centric video foundation model.
|
21 |
+
|
22 |
+
## 🎥 Demos
|
23 |
+
<table>
|
24 |
+
<tr>
|
25 |
+
<td align="center">
|
26 |
+
<video src="https://github.com/user-attachments/assets/f6f9f9a7-5d5f-433c-9d73-d8d593b7ad25" width="100%"></video>
|
27 |
+
</td>
|
28 |
+
<td align="center">
|
29 |
+
<video src="https://github.com/user-attachments/assets/0eb13415-f4d9-4aaf-bcd3-3031851109b9" width="100%"></video>
|
30 |
+
</td>
|
31 |
+
<td align="center">
|
32 |
+
<video src="https://github.com/user-attachments/assets/dcd16603-5bf4-4786-8e4d-1ed23889d07a" width="100%"></video>
|
33 |
+
</td>
|
34 |
+
</tr>
|
35 |
+
</table>
|
36 |
+
The demos above showcase 30-second videos generated using our SkyReels-V2 Diffusion Forcing model.
|
37 |
+
|
38 |
+
|
39 |
+
## 📑 TODO List
|
40 |
+
|
41 |
+
- [x] <a href="https://arxiv.org/pdf/2504.13074">Technical Report</a>
|
42 |
+
- [x] Checkpoints of the 14B and 1.3B Models Series
|
43 |
+
- [x] Single-GPU & Multi-GPU Inference Code
|
44 |
+
- [x] <a href="https://huggingface.co/Skywork/SkyCaptioner-V1">SkyCaptioner-V1</a>: A Video Captioning Model
|
45 |
+
- [x] Prompt Enhancer
|
46 |
+
- [ ] Diffusers integration
|
47 |
+
- [ ] Checkpoints of the 5B Models Series
|
48 |
+
- [ ] Checkpoints of the Camera Director Models
|
49 |
+
- [ ] Checkpoints of the Step & Guidance Distill Model
|
50 |
+
|
51 |
+
|
52 |
+
## 🚀 Quickstart
|
53 |
+
|
54 |
+
#### Installation
|
55 |
+
```shell
|
56 |
+
# clone the repository.
|
57 |
+
git clone https://github.com/SkyworkAI/SkyReels-V2
|
58 |
+
cd SkyReels-V2
|
59 |
+
# Install dependencies. Test environment uses Python 3.10.12.
|
60 |
+
pip install -r requirements.txt
|
61 |
+
```
|
62 |
+
|
63 |
+
#### Model Download
|
64 |
+
You can download our models from Hugging Face:
|
65 |
+
<table>
|
66 |
+
<thead>
|
67 |
+
<tr>
|
68 |
+
<th>Type</th>
|
69 |
+
<th>Model Variant</th>
|
70 |
+
<th>Recommended Height/Width/Frame</th>
|
71 |
+
<th>Link</th>
|
72 |
+
</tr>
|
73 |
+
</thead>
|
74 |
+
<tbody>
|
75 |
+
<tr>
|
76 |
+
<td rowspan="5">Diffusion Forcing</td>
|
77 |
+
<td>1.3B-540P</td>
|
78 |
+
<td>544 * 960 * 97f</td>
|
79 |
+
<td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-DF-1.3B-540P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-DF-1.3B-540P">ModelScope</a></td>
|
80 |
+
</tr>
|
81 |
+
<tr>
|
82 |
+
<td>5B-540P</td>
|
83 |
+
<td>544 * 960 * 97f</td>
|
84 |
+
<td>Coming Soon</td>
|
85 |
+
</tr>
|
86 |
+
<tr>
|
87 |
+
<td>5B-720P</td>
|
88 |
+
<td>720 * 1280 * 121f</td>
|
89 |
+
<td>Coming Soon</td>
|
90 |
+
</tr>
|
91 |
+
<tr>
|
92 |
+
<td>14B-540P</td>
|
93 |
+
<td>544 * 960 * 97f</td>
|
94 |
+
<td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-DF-14B-540P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-DF-14B-540P">ModelScope</a></td>
|
95 |
+
</tr>
|
96 |
+
<tr>
|
97 |
+
<td>14B-720P</td>
|
98 |
+
<td>720 * 1280 * 121f</td>
|
99 |
+
<td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-DF-14B-720P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-DF-14B-720P">ModelScope</a></td>
|
100 |
+
</tr>
|
101 |
+
<tr>
|
102 |
+
<td rowspan="5">Text-to-Video</td>
|
103 |
+
<td>1.3B-540P</td>
|
104 |
+
<td>544 * 960 * 97f</td>
|
105 |
+
<td>Coming Soon</td>
|
106 |
+
</tr>
|
107 |
+
<tr>
|
108 |
+
<td>5B-540P</td>
|
109 |
+
<td>544 * 960 * 97f</td>
|
110 |
+
<td>Coming Soon</td>
|
111 |
+
</tr>
|
112 |
+
<tr>
|
113 |
+
<td>5B-720P</td>
|
114 |
+
<td>720 * 1280 * 121f</td>
|
115 |
+
<td>Coming Soon</td>
|
116 |
+
</tr>
|
117 |
+
<tr>
|
118 |
+
<td>14B-540P</td>
|
119 |
+
<td>544 * 960 * 97f</td>
|
120 |
+
<td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-T2V-14B-540P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-T2V-14B-540P">ModelScope</a></td>
|
121 |
+
</tr>
|
122 |
+
<tr>
|
123 |
+
<td>14B-720P</td>
|
124 |
+
<td>720 * 1280 * 121f</td>
|
125 |
+
<td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-T2V-14B-720P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-T2V-14B-720P">ModelScope</a></td>
|
126 |
+
</tr>
|
127 |
+
<tr>
|
128 |
+
<td rowspan="5">Image-to-Video</td>
|
129 |
+
<td>1.3B-540P</td>
|
130 |
+
<td>544 * 960 * 97f</td>
|
131 |
+
<td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-I2V-1.3B-540P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-I2V-1.3B-540P">ModelScope</a></td>
|
132 |
+
</tr>
|
133 |
+
<tr>
|
134 |
+
<td>5B-540P</td>
|
135 |
+
<td>544 * 960 * 97f</td>
|
136 |
+
<td>Coming Soon</td>
|
137 |
+
</tr>
|
138 |
+
<tr>
|
139 |
+
<td>5B-720P</td>
|
140 |
+
<td>720 * 1280 * 121f</td>
|
141 |
+
<td>Coming Soon</td>
|
142 |
+
</tr>
|
143 |
+
<tr>
|
144 |
+
<td>14B-540P</td>
|
145 |
+
<td>544 * 960 * 97f</td>
|
146 |
+
<td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-540P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-I2V-14B-540P">ModelScope</a></td>
|
147 |
+
</tr>
|
148 |
+
<tr>
|
149 |
+
<td>14B-720P</td>
|
150 |
+
<td>720 * 1280 * 121f</td>
|
151 |
+
<td>🤗 <a href="https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P">Huggingface</a> 🤖 <a href="https://www.modelscope.cn/models/Skywork/SkyReels-V2-I2V-14B-720P">ModelScope</a></td>
|
152 |
+
</tr>
|
153 |
+
<tr>
|
154 |
+
<td rowspan="3">Camera Director</td>
|
155 |
+
<td>5B-540P</td>
|
156 |
+
<td>544 * 960 * 97f</td>
|
157 |
+
<td>Coming Soon</td>
|
158 |
+
</tr>
|
159 |
+
<tr>
|
160 |
+
<td>5B-720P</td>
|
161 |
+
<td>720 * 1280 * 121f</td>
|
162 |
+
<td>Coming Soon</td>
|
163 |
+
</tr>
|
164 |
+
<tr>
|
165 |
+
<td>14B-720P</td>
|
166 |
+
<td>720 * 1280 * 121f</td>
|
167 |
+
<td>Coming Soon</td>
|
168 |
+
</tr>
|
169 |
+
</tbody>
|
170 |
+
</table>
|
171 |
+
|
172 |
+
After downloading, set the model path in your generation commands:
|
173 |
+
|
174 |
+
|
175 |
+
#### Single GPU Inference
|
176 |
+
|
177 |
+
- **Diffusion Forcing for Long Video Generation**
|
178 |
+
|
179 |
+
The <a href="https://arxiv.org/abs/2407.01392">**Diffusion Forcing**</a> version model allows us to generate Infinite-Length videos. This model supports both **text-to-video (T2V)** and **image-to-video (I2V)** tasks, and it can perform inference in both synchronous and asynchronous modes. Here we demonstrate 2 running scripts as examples for long video generation. If you want to adjust the inference parameters, e.g., the duration of video, inference mode, read the Note below first.
|
180 |
+
|
181 |
+
synchronous generation for 10s video
|
182 |
+
```shell
|
183 |
+
model_id=Skywork/SkyReels-V2-DF-14B-540P
|
184 |
+
# synchronous inference
|
185 |
+
python3 generate_video_df.py \
|
186 |
+
--model_id ${model_id} \
|
187 |
+
--resolution 540P \
|
188 |
+
--ar_step 0 \
|
189 |
+
--base_num_frames 97 \
|
190 |
+
--num_frames 257 \
|
191 |
+
--overlap_history 17 \
|
192 |
+
--prompt "A graceful white swan with a curved neck and delicate feathers swimming in a serene lake at dawn, its reflection perfectly mirrored in the still water as mist rises from the surface, with the swan occasionally dipping its head into the water to feed." \
|
193 |
+
--addnoise_condition 20 \
|
194 |
+
--offload \
|
195 |
+
--teacache \
|
196 |
+
--use_ret_steps \
|
197 |
+
--teacache_thresh 0.3
|
198 |
+
```
|
199 |
+
|
200 |
+
asynchronous generation for 30s video
|
201 |
+
```shell
|
202 |
+
model_id=Skywork/SkyReels-V2-DF-14B-540P
|
203 |
+
# asynchronous inference
|
204 |
+
python3 generate_video_df.py \
|
205 |
+
--model_id ${model_id} \
|
206 |
+
--resolution 540P \
|
207 |
+
--ar_step 5 \
|
208 |
+
--causal_block_size 5 \
|
209 |
+
--base_num_frames 97 \
|
210 |
+
--num_frames 737 \
|
211 |
+
--overlap_history 17 \
|
212 |
+
--prompt "A graceful white swan with a curved neck and delicate feathers swimming in a serene lake at dawn, its reflection perfectly mirrored in the still water as mist rises from the surface, with the swan occasionally dipping its head into the water to feed." \
|
213 |
+
--addnoise_condition 20 \
|
214 |
+
--offload
|
215 |
+
```
|
216 |
+
|
217 |
+
> **Note**:
|
218 |
+
> - If you want to run the **image-to-video (I2V)** task, add `--image ${image_path}` to your command and it is also better to use **text-to-video (T2V)**-like prompt which includes some descriptions of the first-frame image.
|
219 |
+
> - For long video generation, you can just switch the `--num_frames`, e.g., `--num_frames 257` for 10s video, `--num_frames 377` for 15s video, `--num_frames 737` for 30s video, `--num_frames 1457` for 60s video. The number is not strictly aligned with the logical frame number for specified time duration, but it is aligned with some training parameters, which means it may perform better. When you use asynchronous inference with causal_block_size > 1, the `--num_frames` should be carefully set.
|
220 |
+
> - You can use `--ar_step 5` to enable asynchronous inference. When asynchronous inference, `--causal_block_size 5` is recommended while it is not supposed to be set for synchronous generation. REMEMBER that the frame latent number inputted into the model in every iteration, e.g., base frame latent number (e.g., (97-1)//4+1=25 for base_num_frames=97) and (e.g., (237-97-(97-17)x1+17-1)//4+1=20 for base_num_frames=97, num_frames=237, overlap_history=17) for the last iteration, MUST be divided by causal_block_size. If you find it too hard to calculate and set proper values, just use our recommended setting above :). Asynchronous inference will take more steps to diffuse the whole sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous inference may improve the instruction following and visual consistent performance.
|
221 |
+
> - To reduce peak VRAM, just lower the `--base_num_frames`, e.g., to 77 or 57, while keeping the same generative length `--num_frames` you want to generate. This may slightly reduce video quality, and it should not be set too small.
|
222 |
+
> - `--addnoise_condition` is used to help smooth the long video generation by adding some noise to the clean condition. Too large noise can cause the inconsistency as well. 20 is a recommended value, and you may try larger ones, but it is recommended to not exceed 50.
|
223 |
+
> - Generating a 540P video using the 1.3B model requires approximately 14.7GB peak VRAM, while the same resolution video using the 14B model demands around 51.2GB peak VRAM.
|
224 |
+
|
225 |
+
- **Text To Video & Image To Video**
|
226 |
+
|
227 |
+
```shell
|
228 |
+
# run Text-to-Video Generation
|
229 |
+
model_id=Skywork/SkyReels-V2-T2V-14B-540P
|
230 |
+
python3 generate_video.py \
|
231 |
+
--model_id ${model_id} \
|
232 |
+
--resolution 540P \
|
233 |
+
--num_frames 97 \
|
234 |
+
--guidance_scale 6.0 \
|
235 |
+
--shift 8.0 \
|
236 |
+
--fps 24 \
|
237 |
+
--prompt "A serene lake surrounded by towering mountains, with a few swans gracefully gliding across the water and sunlight dancing on the surface." \
|
238 |
+
--offload \
|
239 |
+
--teacache \
|
240 |
+
--use_ret_steps \
|
241 |
+
--teacache_thresh 0.3
|
242 |
+
```
|
243 |
+
> **Note**:
|
244 |
+
> - When using an **image-to-video (I2V)** model, you must provide an input image using the `--image ${image_path}` parameter. The `--guidance_scale 5.0` and `--shift 3.0` is recommended for I2V model.
|
245 |
+
> - Generating a 540P video using the 1.3B model requires approximately 14.7GB peak VRAM, while the same resolution video using the 14B model demands around 43.4GB peak VRAM.
|
246 |
+
|
247 |
+
|
248 |
+
- **Prompt Enhancer**
|
249 |
+
|
250 |
+
The prompt enhancer is implemented based on <a href="https://huggingface.co/Qwen/Qwen2.5-32B-Instruct">Qwen2.5-32B-Instruct</a> and is utilized via the `--prompt_enhancer` parameter. It works ideally for short prompts, while for long prompts, it might generate an excessively lengthy prompt that could lead to over-saturation in the generative video. Note the peak memory of GPU is 64G+ if you use `--prompt_enhancer`. If you want to obtain the enhanced prompt separately, you can also run the prompt_enhancer script separately for testing. The steps are as follows:
|
251 |
+
|
252 |
+
```shell
|
253 |
+
cd skyreels_v2_infer/pipelines
|
254 |
+
python3 prompt_enhancer.py --prompt "A serene lake surrounded by towering mountains, with a few swans gracefully gliding across the water and sunlight dancing on the surface."
|
255 |
+
```
|
256 |
+
> **Note**:
|
257 |
+
> - `--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter.
|
258 |
+
|
259 |
+
|
260 |
+
**Advanced Configuration Options**
|
261 |
+
|
262 |
+
Below are the key parameters you can customize for video generation:
|
263 |
+
|
264 |
+
| Parameter | Recommended Value | Description |
|
265 |
+
|:----------------------:|:---------:|:-----------------------------------------:|
|
266 |
+
| --prompt | | Text description for generating your video |
|
267 |
+
| --image | | Path to input image for image-to-video generation |
|
268 |
+
| --resolution | 540P or 720P | Output video resolution (select based on model type) |
|
269 |
+
| --num_frames | 97 or 121 | Total frames to generate (**97 for 540P models**, **121 for 720P models**) |
|
270 |
+
| --inference_steps | 50 | Number of denoising steps |
|
271 |
+
| --fps | 24 | Frames per second in the output video |
|
272 |
+
| --shift | 8.0 or 5.0 | Flow matching scheduler parameter (**8.0 for T2V**, **5.0 for I2V**) |
|
273 |
+
| --guidance_scale | 6.0 or 5.0 | Controls text adherence strength (**6.0 for T2V**, **5.0 for I2V**) |
|
274 |
+
| --seed | | Fixed seed for reproducible results (omit for random generation) |
|
275 |
+
| --offload | True | Offloads model components to CPU to reduce VRAM usage (recommended) |
|
276 |
+
| --use_usp | True | Enables multi-GPU acceleration with xDiT USP |
|
277 |
+
| --outdir | ./video_out | Directory where generated videos will be saved |
|
278 |
+
| --prompt_enhancer | True | Expand the prompt into a more detailed description |
|
279 |
+
| --teacache | False | Enables teacache for faster inference |
|
280 |
+
| --teacache_thresh | 0.2 | Higher speedup will cause to worse quality |
|
281 |
+
| --use_ret_steps | False | Retention Steps for teacache |
|
282 |
+
|
283 |
+
**Diffusion Forcing Additional Parameters**
|
284 |
+
| Parameter | Recommended Value | Description |
|
285 |
+
|:----------------------:|:---------:|:-----------------------------------------:|
|
286 |
+
| --ar_step | 0 | Controls asynchronous inference (0 for synchronous mode) |
|
287 |
+
| --base_num_frames | 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**) |
|
288 |
+
| --overlap_history | 17 | Number of frames to overlap for smooth transitions in long videos |
|
289 |
+
| --addnoise_condition | 20 | Improves consistency in long video generation |
|
290 |
+
| --causal_block_size | 5 | Recommended when using asynchronous inference (--ar_step > 0) |
|
291 |
+
|
292 |
+
#### Multi-GPU inference using xDiT USP
|
293 |
+
|
294 |
+
We use [xDiT](https://github.com/xdit-project/xDiT) USP to accelerate inference. For example, to generate a video with 2 GPUs, you can use the following command:
|
295 |
+
- **Diffusion Forcing**
|
296 |
+
```shell
|
297 |
+
model_id=Skywork/SkyReels-V2-DF-14B-540P
|
298 |
+
# diffusion forcing synchronous inference
|
299 |
+
torchrun --nproc_per_node=2 generate_video_df.py \
|
300 |
+
--model_id ${model_id} \
|
301 |
+
--resolution 540P \
|
302 |
+
--ar_step 0 \
|
303 |
+
--base_num_frames 97 \
|
304 |
+
--num_frames 257 \
|
305 |
+
--overlap_history 17 \
|
306 |
+
--prompt "A graceful white swan with a curved neck and delicate feathers swimming in a serene lake at dawn, its reflection perfectly mirrored in the still water as mist rises from the surface, with the swan occasionally dipping its head into the water to feed." \
|
307 |
+
--addnoise_condition 20 \
|
308 |
+
--use_usp \
|
309 |
+
--offload \
|
310 |
+
--seed 42
|
311 |
+
```
|
312 |
+
- **Text To Video & Image To Video**
|
313 |
+
```shell
|
314 |
+
# run Text-to-Video Generation
|
315 |
+
model_id=Skywork/SkyReels-V2-T2V-14B-540P
|
316 |
+
torchrun --nproc_per_node=2 generate_video.py \
|
317 |
+
--model_id ${model_id} \
|
318 |
+
--resolution 540P \
|
319 |
+
--num_frames 97 \
|
320 |
+
--guidance_scale 6.0 \
|
321 |
+
--shift 8.0 \
|
322 |
+
--fps 24 \
|
323 |
+
--offload \
|
324 |
+
--prompt "A serene lake surrounded by towering mountains, with a few swans gracefully gliding across the water and sunlight dancing on the surface." \
|
325 |
+
--use_usp \
|
326 |
+
--seed 42
|
327 |
+
```
|
328 |
+
> **Note**:
|
329 |
+
> - When using an **image-to-video (I2V)** model, you must provide an input image using the `--image ${image_path}` parameter. The `--guidance_scale 5.0` and `--shift 3.0` is recommended for I2V model.
|
330 |
+
|
331 |
+
|
332 |
+
## Contents
|
333 |
+
- [Abstract](#abstract)
|
334 |
+
- [Methodology of SkyReels-V2](#methodology-of-skyreels-v2)
|
335 |
+
- [Key Contributions of SkyReels-V2](#key-contributions-of-skyreels-v2)
|
336 |
+
- [Video Captioner](#video-captioner)
|
337 |
+
- [Reinforcement Learning](#reinforcement-learning)
|
338 |
+
- [Diffusion Forcing](#diffusion-forcing)
|
339 |
+
- [High-Quality Supervised Fine-Tuning(SFT)](#high-quality-supervised-fine-tuning-sft)
|
340 |
+
- [Performance](#performance)
|
341 |
+
- [Acknowledgements](#acknowledgements)
|
342 |
+
- [Citation](#citation)
|
343 |
+
---
|
344 |
+
|
345 |
+
## Abstract
|
346 |
+
Recent advances in video generation have been driven by diffusion models and autoregressive frameworks, yet critical challenges persist in harmonizing prompt adherence, visual quality, motion dynamics, and duration: compromises in motion dynamics to enhance temporal visual quality, constrained video duration (5-10 seconds) to prioritize resolution, and inadequate shot-aware generation stemming from general-purpose MLLMs' inability to interpret cinematic grammar, such as shot composition, actor expressions, and camera motions. These intertwined limitations hinder realistic long-form synthesis and professional film-style generation.
|
347 |
+
|
348 |
+
To address these limitations, we introduce SkyReels-V2, the world's first infinite-length film generative model using a Diffusion Forcing framework. Our approach synergizes Multi-modal Large Language Models (MLLM), Multi-stage Pretraining, Reinforcement Learning, and Diffusion Forcing techniques to achieve comprehensive optimization. Beyond its technical innovations, SkyReels-V2 enables multiple practical applications, including Story Generation, Image-to-Video Synthesis, Camera Director functionality, and multi-subject consistent video generation through our <a href="https://github.com/SkyworkAI/SkyReels-A2">Skyreels-A2</a> system.
|
349 |
+
|
350 |
+
## Methodology of SkyReels-V2
|
351 |
+
|
352 |
+
The SkyReels-V2 methodology consists of several interconnected components. It starts with a comprehensive data processing pipeline that prepares various quality training data. At its core is the Video Captioner architecture, which provides detailed annotations for video content. The system employs a multi-task pretraining strategy to build fundamental video generation capabilities. Post-training optimization includes Reinforcement Learning to enhance motion quality, Diffusion Forcing Training for generating extended videos, and High-quality Supervised Fine-Tuning (SFT) stages for visual refinement. The model runs on optimized computational infrastructure for efficient training and inference. SkyReels-V2 supports multiple applications, including Story Generation, Image-to-Video Synthesis, Camera Director functionality, and Elements-to-Video Generation.
|
353 |
+
|
354 |
+
<p align="center">
|
355 |
+
<img src="assets/main_pipeline.jpg" alt="mainpipeline" width="100%">
|
356 |
+
</p>
|
357 |
+
|
358 |
+
## Key Contributions of SkyReels-V2
|
359 |
+
|
360 |
+
#### Video Captioner
|
361 |
+
|
362 |
+
<a href="https://huggingface.co/Skywork/SkyCaptioner-V1">SkyCaptioner-V1</a> serves as our video captioning model for data annotation. This model is trained on the captioning result from the base model <a href="https://huggingface.co/Qwen/Qwen2.5-VL-72B-Instruct">Qwen2.5-VL-72B-Instruct</a> and the sub-expert captioners on a balanced video data. The balanced video data is a carefully curated dataset of approximately 2 million videos to ensure conceptual balance and annotation quality. Built upon the <a href="https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct">Qwen2.5-VL-7B-Instruct</a> foundation model, <a href="https://huggingface.co/Skywork/SkyCaptioner-V1">SkyCaptioner-V1</a> is fine-tuned to enhance performance in domain-specific video captioning tasks. To compare the performance with the SOTA models, we conducted a manual assessment of accuracy across different captioning fields using a test set of 1,000 samples. The proposed <a href="https://huggingface.co/Skywork/SkyCaptioner-V1">SkyCaptioner-V1</a> achieves the highest average accuracy among the baseline models, and show a dramatic result in the shot related fields
|
363 |
+
|
364 |
+
<p align="center">
|
365 |
+
<table align="center">
|
366 |
+
<thead>
|
367 |
+
<tr>
|
368 |
+
<th>model</th>
|
369 |
+
<th><a href="https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct">Qwen2.5-VL-7B-Ins.</a></th>
|
370 |
+
<th><a href="https://huggingface.co/Qwen/Qwen2.5-VL-72B-Instruct">Qwen2.5-VL-72B-Ins.</a></th>
|
371 |
+
<th><a href="https://huggingface.co/omni-research/Tarsier2-Recap-7b">Tarsier2-Recap-7b</a></th>
|
372 |
+
<th><a href="https://huggingface.co/Skywork/SkyCaptioner-V1">SkyCaptioner-V1</th>
|
373 |
+
</tr>
|
374 |
+
</thead>
|
375 |
+
<tbody>
|
376 |
+
<tr>
|
377 |
+
<td>Avg accuracy</td>
|
378 |
+
<td>51.4%</td>
|
379 |
+
<td>58.7%</td>
|
380 |
+
<td>49.4%</td>
|
381 |
+
<td><strong>76.3%</strong></td>
|
382 |
+
</tr>
|
383 |
+
<tr>
|
384 |
+
<td>shot type</td>
|
385 |
+
<td>76.8%</td>
|
386 |
+
<td>82.5%</td>
|
387 |
+
<td>60.2%</td>
|
388 |
+
<td><strong>93.7%</strong></td>
|
389 |
+
</tr>
|
390 |
+
<tr>
|
391 |
+
<td>shot angle</td>
|
392 |
+
<td>60.0%</td>
|
393 |
+
<td>73.7%</td>
|
394 |
+
<td>52.4%</td>
|
395 |
+
<td><strong>89.8%</strong></td>
|
396 |
+
</tr>
|
397 |
+
<tr>
|
398 |
+
<td>shot position</td>
|
399 |
+
<td>28.4%</td>
|
400 |
+
<td>32.7%</td>
|
401 |
+
<td>23.6%</td>
|
402 |
+
<td><strong>83.1%</strong></td>
|
403 |
+
</tr>
|
404 |
+
<tr>
|
405 |
+
<td>camera motion</td>
|
406 |
+
<td>62.0%</td>
|
407 |
+
<td>61.2%</td>
|
408 |
+
<td>45.3%</td>
|
409 |
+
<td><strong>85.3%</strong></td>
|
410 |
+
</tr>
|
411 |
+
<tr>
|
412 |
+
<td>expression</td>
|
413 |
+
<td>43.6%</td>
|
414 |
+
<td>51.5%</td>
|
415 |
+
<td>54.3%</td>
|
416 |
+
<td><strong>68.8%</strong></td>
|
417 |
+
</tr>
|
418 |
+
<tr>
|
419 |
+
<td colspan="5" style="text-align: center; border-bottom: 1px solid #ddd; padding: 8px;"></td>
|
420 |
+
</tr>
|
421 |
+
<tr>
|
422 |
+
<td>TYPES_type</td>
|
423 |
+
<td>43.5%</td>
|
424 |
+
<td>49.7%</td>
|
425 |
+
<td>47.6%</td>
|
426 |
+
<td><strong>82.5%</strong></td>
|
427 |
+
</tr>
|
428 |
+
<tr>
|
429 |
+
<td>TYPES_sub_type</td>
|
430 |
+
<td>38.9%</td>
|
431 |
+
<td>44.9%</td>
|
432 |
+
<td>45.9%</td>
|
433 |
+
<td><strong>75.4%</strong></td>
|
434 |
+
</tr>
|
435 |
+
<tr>
|
436 |
+
<td>appearance</td>
|
437 |
+
<td>40.9%</td>
|
438 |
+
<td>52.0%</td>
|
439 |
+
<td>45.6%</td>
|
440 |
+
<td><strong>59.3%</strong></td>
|
441 |
+
</tr>
|
442 |
+
<tr>
|
443 |
+
<td>action</td>
|
444 |
+
<td>32.4%</td>
|
445 |
+
<td>52.0%</td>
|
446 |
+
<td><strong>69.8%</strong></td>
|
447 |
+
<td>68.8%</td>
|
448 |
+
</tr>
|
449 |
+
<tr>
|
450 |
+
<td>position</td>
|
451 |
+
<td>35.4%</td>
|
452 |
+
<td>48.6%</td>
|
453 |
+
<td>45.5%</td>
|
454 |
+
<td><strong>57.5%</strong></td>
|
455 |
+
</tr>
|
456 |
+
<tr>
|
457 |
+
<td>is_main_subject</td>
|
458 |
+
<td>58.5%</td>
|
459 |
+
<td>68.7%</td>
|
460 |
+
<td>69.7%</td>
|
461 |
+
<td><strong>80.9%</strong></td>
|
462 |
+
</tr>
|
463 |
+
<tr>
|
464 |
+
<td>environment</td>
|
465 |
+
<td>70.4%</td>
|
466 |
+
<td><strong>72.7%</strong></td>
|
467 |
+
<td>61.4%</td>
|
468 |
+
<td>70.5%</td>
|
469 |
+
</tr>
|
470 |
+
<tr>
|
471 |
+
<td>lighting</td>
|
472 |
+
<td>77.1%</td>
|
473 |
+
<td><strong>80.0%</strong></td>
|
474 |
+
<td>21.2%</td>
|
475 |
+
<td>76.5%</td>
|
476 |
+
</tr>
|
477 |
+
</tbody>
|
478 |
+
</table>
|
479 |
+
</p>
|
480 |
+
|
481 |
+
#### Reinforcement Learning
|
482 |
+
Inspired by the previous success in LLM, we propose to enhance the performance of the generative model by Reinforcement Learning. Specifically, we focus on the motion quality because we find that the main drawback of our generative model is:
|
483 |
+
|
484 |
+
- the generative model does not handle well with large, deformable motions.
|
485 |
+
- the generated videos may violate the physical law.
|
486 |
+
|
487 |
+
To avoid the degradation in other metrics, such as text alignment and video quality, we ensure the preference data pairs have comparable text alignment and video quality, while only the motion quality varies. This requirement poses greater challenges in obtaining preference annotations due to the inherently higher costs of human annotation. To address this challenge, we propose a semi-automatic pipeline that strategically combines automatically generated motion pairs and human annotation results. This hybrid approach not only enhances the data scale but also improves alignment with human preferences through curated quality control. Leveraging this enhanced dataset, we first train a specialized reward model to capture the generic motion quality differences between paired samples. This learned reward function subsequently guides the sample selection process for Direct Preference Optimization (DPO), enhancing the motion quality of the generative model.
|
488 |
+
|
489 |
+
#### Diffusion Forcing
|
490 |
+
|
491 |
+
We introduce the Diffusion Forcing Transformer to unlock our model’s ability to generate long videos. Diffusion Forcing is a training and sampling strategy where each token is assigned an independent noise level. This allows tokens to be denoised according to arbitrary, per-token schedules. Conceptually, this approach functions as a form of partial masking: a token with zero noise is fully unmasked, while complete noise fully masks it. Diffusion Forcing trains the model to "unmask" any combination of variably noised tokens, using the cleaner tokens as conditional information to guide the recovery of noisy ones. Building on this, our Diffusion Forcing Transformer can extend video generation indefinitely based on the last frames of the previous segment. Note that the synchronous full sequence diffusion is a special case of Diffusion Forcing, where all tokens share the same noise level. This relationship allows us to fine-tune the Diffusion Forcing Transformer from a full-sequence diffusion model.
|
492 |
+
|
493 |
+
#### High-Quality Supervised Fine-Tuning (SFT)
|
494 |
+
|
495 |
+
We implement two sequential high-quality supervised fine-tuning (SFT) stages at 540p and 720p resolutions respectively, with the initial SFT phase conducted immediately after pretraining but prior to reinforcement learning (RL) stage.This first-stage SFT serves as a conceptual equilibrium trainer, building upon the foundation model’s pretraining outcomes that utilized only fps24 video data, while strategically removing FPS embedding components to streamline thearchitecture. Trained with the high-quality concept-balanced samples, this phase establishes optimized initialization parameters for subsequent training processes. Following this, we execute a secondary high-resolution SFT at 720p after completing the diffusion forcing stage, incorporating identical loss formulations and the higher-quality concept-balanced datasets by the manually filter. This final refinement phase focuses on resolution increase such that the overall video quality will be further enhanced.
|
496 |
+
|
497 |
+
## Performance
|
498 |
+
|
499 |
+
To comprehensively evaluate our proposed method, we construct the SkyReels-Bench for human assessment and leveraged the open-source <a href="https://github.com/Vchitect/VBench">V-Bench</a> for automated evaluation. This allows us to compare our model with the state-of-the-art (SOTA) baselines, including both open-source and proprietary models.
|
500 |
+
|
501 |
+
#### Human Evaluation
|
502 |
+
|
503 |
+
For human evaluation, we design SkyReels-Bench with 1,020 text prompts, systematically assessing three dimensions: Instruction Adherence, Motion Quality, Consistency and Visual Quality. This benchmark is designed to evaluate both text-to-video (T2V) and image-to-video (I2V) generation models, providing comprehensive assessment across different generation paradigms. To ensure fairness, all models were evaluated under default settings with consistent resolutions, and no post-generation filtering was applied.
|
504 |
+
|
505 |
+
- Text To Video Models
|
506 |
+
|
507 |
+
<p align="center">
|
508 |
+
<table align="center">
|
509 |
+
<thead>
|
510 |
+
<tr>
|
511 |
+
<th>Model Name</th>
|
512 |
+
<th>Average</th>
|
513 |
+
<th>Instruction Adherence</th>
|
514 |
+
<th>Consistency</th>
|
515 |
+
<th>Visual Quality</th>
|
516 |
+
<th>Motion Quality</th>
|
517 |
+
</tr>
|
518 |
+
</thead>
|
519 |
+
<tbody>
|
520 |
+
<tr>
|
521 |
+
<td><a href="https://runwayml.com/research/introducing-gen-3-alpha">Runway-Gen3 Alpha</a></td>
|
522 |
+
<td>2.53</td>
|
523 |
+
<td>2.19</td>
|
524 |
+
<td>2.57</td>
|
525 |
+
<td>3.23</td>
|
526 |
+
<td>2.11</td>
|
527 |
+
</tr>
|
528 |
+
<tr>
|
529 |
+
<td><a href="https://github.com/Tencent/HunyuanVideo">HunyuanVideo-13B</a></td>
|
530 |
+
<td>2.82</td>
|
531 |
+
<td>2.64</td>
|
532 |
+
<td>2.81</td>
|
533 |
+
<td>3.20</td>
|
534 |
+
<td>2.61</td>
|
535 |
+
</tr>
|
536 |
+
<tr>
|
537 |
+
<td><a href="https://klingai.com">Kling-1.6 STD Mode</a></td>
|
538 |
+
<td>2.99</td>
|
539 |
+
<td>2.77</td>
|
540 |
+
<td>3.05</td>
|
541 |
+
<td>3.39</td>
|
542 |
+
<td><strong>2.76</strong></td>
|
543 |
+
</tr>
|
544 |
+
<tr>
|
545 |
+
<td><a href="https://hailuoai.video">Hailuo-01</a></td>
|
546 |
+
<td>3.0</td>
|
547 |
+
<td>2.8</td>
|
548 |
+
<td>3.08</td>
|
549 |
+
<td>3.29</td>
|
550 |
+
<td>2.74</td>
|
551 |
+
</tr>
|
552 |
+
<tr>
|
553 |
+
<td><a href="https://github.com/Wan-Video/Wan2.1">Wan2.1-14B</a></td>
|
554 |
+
<td>3.12</td>
|
555 |
+
<td>2.91</td>
|
556 |
+
<td>3.31</td>
|
557 |
+
<td><strong>3.54</strong></td>
|
558 |
+
<td>2.71</td>
|
559 |
+
</tr>
|
560 |
+
<tr>
|
561 |
+
<td>SkyReels-V2</td>
|
562 |
+
<td><strong>3.14</strong></td>
|
563 |
+
<td><strong>3.15</strong></td>
|
564 |
+
<td><strong>3.35</strong></td>
|
565 |
+
<td>3.34</td>
|
566 |
+
<td>2.74</td>
|
567 |
+
</tr>
|
568 |
+
</tbody>
|
569 |
+
</table>
|
570 |
+
</p>
|
571 |
+
|
572 |
+
The evaluation demonstrates that our model achieves significant advancements in **instruction adherence (3.15)** compared to baseline methods, while maintaining competitive performance in **motion quality (2.74)** without sacrificing the **consistency (3.35)**.
|
573 |
+
|
574 |
+
- Image To Video Models
|
575 |
+
|
576 |
+
<p align="center">
|
577 |
+
<table align="center">
|
578 |
+
<thead>
|
579 |
+
<tr>
|
580 |
+
<th>Model</th>
|
581 |
+
<th>Average</th>
|
582 |
+
<th>Instruction Adherence</th>
|
583 |
+
<th>Consistency</th>
|
584 |
+
<th>Visual Quality</th>
|
585 |
+
<th>Motion Quality</th>
|
586 |
+
</tr>
|
587 |
+
</thead>
|
588 |
+
<tbody>
|
589 |
+
<tr>
|
590 |
+
<td><a href="https://github.com/Tencent/HunyuanVideo">HunyuanVideo-13B</a></td>
|
591 |
+
<td>2.84</td>
|
592 |
+
<td>2.97</td>
|
593 |
+
<td>2.95</td>
|
594 |
+
<td>2.87</td>
|
595 |
+
<td>2.56</td>
|
596 |
+
</tr>
|
597 |
+
<tr>
|
598 |
+
<td><a href="https://github.com/Wan-Video/Wan2.1">Wan2.1-14B</a></td>
|
599 |
+
<td>2.85</td>
|
600 |
+
<td>3.10</td>
|
601 |
+
<td>2.81</td>
|
602 |
+
<td>3.00</td>
|
603 |
+
<td>2.48</td>
|
604 |
+
</tr>
|
605 |
+
<tr>
|
606 |
+
<td><a href="https://hailuoai.video">Hailuo-01</a></td>
|
607 |
+
<td>3.05</td>
|
608 |
+
<td>3.31</td>
|
609 |
+
<td>2.58</td>
|
610 |
+
<td>3.55</td>
|
611 |
+
<td>2.74</td>
|
612 |
+
</tr>
|
613 |
+
<tr>
|
614 |
+
<td><a href="https://klingai.com">Kling-1.6 Pro Mode</a></td>
|
615 |
+
<td>3.4</td>
|
616 |
+
<td>3.56</td>
|
617 |
+
<td>3.03</td>
|
618 |
+
<td>3.58</td>
|
619 |
+
<td>3.41</td>
|
620 |
+
</tr>
|
621 |
+
<tr>
|
622 |
+
<td><a href="https://runwayml.com/research/introducing-runway-gen-4">Runway-Gen4</a></td>
|
623 |
+
<td>3.39</td>
|
624 |
+
<td>3.75</td>
|
625 |
+
<td>3.2</td>
|
626 |
+
<td>3.4</td>
|
627 |
+
<td>3.37</td>
|
628 |
+
</tr>
|
629 |
+
<tr>
|
630 |
+
<td>SkyReels-V2-DF</td>
|
631 |
+
<td>3.24</td>
|
632 |
+
<td>3.64</td>
|
633 |
+
<td>3.21</td>
|
634 |
+
<td>3.18</td>
|
635 |
+
<td>2.93</td>
|
636 |
+
</tr>
|
637 |
+
<tr>
|
638 |
+
<td>SkyReels-V2-I2V</td>
|
639 |
+
<td>3.29</td>
|
640 |
+
<td>3.42</td>
|
641 |
+
<td>3.18</td>
|
642 |
+
<td>3.56</td>
|
643 |
+
<td>3.01</td>
|
644 |
+
</tr>
|
645 |
+
</tbody>
|
646 |
+
</table>
|
647 |
+
</p>
|
648 |
+
|
649 |
+
Our results demonstrate that both **SkyReels-V2-I2V (3.29)** and **SkyReels-V2-DF (3.24)** achieve state-of-the-art performance among open-source models, significantly outperforming HunyuanVideo-13B (2.84) and Wan2.1-14B (2.85) across all quality dimensions. With an average score of 3.29, SkyReels-V2-I2V demonstrates comparable performance to proprietary models Kling-1.6 (3.4) and Runway-Gen4 (3.39).
|
650 |
+
|
651 |
+
|
652 |
+
#### VBench
|
653 |
+
To objectively compare SkyReels-V2 Model against other leading open-source Text-To-Video models, we conduct comprehensive evaluations using the public benchmark <a href="https://github.com/Vchitect/VBench">V-Bench</a>. Our evaluation specifically leverages the benchmark’s longer version prompt. For fair comparison with baseline models, we strictly follow their recommended setting for inference.
|
654 |
+
|
655 |
+
<p align="center">
|
656 |
+
<table align="center">
|
657 |
+
<thead>
|
658 |
+
<tr>
|
659 |
+
<th>Model</th>
|
660 |
+
<th>Total Score</th>
|
661 |
+
<th>Quality Score</th>
|
662 |
+
<th>Semantic Score</th>
|
663 |
+
</tr>
|
664 |
+
</thead>
|
665 |
+
<tbody>
|
666 |
+
<tr>
|
667 |
+
<td><a href="https://github.com/hpcaitech/Open-Sora">OpenSora 2.0</a></td>
|
668 |
+
<td>81.5 %</td>
|
669 |
+
<td>82.1 %</td>
|
670 |
+
<td>78.2 %</td>
|
671 |
+
</tr>
|
672 |
+
<tr>
|
673 |
+
<td><a href="https://github.com/THUDM/CogVideo">CogVideoX1.5-5B</a></td>
|
674 |
+
<td>80.3 %</td>
|
675 |
+
<td>80.9 %</td>
|
676 |
+
<td>77.9 %</td>
|
677 |
+
</tr>
|
678 |
+
<tr>
|
679 |
+
<td><a href="https://github.com/Tencent/HunyuanVideo">HunyuanVideo-13B</a></td>
|
680 |
+
<td>82.7 %</td>
|
681 |
+
<td>84.4 %</td>
|
682 |
+
<td>76.2 %</td>
|
683 |
+
</tr>
|
684 |
+
<tr>
|
685 |
+
<td><a href="https://github.com/Wan-Video/Wan2.1">Wan2.1-14B</a></td>
|
686 |
+
<td>83.7 %</td>
|
687 |
+
<td>84.2 %</td>
|
688 |
+
<td><strong>81.4 %</strong></td>
|
689 |
+
</tr>
|
690 |
+
<tr>
|
691 |
+
<td>SkyReels-V2</td>
|
692 |
+
<td><strong>83.9 %</strong></td>
|
693 |
+
<td><strong>84.7 %</strong></td>
|
694 |
+
<td>80.8 %</td>
|
695 |
+
</tr>
|
696 |
+
</tbody>
|
697 |
+
</table>
|
698 |
+
</p>
|
699 |
+
|
700 |
+
The VBench results demonstrate that SkyReels-V2 outperforms all compared models including HunyuanVideo-13B and Wan2.1-14B, With the highest **total score (83.9%)** and **quality score (84.7%)**. In this evaluation, the semantic score is slightly lower than Wan2.1-14B, while we outperform Wan2.1-14B in human evaluations, with the primary gap attributed to V-Bench’s insufficient evaluation of shot-scenario semantic adherence.
|
701 |
+
|
702 |
+
## Acknowledgements
|
703 |
+
We would like to thank the contributors of <a href="https://github.com/Wan-Video/Wan2.1">Wan 2.1</a>, <a href="https://github.com/xdit-project/xDiT">XDit</a> and <a href="https://qwenlm.github.io/blog/qwen2.5/">Qwen 2.5</a> repositories, for their open research and contributions.
|
704 |
+
|
705 |
+
## Citation
|
706 |
+
|
707 |
+
```bibtex
|
708 |
+
@misc{chen2025skyreelsv2infinitelengthfilmgenerative,
|
709 |
+
title={SkyReels-V2: Infinite-length Film Generative Model},
|
710 |
+
author={Guibin Chen and Dixuan Lin and Jiangping Yang and Chunze Lin and Junchen Zhu and Mingyuan Fan and Hao Zhang and Sheng Chen and Zheng Chen and Chengcheng Ma and Weiming Xiong and Wei Wang and Nuo Pang and Kang Kang and Zhiheng Xu and Yuzhe Jin and Yupeng Liang and Yubing Song and Peng Zhao and Boyuan Xu and Di Qiu and Debang Li and Zhengcong Fei and Yang Li and Yahui Zhou},
|
711 |
+
year={2025},
|
712 |
+
eprint={2504.13074},
|
713 |
+
archivePrefix={arXiv},
|
714 |
+
primaryClass={cs.CV},
|
715 |
+
url={https://arxiv.org/abs/2504.13074},
|
716 |
+
}
|
717 |
+
```
|
assets/logo2.png
ADDED
![]() |
assets/main_pipeline.jpg
ADDED
![]() |
Git LFS Details
|
generate_video.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import gc
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import time
|
6 |
+
|
7 |
+
import imageio
|
8 |
+
import torch
|
9 |
+
from diffusers.utils import load_image
|
10 |
+
|
11 |
+
from skyreels_v2_infer.modules import download_model
|
12 |
+
from skyreels_v2_infer.pipelines import Image2VideoPipeline
|
13 |
+
from skyreels_v2_infer.pipelines import PromptEnhancer
|
14 |
+
from skyreels_v2_infer.pipelines import resizecrop
|
15 |
+
from skyreels_v2_infer.pipelines import Text2VideoPipeline
|
16 |
+
|
17 |
+
MODEL_ID_CONFIG = {
|
18 |
+
"text2video": [
|
19 |
+
"Skywork/SkyReels-V2-T2V-14B-540P",
|
20 |
+
"Skywork/SkyReels-V2-T2V-14B-720P",
|
21 |
+
],
|
22 |
+
"image2video": [
|
23 |
+
"Skywork/SkyReels-V2-I2V-1.3B-540P",
|
24 |
+
"Skywork/SkyReels-V2-I2V-14B-540P",
|
25 |
+
"Skywork/SkyReels-V2-I2V-14B-720P",
|
26 |
+
],
|
27 |
+
}
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
|
32 |
+
parser = argparse.ArgumentParser()
|
33 |
+
parser.add_argument("--outdir", type=str, default="video_out")
|
34 |
+
parser.add_argument("--model_id", type=str, default="Skywork/SkyReels-V2-T2V-14B-540P")
|
35 |
+
parser.add_argument("--resolution", type=str, choices=["540P", "720P"])
|
36 |
+
parser.add_argument("--num_frames", type=int, default=97)
|
37 |
+
parser.add_argument("--image", type=str, default=None)
|
38 |
+
parser.add_argument("--guidance_scale", type=float, default=6.0)
|
39 |
+
parser.add_argument("--shift", type=float, default=8.0)
|
40 |
+
parser.add_argument("--inference_steps", type=int, default=30)
|
41 |
+
parser.add_argument("--use_usp", action="store_true")
|
42 |
+
parser.add_argument("--offload", action="store_true")
|
43 |
+
parser.add_argument("--fps", type=int, default=24)
|
44 |
+
parser.add_argument("--seed", type=int, default=None)
|
45 |
+
parser.add_argument(
|
46 |
+
"--prompt",
|
47 |
+
type=str,
|
48 |
+
default="A serene lake surrounded by towering mountains, with a few swans gracefully gliding across the water and sunlight dancing on the surface.",
|
49 |
+
)
|
50 |
+
parser.add_argument("--prompt_enhancer", action="store_true")
|
51 |
+
parser.add_argument("--teacache", action="store_true")
|
52 |
+
parser.add_argument(
|
53 |
+
"--teacache_thresh",
|
54 |
+
type=float,
|
55 |
+
default=0.2,
|
56 |
+
help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup")
|
57 |
+
parser.add_argument(
|
58 |
+
"--use_ret_steps",
|
59 |
+
action="store_true",
|
60 |
+
help="Using Retention Steps will result in faster generation speed and better generation quality.")
|
61 |
+
args = parser.parse_args()
|
62 |
+
|
63 |
+
args.model_id = download_model(args.model_id)
|
64 |
+
print("model_id:", args.model_id)
|
65 |
+
|
66 |
+
assert (args.use_usp and args.seed is not None) or (not args.use_usp), "usp mode need seed"
|
67 |
+
if args.seed is None:
|
68 |
+
random.seed(time.time())
|
69 |
+
args.seed = int(random.randrange(4294967294))
|
70 |
+
|
71 |
+
if args.resolution == "540P":
|
72 |
+
height = 544
|
73 |
+
width = 960
|
74 |
+
elif args.resolution == "720P":
|
75 |
+
height = 720
|
76 |
+
width = 1280
|
77 |
+
else:
|
78 |
+
raise ValueError(f"Invalid resolution: {args.resolution}")
|
79 |
+
|
80 |
+
image = load_image(args.image).convert("RGB") if args.image else None
|
81 |
+
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
82 |
+
local_rank = 0
|
83 |
+
if args.use_usp:
|
84 |
+
assert not args.prompt_enhancer, "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter."
|
85 |
+
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
|
86 |
+
import torch.distributed as dist
|
87 |
+
|
88 |
+
dist.init_process_group("nccl")
|
89 |
+
local_rank = dist.get_rank()
|
90 |
+
torch.cuda.set_device(dist.get_rank())
|
91 |
+
device = "cuda"
|
92 |
+
|
93 |
+
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
|
94 |
+
|
95 |
+
initialize_model_parallel(
|
96 |
+
sequence_parallel_degree=dist.get_world_size(),
|
97 |
+
ring_degree=1,
|
98 |
+
ulysses_degree=dist.get_world_size(),
|
99 |
+
)
|
100 |
+
|
101 |
+
prompt_input = args.prompt
|
102 |
+
if args.prompt_enhancer and args.image is None:
|
103 |
+
print(f"init prompt enhancer")
|
104 |
+
prompt_enhancer = PromptEnhancer()
|
105 |
+
prompt_input = prompt_enhancer(prompt_input)
|
106 |
+
print(f"enhanced prompt: {prompt_input}")
|
107 |
+
del prompt_enhancer
|
108 |
+
gc.collect()
|
109 |
+
torch.cuda.empty_cache()
|
110 |
+
|
111 |
+
if image is None:
|
112 |
+
assert "T2V" in args.model_id, f"check model_id:{args.model_id}"
|
113 |
+
print("init text2video pipeline")
|
114 |
+
pipe = Text2VideoPipeline(
|
115 |
+
model_path=args.model_id, dit_path=args.model_id, use_usp=args.use_usp, offload=args.offload
|
116 |
+
)
|
117 |
+
else:
|
118 |
+
assert "I2V" in args.model_id, f"check model_id:{args.model_id}"
|
119 |
+
print("init img2video pipeline")
|
120 |
+
pipe = Image2VideoPipeline(
|
121 |
+
model_path=args.model_id, dit_path=args.model_id, use_usp=args.use_usp, offload=args.offload
|
122 |
+
)
|
123 |
+
args.image = load_image(args.image)
|
124 |
+
image_width, image_height = args.image.size
|
125 |
+
if image_height > image_width:
|
126 |
+
height, width = width, height
|
127 |
+
args.image = resizecrop(args.image, height, width)
|
128 |
+
|
129 |
+
if args.teacache:
|
130 |
+
pipe.transformer.initialize_teacache(enable_teacache=True, num_steps=args.inference_steps,
|
131 |
+
teacache_thresh=args.teacache_thresh, use_ret_steps=args.use_ret_steps,
|
132 |
+
ckpt_dir=args.model_id)
|
133 |
+
|
134 |
+
|
135 |
+
kwargs = {
|
136 |
+
"prompt": prompt_input,
|
137 |
+
"negative_prompt": negative_prompt,
|
138 |
+
"num_frames": args.num_frames,
|
139 |
+
"num_inference_steps": args.inference_steps,
|
140 |
+
"guidance_scale": args.guidance_scale,
|
141 |
+
"shift": args.shift,
|
142 |
+
"generator": torch.Generator(device="cuda").manual_seed(args.seed),
|
143 |
+
"height": height,
|
144 |
+
"width": width,
|
145 |
+
}
|
146 |
+
|
147 |
+
if image is not None:
|
148 |
+
kwargs["image"] = args.image.convert("RGB")
|
149 |
+
|
150 |
+
save_dir = os.path.join("result", args.outdir)
|
151 |
+
os.makedirs(save_dir, exist_ok=True)
|
152 |
+
|
153 |
+
with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad():
|
154 |
+
print(f"infer kwargs:{kwargs}")
|
155 |
+
video_frames = pipe(**kwargs)[0]
|
156 |
+
|
157 |
+
if local_rank == 0:
|
158 |
+
current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
|
159 |
+
video_out_file = f"{args.prompt[:100].replace('/','')}_{args.seed}_{current_time}.mp4"
|
160 |
+
output_path = os.path.join(save_dir, video_out_file)
|
161 |
+
imageio.mimwrite(output_path, video_frames, fps=args.fps, quality=8, output_params=["-loglevel", "error"])
|
generate_video_df.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import gc
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import time
|
6 |
+
|
7 |
+
import imageio
|
8 |
+
import torch
|
9 |
+
from diffusers.utils import load_image
|
10 |
+
|
11 |
+
from skyreels_v2_infer import DiffusionForcingPipeline
|
12 |
+
from skyreels_v2_infer.modules import download_model
|
13 |
+
from skyreels_v2_infer.pipelines import PromptEnhancer
|
14 |
+
from skyreels_v2_infer.pipelines import resizecrop
|
15 |
+
|
16 |
+
if __name__ == "__main__":
|
17 |
+
|
18 |
+
parser = argparse.ArgumentParser()
|
19 |
+
parser.add_argument("--outdir", type=str, default="diffusion_forcing")
|
20 |
+
parser.add_argument("--model_id", type=str, default="Skywork/SkyReels-V2-DF-1.3B-540P")
|
21 |
+
parser.add_argument("--resolution", type=str, choices=["540P", "720P"])
|
22 |
+
parser.add_argument("--num_frames", type=int, default=97)
|
23 |
+
parser.add_argument("--image", type=str, default=None)
|
24 |
+
parser.add_argument("--ar_step", type=int, default=0)
|
25 |
+
parser.add_argument("--causal_attention", action="store_true")
|
26 |
+
parser.add_argument("--causal_block_size", type=int, default=1)
|
27 |
+
parser.add_argument("--base_num_frames", type=int, default=97)
|
28 |
+
parser.add_argument("--overlap_history", type=int, default=None)
|
29 |
+
parser.add_argument("--addnoise_condition", type=int, default=0)
|
30 |
+
parser.add_argument("--guidance_scale", type=float, default=6.0)
|
31 |
+
parser.add_argument("--shift", type=float, default=8.0)
|
32 |
+
parser.add_argument("--inference_steps", type=int, default=30)
|
33 |
+
parser.add_argument("--use_usp", action="store_true")
|
34 |
+
parser.add_argument("--offload", action="store_true")
|
35 |
+
parser.add_argument("--fps", type=int, default=24)
|
36 |
+
parser.add_argument("--seed", type=int, default=None)
|
37 |
+
parser.add_argument(
|
38 |
+
"--prompt",
|
39 |
+
type=str,
|
40 |
+
default="A woman in a leather jacket and sunglasses riding a vintage motorcycle through a desert highway at sunset, her hair blowing wildly in the wind as the motorcycle kicks up dust, with the golden sun casting long shadows across the barren landscape.",
|
41 |
+
)
|
42 |
+
parser.add_argument("--prompt_enhancer", action="store_true")
|
43 |
+
parser.add_argument("--teacache", action="store_true")
|
44 |
+
parser.add_argument(
|
45 |
+
"--teacache_thresh",
|
46 |
+
type=float,
|
47 |
+
default=0.2,
|
48 |
+
help="Higher speedup will cause to worse quality -- 0.1 for 2.0x speedup -- 0.2 for 3.0x speedup",
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
"--use_ret_steps",
|
52 |
+
action="store_true",
|
53 |
+
help="Using Retention Steps will result in faster generation speed and better generation quality.",
|
54 |
+
)
|
55 |
+
args = parser.parse_args()
|
56 |
+
|
57 |
+
args.model_id = download_model(args.model_id)
|
58 |
+
print("model_id:", args.model_id)
|
59 |
+
|
60 |
+
assert (args.use_usp and args.seed is not None) or (not args.use_usp), "usp mode need seed"
|
61 |
+
if args.seed is None:
|
62 |
+
random.seed(time.time())
|
63 |
+
args.seed = int(random.randrange(4294967294))
|
64 |
+
|
65 |
+
if args.resolution == "540P":
|
66 |
+
height = 544
|
67 |
+
width = 960
|
68 |
+
elif args.resolution == "720P":
|
69 |
+
height = 720
|
70 |
+
width = 1280
|
71 |
+
else:
|
72 |
+
raise ValueError(f"Invalid resolution: {args.resolution}")
|
73 |
+
|
74 |
+
num_frames = args.num_frames
|
75 |
+
fps = args.fps
|
76 |
+
|
77 |
+
if num_frames > args.base_num_frames:
|
78 |
+
assert (
|
79 |
+
args.overlap_history is not None
|
80 |
+
), 'You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.'
|
81 |
+
if args.addnoise_condition > 60:
|
82 |
+
print(
|
83 |
+
f'You have set "addnoise_condition" as {args.addnoise_condition}. The value is too large which can cause inconsistency in long video generation. The value is recommanded to set 20.'
|
84 |
+
)
|
85 |
+
|
86 |
+
guidance_scale = args.guidance_scale
|
87 |
+
shift = args.shift
|
88 |
+
if args.image:
|
89 |
+
args.image = load_image(args.image)
|
90 |
+
image_width, image_height = args.image.size
|
91 |
+
if image_height > image_width:
|
92 |
+
height, width = width, height
|
93 |
+
args.image = resizecrop(args.image, height, width)
|
94 |
+
image = args.image.convert("RGB") if args.image else None
|
95 |
+
negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
|
96 |
+
|
97 |
+
save_dir = os.path.join("result", args.outdir)
|
98 |
+
os.makedirs(save_dir, exist_ok=True)
|
99 |
+
local_rank = 0
|
100 |
+
if args.use_usp:
|
101 |
+
assert (
|
102 |
+
not args.prompt_enhancer
|
103 |
+
), "`--prompt_enhancer` is not allowed if using `--use_usp`. We recommend running the skyreels_v2_infer/pipelines/prompt_enhancer.py script first to generate enhanced prompt before enabling the `--use_usp` parameter."
|
104 |
+
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
|
105 |
+
import torch.distributed as dist
|
106 |
+
|
107 |
+
dist.init_process_group("nccl")
|
108 |
+
local_rank = dist.get_rank()
|
109 |
+
torch.cuda.set_device(dist.get_rank())
|
110 |
+
device = "cuda"
|
111 |
+
|
112 |
+
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
|
113 |
+
|
114 |
+
initialize_model_parallel(
|
115 |
+
sequence_parallel_degree=dist.get_world_size(),
|
116 |
+
ring_degree=1,
|
117 |
+
ulysses_degree=dist.get_world_size(),
|
118 |
+
)
|
119 |
+
|
120 |
+
prompt_input = args.prompt
|
121 |
+
if args.prompt_enhancer and args.image is None:
|
122 |
+
print(f"init prompt enhancer")
|
123 |
+
prompt_enhancer = PromptEnhancer()
|
124 |
+
prompt_input = prompt_enhancer(prompt_input)
|
125 |
+
print(f"enhanced prompt: {prompt_input}")
|
126 |
+
del prompt_enhancer
|
127 |
+
gc.collect()
|
128 |
+
torch.cuda.empty_cache()
|
129 |
+
|
130 |
+
pipe = DiffusionForcingPipeline(
|
131 |
+
args.model_id,
|
132 |
+
dit_path=args.model_id,
|
133 |
+
device=torch.device("cuda"),
|
134 |
+
weight_dtype=torch.bfloat16,
|
135 |
+
use_usp=args.use_usp,
|
136 |
+
offload=args.offload,
|
137 |
+
)
|
138 |
+
|
139 |
+
if args.causal_attention:
|
140 |
+
pipe.transformer.set_ar_attention(args.causal_block_size)
|
141 |
+
|
142 |
+
if args.teacache:
|
143 |
+
if args.ar_step > 0:
|
144 |
+
num_steps = (
|
145 |
+
args.inference_steps
|
146 |
+
+ (((args.base_num_frames - 1) // 4 + 1) // args.causal_block_size - 1) * args.ar_step
|
147 |
+
)
|
148 |
+
print("num_steps:", num_steps)
|
149 |
+
else:
|
150 |
+
num_steps = args.inference_steps
|
151 |
+
pipe.transformer.initialize_teacache(
|
152 |
+
enable_teacache=True,
|
153 |
+
num_steps=num_steps,
|
154 |
+
teacache_thresh=args.teacache_thresh,
|
155 |
+
use_ret_steps=args.use_ret_steps,
|
156 |
+
ckpt_dir=args.model_id,
|
157 |
+
)
|
158 |
+
|
159 |
+
print(f"prompt:{prompt_input}")
|
160 |
+
print(f"guidance_scale:{guidance_scale}")
|
161 |
+
|
162 |
+
with torch.cuda.amp.autocast(dtype=pipe.transformer.dtype), torch.no_grad():
|
163 |
+
video_frames = pipe(
|
164 |
+
prompt=prompt_input,
|
165 |
+
negative_prompt=negative_prompt,
|
166 |
+
image=image,
|
167 |
+
height=height,
|
168 |
+
width=width,
|
169 |
+
num_frames=num_frames,
|
170 |
+
num_inference_steps=args.inference_steps,
|
171 |
+
shift=shift,
|
172 |
+
guidance_scale=guidance_scale,
|
173 |
+
generator=torch.Generator(device="cuda").manual_seed(args.seed),
|
174 |
+
overlap_history=args.overlap_history,
|
175 |
+
addnoise_condition=args.addnoise_condition,
|
176 |
+
base_num_frames=args.base_num_frames,
|
177 |
+
ar_step=args.ar_step,
|
178 |
+
causal_block_size=args.causal_block_size,
|
179 |
+
fps=fps,
|
180 |
+
)[0]
|
181 |
+
|
182 |
+
if local_rank == 0:
|
183 |
+
current_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
|
184 |
+
video_out_file = f"{args.prompt[:100].replace('/','')}_{args.seed}_{current_time}.mp4"
|
185 |
+
output_path = os.path.join(save_dir, video_out_file)
|
186 |
+
imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"])
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.5.1
|
2 |
+
torchvision==0.20.1
|
3 |
+
opencv-python==4.10.0.84
|
4 |
+
diffusers>=0.31.0
|
5 |
+
transformers==4.49.0
|
6 |
+
tokenizers==0.21.1
|
7 |
+
accelerate==1.6.0
|
8 |
+
tqdm
|
9 |
+
imageio
|
10 |
+
easydict
|
11 |
+
ftfy
|
12 |
+
dashscope
|
13 |
+
imageio-ffmpeg
|
14 |
+
flash_attn
|
15 |
+
numpy>=1.23.5,<2
|
16 |
+
xfuser
|
skycaptioner_v1/README.md
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SkyCaptioner-V1: A Structural Video Captioning Model
|
2 |
+
|
3 |
+
<p align="center">
|
4 |
+
📑 <a href="https://arxiv.org/pdf/2504.13074">Technical Report</a> · 👋 <a href="https://www.skyreels.ai/home?utm_campaign=github_SkyReels_V2" target="_blank">Playground</a> · 💬 <a href="https://discord.gg/PwM6NYtccQ" target="_blank">Discord</a> · 🤗 <a href="https://huggingface.co/Skywork/SkyCaptioner-V1" target="_blank">Hugging Face</a> · 🤖 <a href="https://modelscope.cn/collections/SkyReels-V2-f665650130b144">ModelScope</a></a>
|
5 |
+
</p>
|
6 |
+
|
7 |
+
---
|
8 |
+
|
9 |
+
Welcome to the SkyCaptioner-V1 repository! Here, you'll find the structural video captioning model weights and inference code for our video captioner that labels the video data efficiently and comprehensively.
|
10 |
+
|
11 |
+
## 🔥🔥🔥 News!!
|
12 |
+
|
13 |
+
* Apr 21, 2025: 👋 We release the [vllm](https://github.com/vllm-project/vllm) batch inference code for SkyCaptioner-V1 Model and caption fusion inference code.
|
14 |
+
* Apr 21, 2025: 👋 We release the first shot-aware video captioning model [SkyCaptioner-V1 Model](https://huggingface.co/Skywork/SkyCaptioner-V1). For more details, please check our [paper](https://arxiv.org/pdf/2504.13074).
|
15 |
+
|
16 |
+
## 📑 TODO List
|
17 |
+
|
18 |
+
- SkyCaptioner-V1
|
19 |
+
|
20 |
+
- [x] Checkpoints
|
21 |
+
- [x] Batch Inference Code
|
22 |
+
- [x] Caption Fusion Method
|
23 |
+
- [ ] Web Demo (Gradio)
|
24 |
+
|
25 |
+
## 🌟 Overview
|
26 |
+
|
27 |
+
SkyCaptioner-V1 is a structural video captioning model designed to generate high-quality, structural descriptions for video data. It integrates specialized sub-expert models and multimodal large language models (MLLMs) with human annotations to address the limitations of general captioners in capturing professional film-related details. Key aspects include:
|
28 |
+
|
29 |
+
1. **Structural Representation**: Combines general video descriptions (from MLLMs) with sub-expert captioner (e.g., shot types,shot angles, shot positions, camera motions.) and human annotations.
|
30 |
+
2. **Knowledge Distillation**: Distills expertise from sub-expert captioners into a unified model.
|
31 |
+
3. **Application Flexibility**: Generates dense captions for text-to-video (T2V) and concise prompts for image-to-video (I2V) tasks.
|
32 |
+
|
33 |
+
## 🔑 Key Features
|
34 |
+
|
35 |
+
### Structural Captioning Framework
|
36 |
+
|
37 |
+
Our Video Captioning model captures multi-dimensional details:
|
38 |
+
|
39 |
+
* **Subjects**: Appearance, action, expression, position, and hierarchical categorization.
|
40 |
+
* **Shot Metadata**: Shot type (e.g., close-up, long shot), shot angle, shot position, camera motion, environment, lighting, etc.
|
41 |
+
|
42 |
+
### Sub-Expert Integration
|
43 |
+
|
44 |
+
* **Shot Captioner**: Classifies shot type, angle, and position with high precision.
|
45 |
+
* **Expression Captioner**: Analyzes facial expressions, emotion intensity, and temporal dynamics.
|
46 |
+
* **Camera Motion Captioner**: Tracks 6DoF camera movements and composite motion types,
|
47 |
+
|
48 |
+
### Training Pipeline
|
49 |
+
|
50 |
+
* Trained on \~2M high-quality, concept-balanced videos curated from 10M raw samples.
|
51 |
+
* Fine-tuned on Qwen2.5-VL-7B-Instruct with a global batch size of 512 across 32 A800 GPUs.
|
52 |
+
* Optimized using AdamW (learning rate: 1e-5) for 2 epochs.
|
53 |
+
|
54 |
+
### Dynamic Caption Fusion:
|
55 |
+
|
56 |
+
* Adapts output length based on application (T2V/I2V).
|
57 |
+
* Employs LLM Model to fusion structural fields to get a natural and fluency caption for downstream tasks.
|
58 |
+
|
59 |
+
## 📊 Benchmark Results
|
60 |
+
|
61 |
+
SkyCaptioner-V1 demonstrates significant improvements over existing models in key film-specific captioning tasks, particularly in **shot-language understanding** and **domain-specific precision**. The differences stem from its structural architecture and expert-guided training:
|
62 |
+
|
63 |
+
1. **Superior Shot-Language Understanding**:
|
64 |
+
* Our Captioner model outperforms Qwen2.5-VL-72B with +11.2% in shot type, +16.1% in shot angle, and +50.4% in shot position accuracy. Because SkyCaptioner-V1’s specialized shot classifiers outperform generalist MLLMs, which lack film-domain fine-tuning.
|
65 |
+
* +28.5% accuracy in camera motion vs. Tarsier2-recap-7B (88.8% vs. 41.5%):
|
66 |
+
Its 6DoF motion analysis and active learning pipeline address ambiguities in composite motions (e.g., tracking + panning) that challenge generic captioners.
|
67 |
+
2. **High domain-specific precision**:
|
68 |
+
* Expression accuracy: 68.8% vs. 54.3% (Tarsier2-recap-7B), leveraging temporal-aware S2D frameworks to capture dynamic facial changes.
|
69 |
+
|
70 |
+
<p align="center">
|
71 |
+
<table align="center">
|
72 |
+
<thead>
|
73 |
+
<tr>
|
74 |
+
<th>Metric</th>
|
75 |
+
<th>Qwen2.5-VL-7B-Ins.</th>
|
76 |
+
<th>Qwen2.5-VL-72B-Ins.</th>
|
77 |
+
<th>Tarsier2-recap-7B</th>
|
78 |
+
<th>SkyCaptioner-V1</th>
|
79 |
+
</tr>
|
80 |
+
</thead>
|
81 |
+
<tbody>
|
82 |
+
<tr>
|
83 |
+
<td>Avg accuracy</td>
|
84 |
+
<td>51.4%</td>
|
85 |
+
<td>58.7%</td>
|
86 |
+
<td>49.4%</td>
|
87 |
+
<td><strong>76.3%</strong></td>
|
88 |
+
</tr>
|
89 |
+
<tr>
|
90 |
+
<td>shot type</td>
|
91 |
+
<td>76.8%</td>
|
92 |
+
<td>82.5%</td>
|
93 |
+
<td>60.2%</td>
|
94 |
+
<td><strong>93.7%</strong></td>
|
95 |
+
</tr>
|
96 |
+
<tr>
|
97 |
+
<td>shot angle</td>
|
98 |
+
<td>60.0%</td>
|
99 |
+
<td>73.7%</td>
|
100 |
+
<td>52.4%</td>
|
101 |
+
<td><strong>89.8%</strong></td>
|
102 |
+
</tr>
|
103 |
+
<tr>
|
104 |
+
<td>shot position</td>
|
105 |
+
<td>28.4%</td>
|
106 |
+
<td>32.7%</td>
|
107 |
+
<td>23.6%</td>
|
108 |
+
<td><strong>83.1%</strong></td>
|
109 |
+
</tr>
|
110 |
+
<tr>
|
111 |
+
<td>camera motion</td>
|
112 |
+
<td>62.0%</td>
|
113 |
+
<td>61.2%</td>
|
114 |
+
<td>45.3%</td>
|
115 |
+
<td><strong>85.3%</strong></td>
|
116 |
+
</tr>
|
117 |
+
<tr>
|
118 |
+
<td>expression</td>
|
119 |
+
<td>43.6%</td>
|
120 |
+
<td>51.5%</td>
|
121 |
+
<td>54.3%</td>
|
122 |
+
<td><strong>68.8%</strong></td>
|
123 |
+
</tr>
|
124 |
+
<tr>
|
125 |
+
<td>TYPES_type</td>
|
126 |
+
<td>43.5%</td>
|
127 |
+
<td>49.7%</td>
|
128 |
+
<td>47.6%</td>
|
129 |
+
<td><strong>82.5%</strong></td>
|
130 |
+
</tr>
|
131 |
+
<tr>
|
132 |
+
<td>TYPES_sub_type</td>
|
133 |
+
<td>38.9%</td>
|
134 |
+
<td>44.9%</td>
|
135 |
+
<td>45.9%</td>
|
136 |
+
<td><strong>75.4%</strong></td>
|
137 |
+
</tr>
|
138 |
+
<tr>
|
139 |
+
<td>appearance</td>
|
140 |
+
<td>40.9%</td>
|
141 |
+
<td>52.0%</td>
|
142 |
+
<td>45.6%</td>
|
143 |
+
<td><strong>59.3%</strong></td>
|
144 |
+
</tr>
|
145 |
+
<tr>
|
146 |
+
<td>action</td>
|
147 |
+
<td>32.4%</td>
|
148 |
+
<td>52.0%</td>
|
149 |
+
<td><strong>69.8%</strong></td>
|
150 |
+
<td>68.8%</td>
|
151 |
+
</tr>
|
152 |
+
<tr>
|
153 |
+
<td>position</td>
|
154 |
+
<td>35.4%</td>
|
155 |
+
<td>48.6%</td>
|
156 |
+
<td>45.5%</td>
|
157 |
+
<td><strong>57.5%</strong></td>
|
158 |
+
</tr>
|
159 |
+
<tr>
|
160 |
+
<td>is_main_subject</td>
|
161 |
+
<td>58.5%</td>
|
162 |
+
<td>68.7%</td>
|
163 |
+
<td>69.7%</td>
|
164 |
+
<td><strong>80.9%</strong></td>
|
165 |
+
</tr>
|
166 |
+
<tr>
|
167 |
+
<td>environment</td>
|
168 |
+
<td>70.4%</td>
|
169 |
+
<td><strong>72.7%</strong></td>
|
170 |
+
<td>61.4%</td>
|
171 |
+
<td>70.5%</td>
|
172 |
+
</tr>
|
173 |
+
<tr>
|
174 |
+
<td>lighting</td>
|
175 |
+
<td>77.1%</td>
|
176 |
+
<td><strong>80.0%</strong></td>
|
177 |
+
<td>21.2%</td>
|
178 |
+
<td>76.5%</td>
|
179 |
+
</tr>
|
180 |
+
</tbody>
|
181 |
+
</table>
|
182 |
+
</p>
|
183 |
+
|
184 |
+
## 📦 Model Downloads
|
185 |
+
|
186 |
+
Our SkyCaptioner-V1 model can be downloaded from [SkyCaptioner-V1 Model](https://huggingface.co/Skywork/SkyCaptioner-V1).
|
187 |
+
We use [Qwen2.5-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-32B-Instruct) as our caption fusion model to intelligently combine structured caption fields, producing either dense or sparse final captions depending on application requirements.
|
188 |
+
|
189 |
+
```shell
|
190 |
+
# download SkyCaptioner-V1
|
191 |
+
huggingface-cli download Skywork/SkyCaptioner-V1 --local-dir /path/to/your_local_model_path
|
192 |
+
# download Qwen2.5-32B-Instruct
|
193 |
+
huggingface-cli download Qwen/Qwen2.5-32B-Instruct --local-dir /path/to/your_local_model_path2
|
194 |
+
```
|
195 |
+
|
196 |
+
## 🛠️ Running Guide
|
197 |
+
|
198 |
+
Begin by cloning the repository:
|
199 |
+
|
200 |
+
```shell
|
201 |
+
git clone https://github.com/SkyworkAI/SkyReels-V2
|
202 |
+
cd skycaptioner_v1
|
203 |
+
```
|
204 |
+
|
205 |
+
### Installation Guide for Linux
|
206 |
+
|
207 |
+
We recommend Python 3.10 and CUDA version 12.2 for the manual installation.
|
208 |
+
|
209 |
+
```shell
|
210 |
+
pip install -r requirements.txt
|
211 |
+
```
|
212 |
+
|
213 |
+
### Running Command
|
214 |
+
|
215 |
+
#### Get Structural Caption by SkyCaptioner-V1
|
216 |
+
|
217 |
+
```shell
|
218 |
+
export SkyCaptioner_V1_Model_PATH="/path/to/your_local_model_path"
|
219 |
+
|
220 |
+
python scripts/vllm_struct_caption.py \
|
221 |
+
--model_path ${SkyCaptioner_V1_Model_PATH} \
|
222 |
+
--input_csv "./examples/test.csv" \
|
223 |
+
--out_csv "./examples/test_result.csv" \
|
224 |
+
--tp 1 \
|
225 |
+
--bs 4
|
226 |
+
```
|
227 |
+
|
228 |
+
#### T2V/I2V Caption Fusion by Qwen2.5-32B-Instruct Model
|
229 |
+
|
230 |
+
```shell
|
231 |
+
export LLM_MODEL_PATH="/path/to/your_local_model_path2"
|
232 |
+
|
233 |
+
python scripts/vllm_fusion_caption.py \
|
234 |
+
--model_path ${LLM_MODEL_PATH} \
|
235 |
+
--input_csv "./examples/test_result.csv" \
|
236 |
+
--out_csv "./examples/test_result_caption.csv" \
|
237 |
+
--bs 4 \
|
238 |
+
--tp 1 \
|
239 |
+
--task t2v
|
240 |
+
```
|
241 |
+
> **Note**:
|
242 |
+
> - If you want to get i2v caption, just change the `--task t2v` to `--task i2v` in your Command.
|
243 |
+
|
244 |
+
## Acknowledgements
|
245 |
+
|
246 |
+
We would like to thank the contributors of <a href="https://github.com/QwenLM/Qwen2.5-VL">Qwen2.5-VL</a>, <a href="https://github.com/bytedance/tarsier">tarsier2</a> and <a href="https://github.com/vllm-project/vllm">vllm</a> repositories, for their open research and contributions.
|
247 |
+
|
248 |
+
## Citation
|
249 |
+
|
250 |
+
```bibtex
|
251 |
+
@misc{chen2025skyreelsv2infinitelengthfilmgenerative,
|
252 |
+
author = {Guibin Chen and Dixuan Lin and Jiangping Yang and Chunze Lin and Junchen Zhu and Mingyuan Fan and Hao Zhang and Sheng Chen and Zheng Chen and Chengcheng Ma and Weiming Xiong and Wei Wang and Nuo Pang and Kang Kang and Zhiheng Xu and Yuzhe Jin and Yupeng Liang and Yubing Song and Peng Zhao and Boyuan Xu and Di Qiu and Debang Li and Zhengcong Fei and Yang Li and Yahui Zhou},
|
253 |
+
title = {Skyreels V2:Infinite-Length Film Generative Model},
|
254 |
+
year = {2025},
|
255 |
+
eprint={2504.13074},
|
256 |
+
archivePrefix={arXiv},
|
257 |
+
primaryClass={cs.CV},
|
258 |
+
url={https://arxiv.org/abs/2504.13074}
|
259 |
+
}
|
260 |
+
```
|
261 |
+
|
262 |
+
|
skycaptioner_v1/examples/data/1.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:16becc00811e427d9e3f5da5a977239907ea3a6d45b5482b9bcfea2abc3c6b7f
|
3 |
+
size 4821469
|
skycaptioner_v1/examples/data/2.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:109b954b7b1e36addef840554af53241415492e71972c87b7bfb03f77bf0d68a
|
3 |
+
size 1030652
|
skycaptioner_v1/examples/data/3.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0fd544787b24265605cf16c954f495fe9d73680e0529f1438e47c02803a9c2bf
|
3 |
+
size 1661366
|
skycaptioner_v1/examples/data/4.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f5cc4038924aae963a54779d713f7f9680259ae81a4446fb2775cc6bb51d907c
|
3 |
+
size 2523332
|
skycaptioner_v1/examples/test.csv
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
path
|
2 |
+
./examples/data/1.mp4
|
3 |
+
./examples/data/2.mp4
|
4 |
+
./examples/data/3.mp4
|
5 |
+
./examples/data/4.mp4
|
skycaptioner_v1/examples/test_result.csv
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
path,structural_caption
|
2 |
+
./examples/data/1.mp4,"{""subjects"": [{""TYPES"": {""type"": ""Sport"", ""sub_type"": ""Other""}, ""appearance"": ""Wearing winter sports gear, including a helmet and goggles."", ""action"": ""The video shows a snowy mountain landscape with a ski slope surrounded by dense trees and distant mountains. A skier is seen descending the slope, moving from the top left to the bottom right of the frame. The skier maintains a steady pace, navigating the curves of the slope. The background includes a ski lift with chairs moving along the cables, and the slope is marked with red and white lines indicating the path for skiers. The skier continues to descend, gradually getting closer to the bottom of the slope."", ""expression"": """", ""position"": ""Centered in the frame, moving downwards on the slope."", ""is_main_subject"": true}, {""TYPES"": {""type"": ""Scenery"", ""sub_type"": ""Mountain""}, ""appearance"": ""White, covering the ground and trees."", ""action"": """", ""expression"": """", ""position"": ""Surrounding the skier, covering the entire visible area."", ""is_main_subject"": false}, {""TYPES"": {""type"": ""Plant"", ""sub_type"": ""Tree""}, ""appearance"": ""Tall, evergreen, covered in snow."", ""action"": """", ""expression"": """", ""position"": ""Scattered throughout the scene, both in the foreground and background."", ""is_main_subject"": false}, {""TYPES"": {""type"": ""Scenery"", ""sub_type"": ""Mountain""}, ""appearance"": ""Snow-covered, with ski lifts and structures."", ""action"": """", ""expression"": """", ""position"": ""In the background, providing context to the location."", ""is_main_subject"": false}], ""shot_type"": ""long_shot"", ""shot_angle"": ""high_angle"", ""shot_position"": ""overlooking_view"", ""camera_motion"": ""the camera moves toward zooms in"", ""environment"": ""A snowy mountain landscape with a ski slope, trees, and a ski resort in the background."", ""lighting"": ""Bright daylight, casting shadows and highlighting the snow's texture.""}"
|
3 |
+
./examples/data/2.mp4,"{""subjects"": [{""TYPES"": {""type"": ""Human"", ""sub_type"": ""Woman""}, ""appearance"": ""Long, straight black hair, wearing a sparkling choker necklace with a diamond-like texture, light-colored top, subtle makeup with pink lipstick, stud earrings."", ""action"": ""A woman wearing a sparkling choker necklace and earrings is sitting in a car, looking to her left and speaking. A man, dressed in a suit, is sitting next to her, attentively watching her. The background outside the car is green, indicating a possible outdoor setting."", ""expression"": ""The individual in the video exhibits a neutral facial expression, characterized by slightly open lips and a gentle, soft-focus gaze. There are no noticeable signs of sadness or distress evident in their demeanor."", ""position"": ""Seated in the foreground of the car, facing slightly to the right."", ""is_main_subject"": true}, {""TYPES"": {""type"": ""Human"", ""sub_type"": ""Man""}, ""appearance"": ""Short hair, wearing a dark-colored suit with a white shirt."", ""action"": """", ""expression"": """", ""position"": ""Seated in the background of the car, facing the woman."", ""is_main_subject"": false}], ""shot_type"": ""close_up"", ""shot_angle"": ""eye_level"", ""shot_position"": ""side_view"", ""camera_motion"": """", ""environment"": ""Interior of a car with dark upholstery."", ""lighting"": ""Soft and natural lighting, suggesting daytime.""}"
|
4 |
+
./examples/data/3.mp4,"{""subjects"": [{""TYPES"": {""type"": ""Animal"", ""sub_type"": ""Insect""}, ""appearance"": ""The spider has a spherical, yellowish-green body with darker green stripes and spots. It has eight slender legs with visible joints and segments."", ""action"": ""A spider with a yellow and green body and black and white striped legs is hanging from its web in a natural setting with a blurred background of green and brown hues. The spider remains mostly still, with slight movements in its legs and body, indicating a gentle swaying motion."", ""expression"": """", ""position"": ""The spider is centrally positioned in the frame, hanging from a web."", ""is_main_subject"": true}], ""shot_type"": ""extreme_close_up"", ""shot_angle"": ""eye_level"", ""shot_position"": ""front_view"", ""camera_motion"": """", ""environment"": ""The background consists of vertical, out-of-focus lines in shades of green and brown, suggesting a natural environment with vegetation."", ""lighting"": ""The lighting is soft and diffused, with no harsh shadows, indicating an overcast sky or a shaded area.""}"
|
5 |
+
./examples/data/4.mp4,"{""subjects"": [{""TYPES"": {""type"": ""Sport"", ""sub_type"": ""Football""}, ""appearance"": ""Wearing a dark-colored jersey, black shorts, and bright blue soccer shoes with white soles."", ""action"": ""A man is on a grassy field with orange cones placed in a line. He is wearing a gray shirt, black shorts, and black socks with blue shoes. The man starts by standing still with his feet apart, then begins to move forward while keeping his eyes on the cones. He continues to run forward, maintaining his focus on the cones, and his feet move in a coordinated manner to navigate around them. The background shows a clear sky, some trees, and a few buildings in the distance."", ""expression"": """", ""position"": ""Centered in the frame, with the soccer ball positioned between the cones."", ""is_main_subject"": true}, {""TYPES"": {""type"": ""Sport"", ""sub_type"": ""Other""}, ""appearance"": ""Bright orange, conical shape."", ""action"": """", ""expression"": """", ""position"": ""Placed on the grass, creating a path for the soccer player."", ""is_main_subject"": false}], ""shot_type"": ""full_shot"", ""shot_angle"": ""low_angle"", ""shot_position"": ""front_view"", ""camera_motion"": ""use a tracking shot, the camera moves toward zooms in"", ""environment"": ""Outdoor sports field with well-maintained grass, trees, and a clear blue sky."", ""lighting"": ""Bright and natural, suggesting a sunny day.""}"
|
skycaptioner_v1/infer_fusion_caption.sh
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
expor LLM_MODEL_PATH="/path/to/your_local_model_path2"
|
2 |
+
|
3 |
+
python scripts/vllm_fusion_caption.py \
|
4 |
+
--model_path ${LLM_MODEL_PATH} \
|
5 |
+
--input_csv "./examples/test_result.csv" \
|
6 |
+
--out_csv "./examples/test_result_caption.csv" \
|
7 |
+
--bs 4 \
|
8 |
+
--tp 1 \
|
9 |
+
--task t2v
|
skycaptioner_v1/infer_struct_caption.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
expor SkyCaptioner_V1_Model_PATH="/path/to/your_local_model_path"
|
2 |
+
|
3 |
+
python scripts/vllm_struct_caption.py \
|
4 |
+
--model_path ${SkyCaptioner_V1_Model_PATH} \
|
5 |
+
--input_csv "./examples/test.csv" \
|
6 |
+
--out_csv "./examepls/test_result.csv" \
|
7 |
+
--tp 1 \
|
8 |
+
--bs 32
|
skycaptioner_v1/requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
decord==0.6.0
|
2 |
+
transformers>=4.49.0
|
3 |
+
vllm==0.8.4
|
skycaptioner_v1/scripts/utils.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
def result_writer(indices_list: list, result_list: list, meta: pd.DataFrame, column):
|
5 |
+
flat_indices = []
|
6 |
+
for x in zip(indices_list):
|
7 |
+
flat_indices.extend(x)
|
8 |
+
flat_results = []
|
9 |
+
for x in zip(result_list):
|
10 |
+
flat_results.extend(x)
|
11 |
+
|
12 |
+
flat_indices = np.array(flat_indices)
|
13 |
+
flat_results = np.array(flat_results)
|
14 |
+
|
15 |
+
unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True)
|
16 |
+
meta.loc[unique_indices, column[0]] = flat_results[unique_indices_idx]
|
17 |
+
|
18 |
+
meta = meta.loc[unique_indices]
|
19 |
+
return meta
|
skycaptioner_v1/scripts/vllm_fusion_caption.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
import argparse
|
4 |
+
import glob
|
5 |
+
import time
|
6 |
+
import gc
|
7 |
+
from tqdm import tqdm
|
8 |
+
import torch
|
9 |
+
from transformers import AutoTokenizer
|
10 |
+
import pandas as pd
|
11 |
+
from vllm import LLM, SamplingParams
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
import json
|
14 |
+
import random
|
15 |
+
from utils import result_writer
|
16 |
+
|
17 |
+
SYSTEM_PROMPT_I2V = """
|
18 |
+
You are an expert in video captioning. You are given a structured video caption and you need to compose it to be more natural and fluent in English.
|
19 |
+
|
20 |
+
## Structured Input
|
21 |
+
{structured_input}
|
22 |
+
|
23 |
+
## Notes
|
24 |
+
1. If there has an empty field, just ignore it and do not mention it in the output.
|
25 |
+
2. Do not make any semantic changes to the original fields. Please be sure to follow the original meaning.
|
26 |
+
3. If the action field is not empty, eliminate the irrelevant information in the action field that is not related to the timing action(such as wearings, background and environment information) to make a pure action field.
|
27 |
+
|
28 |
+
## Output Principles and Orders
|
29 |
+
1. First, eliminate the static information in the action field that is not related to the timing action, such as background or environment information.
|
30 |
+
2. Second, describe each subject with its pure action and expression if these fields exist.
|
31 |
+
|
32 |
+
## Output
|
33 |
+
Please directly output the final composed caption without any additional information.
|
34 |
+
"""
|
35 |
+
|
36 |
+
SYSTEM_PROMPT_T2V = """
|
37 |
+
You are an expert in video captioning. You are given a structured video caption and you need to compose it to be more natural and fluent in English.
|
38 |
+
|
39 |
+
## Structured Input
|
40 |
+
{structured_input}
|
41 |
+
|
42 |
+
## Notes
|
43 |
+
1. According to the action field information, change its name field to the subject pronoun in the action.
|
44 |
+
2. If there has an empty field, just ignore it and do not mention it in the output.
|
45 |
+
3. Do not make any semantic changes to the original fields. Please be sure to follow the original meaning.
|
46 |
+
|
47 |
+
## Output Principles and Orders
|
48 |
+
1. First, declare the shot_type, then declare the shot_angle and the shot_position fields.
|
49 |
+
2. Second, eliminate information in the action field that is not related to the timing action, such as background or environment information if action is not empty.
|
50 |
+
3. Third, describe each subject with its pure action, appearance, expression, position if these fields exist.
|
51 |
+
4. Finally, declare the environment and lighting if the environment and lighting fields are not empty.
|
52 |
+
|
53 |
+
## Output
|
54 |
+
Please directly output the final composed caption without any additional information.
|
55 |
+
"""
|
56 |
+
|
57 |
+
SHOT_TYPE_LIST = [
|
58 |
+
'close-up shot',
|
59 |
+
'extreme close-up shot',
|
60 |
+
'medium shot',
|
61 |
+
'long shot',
|
62 |
+
'full shot',
|
63 |
+
]
|
64 |
+
|
65 |
+
|
66 |
+
class StructuralCaptionDataset(torch.utils.data.Dataset):
|
67 |
+
def __init__(self, input_csv, model_path):
|
68 |
+
self.meta = pd.read_csv(input_csv)
|
69 |
+
self.task = args.task
|
70 |
+
self.system_prompt = SYSTEM_PROMPT_T2V if self.task == 't2v' else SYSTEM_PROMPT_I2V
|
71 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
72 |
+
|
73 |
+
|
74 |
+
def __len__(self):
|
75 |
+
return len(self.meta)
|
76 |
+
|
77 |
+
def __getitem__(self, index):
|
78 |
+
row = self.meta.iloc[index]
|
79 |
+
real_index = self.meta.index[index]
|
80 |
+
|
81 |
+
struct_caption = json.loads(row["structural_caption"])
|
82 |
+
|
83 |
+
camera_movement = struct_caption.get('camera_motion', '')
|
84 |
+
if camera_movement != '':
|
85 |
+
camera_movement += '.'
|
86 |
+
camera_movement = camera_movement.capitalize()
|
87 |
+
|
88 |
+
fusion_by_llm = False
|
89 |
+
cleaned_struct_caption = self.clean_struct_caption(struct_caption, self.task)
|
90 |
+
if cleaned_struct_caption.get('num_subjects', 0) > 0:
|
91 |
+
new_struct_caption = json.dumps(cleaned_struct_caption, indent=4, ensure_ascii=False)
|
92 |
+
conversation = [
|
93 |
+
{
|
94 |
+
"role": "system",
|
95 |
+
"content": self.system_prompt.format(structured_input=new_struct_caption),
|
96 |
+
},
|
97 |
+
]
|
98 |
+
text = self.tokenizer.apply_chat_template(
|
99 |
+
conversation,
|
100 |
+
tokenize=False,
|
101 |
+
add_generation_prompt=True
|
102 |
+
)
|
103 |
+
fusion_by_llm = True
|
104 |
+
else:
|
105 |
+
text = '-'
|
106 |
+
return real_index, fusion_by_llm, text, '-', camera_movement
|
107 |
+
|
108 |
+
def clean_struct_caption(self, struct_caption, task):
|
109 |
+
raw_subjects = struct_caption.get('subjects', [])
|
110 |
+
subjects = []
|
111 |
+
for subject in raw_subjects:
|
112 |
+
subject_type = subject.get("TYPES", {}).get('type', '')
|
113 |
+
subject_sub_type = subject.get("TYPES", {}).get('sub_type', '')
|
114 |
+
if subject_type not in ["Human", "Animal"]:
|
115 |
+
subject['expression'] = ''
|
116 |
+
if subject_type == 'Human' and subject_sub_type == 'Accessory':
|
117 |
+
subject['expression'] = ''
|
118 |
+
if subject_sub_type != '':
|
119 |
+
subject['name'] = subject_sub_type
|
120 |
+
if 'TYPES' in subject:
|
121 |
+
del subject['TYPES']
|
122 |
+
if 'is_main_subject' in subject:
|
123 |
+
del subject['is_main_subject']
|
124 |
+
subjects.append(subject)
|
125 |
+
|
126 |
+
to_del_subject_ids = []
|
127 |
+
for idx, subject in enumerate(subjects):
|
128 |
+
action = subject.get('action', '').strip()
|
129 |
+
subject['action'] = action
|
130 |
+
if random.random() > 0.9 and 'appearance' in subject:
|
131 |
+
del subject['appearance']
|
132 |
+
if random.random() > 0.9 and 'position' in subject:
|
133 |
+
del subject['position']
|
134 |
+
if task == 'i2v':
|
135 |
+
# just keep name and action, expression in subjects
|
136 |
+
dropped_keys = ['appearance', 'position']
|
137 |
+
for key in dropped_keys:
|
138 |
+
if key in subject:
|
139 |
+
del subject[key]
|
140 |
+
if subject['action'] == '' and ('expression' not in subject or subject['expression'] == ''):
|
141 |
+
to_del_subject_ids.append(idx)
|
142 |
+
|
143 |
+
# delete the subjects according to the to_del_subject_ids
|
144 |
+
for idx in sorted(to_del_subject_ids, reverse=True):
|
145 |
+
del subjects[idx]
|
146 |
+
|
147 |
+
|
148 |
+
shot_type = struct_caption.get('shot_type', '').replace('_', ' ')
|
149 |
+
if shot_type not in SHOT_TYPE_LIST:
|
150 |
+
struct_caption['shot_type'] = ''
|
151 |
+
|
152 |
+
new_struct_caption = {
|
153 |
+
'num_subjects': len(subjects),
|
154 |
+
'subjects': subjects,
|
155 |
+
'shot_type': struct_caption.get('shot_type', ''),
|
156 |
+
'shot_angle': struct_caption.get('shot_angle', ''),
|
157 |
+
'shot_position': struct_caption.get('shot_position', ''),
|
158 |
+
'environment': struct_caption.get('environment', ''),
|
159 |
+
'lighting': struct_caption.get('lighting', ''),
|
160 |
+
}
|
161 |
+
|
162 |
+
if task == 't2v' and random.random() > 0.9:
|
163 |
+
del new_struct_caption['lighting']
|
164 |
+
|
165 |
+
if task == 'i2v':
|
166 |
+
drop_keys = ['environment', 'lighting', 'shot_type', 'shot_angle', 'shot_position']
|
167 |
+
for drop_key in drop_keys:
|
168 |
+
del new_struct_caption[drop_key]
|
169 |
+
return new_struct_caption
|
170 |
+
|
171 |
+
def custom_collate_fn(batch):
|
172 |
+
real_indices, fusion_by_llm, texts, original_texts, camera_movements = zip(*batch)
|
173 |
+
return list(real_indices), list(fusion_by_llm), list(texts), list(original_texts), list(camera_movements)
|
174 |
+
|
175 |
+
|
176 |
+
if __name__ == "__main__":
|
177 |
+
parser = argparse.ArgumentParser(description="Caption Fusion by LLM")
|
178 |
+
parser.add_argument("--input_csv", default="./examples/test_result.csv")
|
179 |
+
parser.add_argument("--out_csv", default="./examples/test_result_caption.csv")
|
180 |
+
parser.add_argument("--bs", type=int, default=4)
|
181 |
+
parser.add_argument("--tp", type=int, default=1)
|
182 |
+
parser.add_argument("--model_path", required=True, type=str, help="LLM model path")
|
183 |
+
parser.add_argument("--task", default='t2v', help="t2v or i2v")
|
184 |
+
|
185 |
+
args = parser.parse_args()
|
186 |
+
|
187 |
+
sampling_params = SamplingParams(
|
188 |
+
temperature=0.1,
|
189 |
+
max_tokens=512,
|
190 |
+
stop=['\n\n']
|
191 |
+
)
|
192 |
+
# model_path = "/maindata/data/shared/public/Common-Models/Qwen2.5-32B-Instruct/"
|
193 |
+
|
194 |
+
|
195 |
+
llm = LLM(
|
196 |
+
model=args.model_path,
|
197 |
+
gpu_memory_utilization=0.9,
|
198 |
+
max_model_len=4096,
|
199 |
+
tensor_parallel_size = args.tp
|
200 |
+
)
|
201 |
+
|
202 |
+
|
203 |
+
dataset = StructuralCaptionDataset(input_csv=args.input_csv, model_path=args.model_path)
|
204 |
+
|
205 |
+
dataloader = DataLoader(
|
206 |
+
dataset,
|
207 |
+
batch_size=args.bs,
|
208 |
+
num_workers=8,
|
209 |
+
collate_fn=custom_collate_fn,
|
210 |
+
shuffle=False,
|
211 |
+
drop_last=False,
|
212 |
+
)
|
213 |
+
|
214 |
+
indices_list = []
|
215 |
+
result_list = []
|
216 |
+
for indices, fusion_by_llms, texts, original_texts, camera_movements in tqdm(dataloader):
|
217 |
+
llm_indices, llm_texts, llm_original_texts, llm_camera_movements = [], [], [], []
|
218 |
+
for idx, fusion_by_llm, text, original_text, camera_movement in zip(indices, fusion_by_llms, texts, original_texts, camera_movements):
|
219 |
+
if fusion_by_llm:
|
220 |
+
llm_indices.append(idx)
|
221 |
+
llm_texts.append(text)
|
222 |
+
llm_original_texts.append(original_text)
|
223 |
+
llm_camera_movements.append(camera_movement)
|
224 |
+
else:
|
225 |
+
indices_list.append(idx)
|
226 |
+
caption = original_text + " " + camera_movement
|
227 |
+
result_list.append(caption)
|
228 |
+
if len(llm_texts) > 0:
|
229 |
+
try:
|
230 |
+
outputs = llm.generate(llm_texts, sampling_params, use_tqdm=False)
|
231 |
+
results = []
|
232 |
+
for output in outputs:
|
233 |
+
result = output.outputs[0].text.strip()
|
234 |
+
results.append(result)
|
235 |
+
indices_list.extend(llm_indices)
|
236 |
+
except Exception as e:
|
237 |
+
print(f"Error at {llm_indices}: {str(e)}")
|
238 |
+
indices_list.extend(llm_indices)
|
239 |
+
results = llm_original_texts
|
240 |
+
|
241 |
+
for result, camera_movement in zip(results, llm_camera_movements):
|
242 |
+
# concat camera movement to fusion_caption
|
243 |
+
llm_caption = result + " " + camera_movement
|
244 |
+
result_list.append(llm_caption)
|
245 |
+
torch.cuda.empty_cache()
|
246 |
+
gc.collect()
|
247 |
+
gathered_list = [indices_list, result_list]
|
248 |
+
meta_new = result_writer(indices_list, result_list, dataset.meta, column=[f"{args.task}_fusion_caption"])
|
249 |
+
meta_new.to_csv(args.out_csv, index=False)
|
250 |
+
|
skycaptioner_v1/scripts/vllm_struct_caption.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import decord
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
from vllm import LLM, SamplingParams
|
11 |
+
from transformers import AutoTokenizer, AutoProcessor
|
12 |
+
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
|
15 |
+
SYSTEM_PROMPT = "I need you to generate a structured and detailed caption for the provided video. The structured output and the requirements for each field are as shown in the following JSON content: {\"subjects\": [{\"appearance\": \"Main subject appearance description\", \"action\": \"Main subject action\", \"expression\": \"Main subject expression (Only for human/animal categories, empty otherwise)\", \"position\": \"Subject position in the video (Can be relative position to other objects or spatial description)\", \"TYPES\": {\"type\": \"Main category (e.g., Human)\", \"sub_type\": \"Sub-category (e.g., Man)\"}, \"is_main_subject\": true}, {\"appearance\": \"Non-main subject appearance description\", \"action\": \"Non-main subject action\", \"expression\": \"Non-main subject expression (Only for human/animal categories, empty otherwise)\", \"position\": \"Position of non-main subject 1\", \"TYPES\": {\"type\": \"Main category (e.g., Vehicles)\", \"sub_type\": \"Sub-category (e.g., Ship)\"}, \"is_main_subject\": false}], \"shot_type\": \"Shot type(Options: long_shot/full_shot/medium_shot/close_up/extreme_close_up/other)\", \"shot_angle\": \"Camera angle(Options: eye_level/high_angle/low_angle/other)\", \"shot_position\": \"Camera position(Options: front_view/back_view/side_view/over_the_shoulder/overhead_view/point_of_view/aerial_view/overlooking_view/other)\", \"camera_motion\": \"Camera movement description\", \"environment\": \"Video background/environment description\", \"lighting\": \"Lighting information in the video\"}"
|
16 |
+
|
17 |
+
|
18 |
+
class VideoTextDataset(torch.utils.data.Dataset):
|
19 |
+
def __init__(self, csv_path, model_path):
|
20 |
+
self.meta = pd.read_csv(csv_path)
|
21 |
+
self._path = 'path'
|
22 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
23 |
+
self.processor = AutoProcessor.from_pretrained(model_path)
|
24 |
+
|
25 |
+
def __getitem__(self, index):
|
26 |
+
row = self.meta.iloc[index]
|
27 |
+
path = row[self._path]
|
28 |
+
real_index = self.meta.index[index]
|
29 |
+
vr = decord.VideoReader(path, ctx=decord.cpu(0), width=360, height=420)
|
30 |
+
start = 0
|
31 |
+
end = len(vr)
|
32 |
+
# avg_fps = vr.get_avg_fps()
|
33 |
+
index = self.get_index(end-start, 16, st=start)
|
34 |
+
frames = vr.get_batch(index).asnumpy() # n h w c
|
35 |
+
video_inputs = [torch.from_numpy(frames).permute(0, 3, 1, 2)]
|
36 |
+
conversation = {
|
37 |
+
"role": "user",
|
38 |
+
"content": [
|
39 |
+
{
|
40 |
+
"type": "video",
|
41 |
+
"video": row['path'],
|
42 |
+
"max_pixels": 360 * 420, # 460800
|
43 |
+
"fps": 2.0,
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"type": "text",
|
47 |
+
"text": SYSTEM_PROMPT
|
48 |
+
},
|
49 |
+
],
|
50 |
+
}
|
51 |
+
|
52 |
+
# 生成 user_input
|
53 |
+
user_input = self.processor.apply_chat_template(
|
54 |
+
[conversation],
|
55 |
+
tokenize=False,
|
56 |
+
add_generation_prompt=True
|
57 |
+
)
|
58 |
+
results = dict()
|
59 |
+
inputs = {
|
60 |
+
'prompt': user_input,
|
61 |
+
'multi_modal_data': {'video': video_inputs}
|
62 |
+
}
|
63 |
+
results["index"] = real_index
|
64 |
+
results['input'] = inputs
|
65 |
+
return results
|
66 |
+
|
67 |
+
def __len__(self):
|
68 |
+
return len(self.meta)
|
69 |
+
|
70 |
+
def get_index(self, video_size, num_frames, st=0):
|
71 |
+
seg_size = max(0., float(video_size - 1) / num_frames)
|
72 |
+
max_frame = int(video_size) - 1
|
73 |
+
seq = []
|
74 |
+
# index from 1, must add 1
|
75 |
+
for i in range(num_frames):
|
76 |
+
start = int(np.round(seg_size * i))
|
77 |
+
# end = int(np.round(seg_size * (i + 1)))
|
78 |
+
idx = min(start, max_frame)
|
79 |
+
seq.append(idx+st)
|
80 |
+
return seq
|
81 |
+
|
82 |
+
def result_writer(indices_list: list, result_list: list, meta: pd.DataFrame, column):
|
83 |
+
flat_indices = []
|
84 |
+
for x in zip(indices_list):
|
85 |
+
flat_indices.extend(x)
|
86 |
+
flat_results = []
|
87 |
+
for x in zip(result_list):
|
88 |
+
flat_results.extend(x)
|
89 |
+
|
90 |
+
flat_indices = np.array(flat_indices)
|
91 |
+
flat_results = np.array(flat_results)
|
92 |
+
|
93 |
+
unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True)
|
94 |
+
meta.loc[unique_indices, column[0]] = flat_results[unique_indices_idx]
|
95 |
+
|
96 |
+
meta = meta.loc[unique_indices]
|
97 |
+
return meta
|
98 |
+
|
99 |
+
|
100 |
+
def worker_init_fn(worker_id):
|
101 |
+
# Set different seed for each worker
|
102 |
+
worker_seed = torch.initial_seed() % 2**32
|
103 |
+
np.random.seed(worker_seed)
|
104 |
+
# Prevent deadlocks by setting timeout
|
105 |
+
torch.set_num_threads(1)
|
106 |
+
|
107 |
+
def main():
|
108 |
+
parser = argparse.ArgumentParser(description="SkyCaptioner-V1 vllm batch inference")
|
109 |
+
parser.add_argument("--input_csv", default="./examples/test.csv")
|
110 |
+
parser.add_argument("--out_csv", default="./examples/test_result.csv")
|
111 |
+
parser.add_argument("--bs", type=int, default=4)
|
112 |
+
parser.add_argument("--tp", type=int, default=1)
|
113 |
+
parser.add_argument("--model_path", required=True, type=str, help="skycaptioner-v1 model path")
|
114 |
+
args = parser.parse_args()
|
115 |
+
|
116 |
+
dataset = VideoTextDataset(csv_path=args.input_csv, model_path=args.model_path)
|
117 |
+
dataloader = DataLoader(
|
118 |
+
dataset,
|
119 |
+
batch_size=args.bs,
|
120 |
+
num_workers=4,
|
121 |
+
worker_init_fn=worker_init_fn,
|
122 |
+
persistent_workers=True,
|
123 |
+
timeout=180,
|
124 |
+
)
|
125 |
+
|
126 |
+
sampling_params = SamplingParams(temperature=0.05, max_tokens=2048)
|
127 |
+
|
128 |
+
llm = LLM(model=args.model_path,
|
129 |
+
gpu_memory_utilization=0.6,
|
130 |
+
max_model_len=31920,
|
131 |
+
tensor_parallel_size=args.tp)
|
132 |
+
|
133 |
+
indices_list = []
|
134 |
+
caption_save = []
|
135 |
+
for video_batch in tqdm(dataloader):
|
136 |
+
indices = video_batch["index"]
|
137 |
+
inputs = video_batch["input"]
|
138 |
+
batch_user_inputs = []
|
139 |
+
for prompt, video in zip(inputs['prompt'], inputs['multi_modal_data']['video'][0]):
|
140 |
+
usi={'prompt':prompt, 'multi_modal_data':{'video':video}}
|
141 |
+
batch_user_inputs.append(usi)
|
142 |
+
outputs = llm.generate(batch_user_inputs, sampling_params, use_tqdm=False)
|
143 |
+
struct_outputs = [output.outputs[0].text for output in outputs]
|
144 |
+
|
145 |
+
indices_list.extend(indices.tolist())
|
146 |
+
caption_save.extend(struct_outputs)
|
147 |
+
|
148 |
+
meta_new = result_writer(indices_list, caption_save, dataset.meta, column=["structural_caption"])
|
149 |
+
meta_new.to_csv(args.out_csv, index=False)
|
150 |
+
print(f'Saved structural_caption to {args.out_csv}')
|
151 |
+
|
152 |
+
if __name__ == '__main__':
|
153 |
+
main()
|
skyreels_v2_infer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .pipelines import DiffusionForcingPipeline
|
skyreels_v2_infer/distributed/__init__.py
ADDED
File without changes
|
skyreels_v2_infer/distributed/xdit_context_parallel.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.amp as amp
|
4 |
+
from torch.backends.cuda import sdp_kernel
|
5 |
+
from xfuser.core.distributed import get_sequence_parallel_rank
|
6 |
+
from xfuser.core.distributed import get_sequence_parallel_world_size
|
7 |
+
from xfuser.core.distributed import get_sp_group
|
8 |
+
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
9 |
+
|
10 |
+
from ..modules.transformer import sinusoidal_embedding_1d
|
11 |
+
|
12 |
+
|
13 |
+
def pad_freqs(original_tensor, target_len):
|
14 |
+
seq_len, s1, s2 = original_tensor.shape
|
15 |
+
pad_size = target_len - seq_len
|
16 |
+
padding_tensor = torch.ones(pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device)
|
17 |
+
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
18 |
+
return padded_tensor
|
19 |
+
|
20 |
+
|
21 |
+
@amp.autocast("cuda", enabled=False)
|
22 |
+
def rope_apply(x, grid_sizes, freqs):
|
23 |
+
"""
|
24 |
+
x: [B, L, N, C].
|
25 |
+
grid_sizes: [B, 3].
|
26 |
+
freqs: [M, C // 2].
|
27 |
+
"""
|
28 |
+
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
29 |
+
# split freqs
|
30 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
31 |
+
|
32 |
+
# loop over samples
|
33 |
+
output = []
|
34 |
+
grid = [grid_sizes.tolist()] * x.size(0)
|
35 |
+
for i, (f, h, w) in enumerate(grid):
|
36 |
+
seq_len = f * h * w
|
37 |
+
|
38 |
+
# precompute multipliers
|
39 |
+
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2))
|
40 |
+
freqs_i = torch.cat(
|
41 |
+
[
|
42 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
43 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
44 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
|
45 |
+
],
|
46 |
+
dim=-1,
|
47 |
+
).reshape(seq_len, 1, -1)
|
48 |
+
|
49 |
+
# apply rotary embedding
|
50 |
+
sp_size = get_sequence_parallel_world_size()
|
51 |
+
sp_rank = get_sequence_parallel_rank()
|
52 |
+
freqs_i = pad_freqs(freqs_i, s * sp_size)
|
53 |
+
s_per_rank = s
|
54 |
+
freqs_i_rank = freqs_i[(sp_rank * s_per_rank) : ((sp_rank + 1) * s_per_rank), :, :]
|
55 |
+
x_i = torch.view_as_real(x_i * freqs_i_rank.cuda()).flatten(2)
|
56 |
+
x_i = torch.cat([x_i, x[i, s:]])
|
57 |
+
|
58 |
+
# append to collection
|
59 |
+
output.append(x_i)
|
60 |
+
return torch.stack(output).float()
|
61 |
+
|
62 |
+
|
63 |
+
def broadcast_should_calc(should_calc: bool) -> bool:
|
64 |
+
import torch.distributed as dist
|
65 |
+
|
66 |
+
device = torch.cuda.current_device()
|
67 |
+
int_should_calc = 1 if should_calc else 0
|
68 |
+
tensor = torch.tensor([int_should_calc], device=device, dtype=torch.int8)
|
69 |
+
dist.broadcast(tensor, src=0)
|
70 |
+
should_calc = tensor.item() == 1
|
71 |
+
return should_calc
|
72 |
+
|
73 |
+
|
74 |
+
def usp_dit_forward(self, x, t, context, clip_fea=None, y=None, fps=None):
|
75 |
+
"""
|
76 |
+
x: A list of videos each with shape [C, T, H, W].
|
77 |
+
t: [B].
|
78 |
+
context: A list of text embeddings each with shape [L, C].
|
79 |
+
"""
|
80 |
+
if self.model_type == "i2v":
|
81 |
+
assert clip_fea is not None and y is not None
|
82 |
+
# params
|
83 |
+
device = self.patch_embedding.weight.device
|
84 |
+
if self.freqs.device != device:
|
85 |
+
self.freqs = self.freqs.to(device)
|
86 |
+
|
87 |
+
if y is not None:
|
88 |
+
x = torch.cat([x, y], dim=1)
|
89 |
+
|
90 |
+
# embeddings
|
91 |
+
x = self.patch_embedding(x)
|
92 |
+
grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long)
|
93 |
+
x = x.flatten(2).transpose(1, 2)
|
94 |
+
|
95 |
+
if self.flag_causal_attention:
|
96 |
+
frame_num = grid_sizes[0]
|
97 |
+
height = grid_sizes[1]
|
98 |
+
width = grid_sizes[2]
|
99 |
+
block_num = frame_num // self.num_frame_per_block
|
100 |
+
range_tensor = torch.arange(block_num).view(-1, 1)
|
101 |
+
range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten()
|
102 |
+
casual_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
|
103 |
+
casual_mask = casual_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x.device)
|
104 |
+
casual_mask = casual_mask.repeat(1, height, width, 1, height, width)
|
105 |
+
casual_mask = casual_mask.reshape(frame_num * height * width, frame_num * height * width)
|
106 |
+
self.block_mask = casual_mask.unsqueeze(0).unsqueeze(0)
|
107 |
+
|
108 |
+
# time embeddings
|
109 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
110 |
+
if t.dim() == 2:
|
111 |
+
b, f = t.shape
|
112 |
+
_flag_df = True
|
113 |
+
else:
|
114 |
+
_flag_df = False
|
115 |
+
e = self.time_embedding(
|
116 |
+
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
|
117 |
+
) # b, dim
|
118 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # b, 6, dim
|
119 |
+
|
120 |
+
if self.inject_sample_info:
|
121 |
+
fps = torch.tensor(fps, dtype=torch.long, device=device)
|
122 |
+
|
123 |
+
fps_emb = self.fps_embedding(fps).float()
|
124 |
+
if _flag_df:
|
125 |
+
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
|
126 |
+
else:
|
127 |
+
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
|
128 |
+
|
129 |
+
if _flag_df:
|
130 |
+
e = e.view(b, f, 1, 1, self.dim)
|
131 |
+
e0 = e0.view(b, f, 1, 1, 6, self.dim)
|
132 |
+
e = e.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3)
|
133 |
+
e0 = e0.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3)
|
134 |
+
e0 = e0.transpose(1, 2).contiguous()
|
135 |
+
|
136 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
137 |
+
|
138 |
+
# context
|
139 |
+
context = self.text_embedding(context)
|
140 |
+
|
141 |
+
if clip_fea is not None:
|
142 |
+
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
143 |
+
context = torch.concat([context_clip, context], dim=1)
|
144 |
+
|
145 |
+
# arguments
|
146 |
+
if e0.ndim == 4:
|
147 |
+
e0 = torch.chunk(e0, get_sequence_parallel_world_size(), dim=2)[get_sequence_parallel_rank()]
|
148 |
+
kwargs = dict(e=e0, grid_sizes=grid_sizes, freqs=self.freqs, context=context, block_mask=self.block_mask)
|
149 |
+
|
150 |
+
if self.enable_teacache:
|
151 |
+
modulated_inp = e0 if self.use_ref_steps else e
|
152 |
+
# teacache
|
153 |
+
if self.cnt % 2 == 0: # even -> conditon
|
154 |
+
self.is_even = True
|
155 |
+
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
156 |
+
should_calc_even = True
|
157 |
+
self.accumulated_rel_l1_distance_even = 0
|
158 |
+
else:
|
159 |
+
rescale_func = np.poly1d(self.coefficients)
|
160 |
+
self.accumulated_rel_l1_distance_even += rescale_func(
|
161 |
+
((modulated_inp - self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean())
|
162 |
+
.cpu()
|
163 |
+
.item()
|
164 |
+
)
|
165 |
+
if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
|
166 |
+
should_calc_even = False
|
167 |
+
else:
|
168 |
+
should_calc_even = True
|
169 |
+
self.accumulated_rel_l1_distance_even = 0
|
170 |
+
self.previous_e0_even = modulated_inp.clone()
|
171 |
+
else: # odd -> unconditon
|
172 |
+
self.is_even = False
|
173 |
+
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
174 |
+
should_calc_odd = True
|
175 |
+
self.accumulated_rel_l1_distance_odd = 0
|
176 |
+
else:
|
177 |
+
rescale_func = np.poly1d(self.coefficients)
|
178 |
+
self.accumulated_rel_l1_distance_odd += rescale_func(
|
179 |
+
((modulated_inp - self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean())
|
180 |
+
.cpu()
|
181 |
+
.item()
|
182 |
+
)
|
183 |
+
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
|
184 |
+
should_calc_odd = False
|
185 |
+
else:
|
186 |
+
should_calc_odd = True
|
187 |
+
self.accumulated_rel_l1_distance_odd = 0
|
188 |
+
self.previous_e0_odd = modulated_inp.clone()
|
189 |
+
|
190 |
+
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
191 |
+
if self.enable_teacache:
|
192 |
+
if self.is_even:
|
193 |
+
should_calc_even = broadcast_should_calc(should_calc_even)
|
194 |
+
if not should_calc_even:
|
195 |
+
x += self.previous_residual_even
|
196 |
+
else:
|
197 |
+
ori_x = x.clone()
|
198 |
+
for block in self.blocks:
|
199 |
+
x = block(x, **kwargs)
|
200 |
+
ori_x.mul_(-1)
|
201 |
+
ori_x.add_(x)
|
202 |
+
self.previous_residual_even = ori_x
|
203 |
+
else:
|
204 |
+
should_calc_odd = broadcast_should_calc(should_calc_odd)
|
205 |
+
if not should_calc_odd:
|
206 |
+
x += self.previous_residual_odd
|
207 |
+
else:
|
208 |
+
ori_x = x.clone()
|
209 |
+
for block in self.blocks:
|
210 |
+
x = block(x, **kwargs)
|
211 |
+
ori_x.mul_(-1)
|
212 |
+
ori_x.add_(x)
|
213 |
+
self.previous_residual_odd = ori_x
|
214 |
+
self.cnt += 1
|
215 |
+
if self.cnt >= self.num_steps:
|
216 |
+
self.cnt = 0
|
217 |
+
else:
|
218 |
+
# Context Parallel
|
219 |
+
for block in self.blocks:
|
220 |
+
x = block(x, **kwargs)
|
221 |
+
|
222 |
+
# head
|
223 |
+
if e.ndim == 3:
|
224 |
+
e = torch.chunk(e, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
225 |
+
x = self.head(x, e)
|
226 |
+
# Context Parallel
|
227 |
+
x = get_sp_group().all_gather(x, dim=1)
|
228 |
+
# unpatchify
|
229 |
+
x = self.unpatchify(x, grid_sizes)
|
230 |
+
return x.float()
|
231 |
+
|
232 |
+
|
233 |
+
def usp_attn_forward(self, x, grid_sizes, freqs, block_mask):
|
234 |
+
|
235 |
+
r"""
|
236 |
+
Args:
|
237 |
+
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
238 |
+
seq_lens(Tensor): Shape [B]
|
239 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
240 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
241 |
+
"""
|
242 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
243 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
244 |
+
|
245 |
+
def half(x):
|
246 |
+
return x if x.dtype in half_dtypes else x.to(torch.bfloat16)
|
247 |
+
|
248 |
+
# query, key, value function
|
249 |
+
def qkv_fn(x):
|
250 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
251 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
252 |
+
v = self.v(x).view(b, s, n, d)
|
253 |
+
return q, k, v
|
254 |
+
|
255 |
+
x = x.to(self.q.weight.dtype)
|
256 |
+
q, k, v = qkv_fn(x)
|
257 |
+
|
258 |
+
if not self._flag_ar_attention:
|
259 |
+
q = rope_apply(q, grid_sizes, freqs)
|
260 |
+
k = rope_apply(k, grid_sizes, freqs)
|
261 |
+
else:
|
262 |
+
|
263 |
+
q = rope_apply(q, grid_sizes, freqs)
|
264 |
+
k = rope_apply(k, grid_sizes, freqs)
|
265 |
+
q = q.to(torch.bfloat16)
|
266 |
+
k = k.to(torch.bfloat16)
|
267 |
+
v = v.to(torch.bfloat16)
|
268 |
+
# x = torch.nn.functional.scaled_dot_product_attention(
|
269 |
+
# q.transpose(1, 2),
|
270 |
+
# k.transpose(1, 2),
|
271 |
+
# v.transpose(1, 2),
|
272 |
+
# ).transpose(1, 2).contiguous()
|
273 |
+
with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
274 |
+
x = (
|
275 |
+
torch.nn.functional.scaled_dot_product_attention(
|
276 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
|
277 |
+
)
|
278 |
+
.transpose(1, 2)
|
279 |
+
.contiguous()
|
280 |
+
)
|
281 |
+
x = xFuserLongContextAttention()(None, query=half(q), key=half(k), value=half(v), window_size=self.window_size)
|
282 |
+
|
283 |
+
# output
|
284 |
+
x = x.flatten(2)
|
285 |
+
x = self.o(x)
|
286 |
+
return x
|
skyreels_v2_infer/modules/__init__.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from safetensors.torch import load_file
|
6 |
+
|
7 |
+
from .clip import CLIPModel
|
8 |
+
from .t5 import T5EncoderModel
|
9 |
+
from .transformer import WanModel
|
10 |
+
from .vae import WanVAE
|
11 |
+
|
12 |
+
|
13 |
+
def download_model(model_id):
|
14 |
+
if not os.path.exists(model_id):
|
15 |
+
from huggingface_hub import snapshot_download
|
16 |
+
|
17 |
+
model_id = snapshot_download(repo_id=model_id)
|
18 |
+
return model_id
|
19 |
+
|
20 |
+
|
21 |
+
def get_vae(model_path, device="cuda", weight_dtype=torch.float32) -> WanVAE:
|
22 |
+
vae = WanVAE(model_path).to(device).to(weight_dtype)
|
23 |
+
vae.vae.requires_grad_(False)
|
24 |
+
vae.vae.eval()
|
25 |
+
gc.collect()
|
26 |
+
torch.cuda.empty_cache()
|
27 |
+
return vae
|
28 |
+
|
29 |
+
|
30 |
+
def get_transformer(model_path, device="cuda", weight_dtype=torch.bfloat16) -> WanModel:
|
31 |
+
config_path = os.path.join(model_path, "config.json")
|
32 |
+
transformer = WanModel.from_config(config_path).to(weight_dtype).to(device)
|
33 |
+
|
34 |
+
for file in os.listdir(model_path):
|
35 |
+
if file.endswith(".safetensors"):
|
36 |
+
file_path = os.path.join(model_path, file)
|
37 |
+
state_dict = load_file(file_path)
|
38 |
+
transformer.load_state_dict(state_dict, strict=False)
|
39 |
+
del state_dict
|
40 |
+
gc.collect()
|
41 |
+
torch.cuda.empty_cache()
|
42 |
+
|
43 |
+
transformer.requires_grad_(False)
|
44 |
+
transformer.eval()
|
45 |
+
gc.collect()
|
46 |
+
torch.cuda.empty_cache()
|
47 |
+
return transformer
|
48 |
+
|
49 |
+
|
50 |
+
def get_text_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16) -> T5EncoderModel:
|
51 |
+
t5_model = os.path.join(model_path, "models_t5_umt5-xxl-enc-bf16.pth")
|
52 |
+
tokenizer_path = os.path.join(model_path, "google", "umt5-xxl")
|
53 |
+
text_encoder = T5EncoderModel(checkpoint_path=t5_model, tokenizer_path=tokenizer_path).to(device).to(weight_dtype)
|
54 |
+
text_encoder.requires_grad_(False)
|
55 |
+
text_encoder.eval()
|
56 |
+
gc.collect()
|
57 |
+
torch.cuda.empty_cache()
|
58 |
+
return text_encoder
|
59 |
+
|
60 |
+
|
61 |
+
def get_image_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16) -> CLIPModel:
|
62 |
+
checkpoint_path = os.path.join(model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
|
63 |
+
tokenizer_path = os.path.join(model_path, "xlm-roberta-large")
|
64 |
+
image_enc = CLIPModel(checkpoint_path, tokenizer_path).to(weight_dtype).to(device)
|
65 |
+
image_enc.requires_grad_(False)
|
66 |
+
image_enc.eval()
|
67 |
+
gc.collect()
|
68 |
+
torch.cuda.empty_cache()
|
69 |
+
return image_enc
|
skyreels_v2_infer/modules/attention.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
2 |
+
import torch
|
3 |
+
|
4 |
+
try:
|
5 |
+
import flash_attn_interface
|
6 |
+
|
7 |
+
FLASH_ATTN_3_AVAILABLE = True
|
8 |
+
except ModuleNotFoundError:
|
9 |
+
FLASH_ATTN_3_AVAILABLE = False
|
10 |
+
|
11 |
+
try:
|
12 |
+
import flash_attn
|
13 |
+
|
14 |
+
FLASH_ATTN_2_AVAILABLE = True
|
15 |
+
except ModuleNotFoundError:
|
16 |
+
FLASH_ATTN_2_AVAILABLE = False
|
17 |
+
|
18 |
+
import warnings
|
19 |
+
|
20 |
+
__all__ = [
|
21 |
+
"flash_attention",
|
22 |
+
"attention",
|
23 |
+
]
|
24 |
+
|
25 |
+
|
26 |
+
def flash_attention(
|
27 |
+
q,
|
28 |
+
k,
|
29 |
+
v,
|
30 |
+
q_lens=None,
|
31 |
+
k_lens=None,
|
32 |
+
dropout_p=0.0,
|
33 |
+
softmax_scale=None,
|
34 |
+
q_scale=None,
|
35 |
+
causal=False,
|
36 |
+
window_size=(-1, -1),
|
37 |
+
deterministic=False,
|
38 |
+
dtype=torch.bfloat16,
|
39 |
+
version=None,
|
40 |
+
):
|
41 |
+
"""
|
42 |
+
q: [B, Lq, Nq, C1].
|
43 |
+
k: [B, Lk, Nk, C1].
|
44 |
+
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
|
45 |
+
q_lens: [B].
|
46 |
+
k_lens: [B].
|
47 |
+
dropout_p: float. Dropout probability.
|
48 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
49 |
+
causal: bool. Whether to apply causal attention mask.
|
50 |
+
window_size: (left right). If not (-1, -1), apply sliding window local attention.
|
51 |
+
deterministic: bool. If True, slightly slower and uses more memory.
|
52 |
+
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
53 |
+
"""
|
54 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
55 |
+
assert dtype in half_dtypes
|
56 |
+
assert q.device.type == "cuda" and q.size(-1) <= 256
|
57 |
+
|
58 |
+
# params
|
59 |
+
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
60 |
+
|
61 |
+
def half(x):
|
62 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
63 |
+
|
64 |
+
# preprocess query
|
65 |
+
|
66 |
+
q = half(q.flatten(0, 1))
|
67 |
+
q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(device=q.device, non_blocking=True)
|
68 |
+
|
69 |
+
# preprocess key, value
|
70 |
+
|
71 |
+
k = half(k.flatten(0, 1))
|
72 |
+
v = half(v.flatten(0, 1))
|
73 |
+
k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(device=k.device, non_blocking=True)
|
74 |
+
|
75 |
+
q = q.to(v.dtype)
|
76 |
+
k = k.to(v.dtype)
|
77 |
+
|
78 |
+
if q_scale is not None:
|
79 |
+
q = q * q_scale
|
80 |
+
|
81 |
+
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
82 |
+
warnings.warn("Flash attention 3 is not available, use flash attention 2 instead.")
|
83 |
+
|
84 |
+
torch.cuda.nvtx.range_push(f"{list(q.shape)}-{list(k.shape)}-{list(v.shape)}-{q.dtype}-{k.dtype}-{v.dtype}")
|
85 |
+
# apply attention
|
86 |
+
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
|
87 |
+
# Note: dropout_p, window_size are not supported in FA3 now.
|
88 |
+
x = flash_attn_interface.flash_attn_varlen_func(
|
89 |
+
q=q,
|
90 |
+
k=k,
|
91 |
+
v=v,
|
92 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
|
93 |
+
.cumsum(0, dtype=torch.int32)
|
94 |
+
.to(q.device, non_blocking=True),
|
95 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
|
96 |
+
.cumsum(0, dtype=torch.int32)
|
97 |
+
.to(q.device, non_blocking=True),
|
98 |
+
seqused_q=None,
|
99 |
+
seqused_k=None,
|
100 |
+
max_seqlen_q=lq,
|
101 |
+
max_seqlen_k=lk,
|
102 |
+
softmax_scale=softmax_scale,
|
103 |
+
causal=causal,
|
104 |
+
deterministic=deterministic,
|
105 |
+
)[0].unflatten(0, (b, lq))
|
106 |
+
else:
|
107 |
+
assert FLASH_ATTN_2_AVAILABLE
|
108 |
+
x = flash_attn.flash_attn_varlen_func(
|
109 |
+
q=q,
|
110 |
+
k=k,
|
111 |
+
v=v,
|
112 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
|
113 |
+
.cumsum(0, dtype=torch.int32)
|
114 |
+
.to(q.device, non_blocking=True),
|
115 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
|
116 |
+
.cumsum(0, dtype=torch.int32)
|
117 |
+
.to(q.device, non_blocking=True),
|
118 |
+
max_seqlen_q=lq,
|
119 |
+
max_seqlen_k=lk,
|
120 |
+
dropout_p=dropout_p,
|
121 |
+
softmax_scale=softmax_scale,
|
122 |
+
causal=causal,
|
123 |
+
window_size=window_size,
|
124 |
+
deterministic=deterministic,
|
125 |
+
).unflatten(0, (b, lq))
|
126 |
+
torch.cuda.nvtx.range_pop()
|
127 |
+
|
128 |
+
# output
|
129 |
+
return x
|
130 |
+
|
131 |
+
|
132 |
+
def attention(
|
133 |
+
q,
|
134 |
+
k,
|
135 |
+
v,
|
136 |
+
q_lens=None,
|
137 |
+
k_lens=None,
|
138 |
+
dropout_p=0.0,
|
139 |
+
softmax_scale=None,
|
140 |
+
q_scale=None,
|
141 |
+
causal=False,
|
142 |
+
window_size=(-1, -1),
|
143 |
+
deterministic=False,
|
144 |
+
dtype=torch.bfloat16,
|
145 |
+
fa_version=None,
|
146 |
+
):
|
147 |
+
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
148 |
+
return flash_attention(
|
149 |
+
q=q,
|
150 |
+
k=k,
|
151 |
+
v=v,
|
152 |
+
q_lens=q_lens,
|
153 |
+
k_lens=k_lens,
|
154 |
+
dropout_p=dropout_p,
|
155 |
+
softmax_scale=softmax_scale,
|
156 |
+
q_scale=q_scale,
|
157 |
+
causal=causal,
|
158 |
+
window_size=window_size,
|
159 |
+
deterministic=deterministic,
|
160 |
+
dtype=dtype,
|
161 |
+
version=fa_version,
|
162 |
+
)
|
163 |
+
else:
|
164 |
+
if q_lens is not None or k_lens is not None:
|
165 |
+
warnings.warn(
|
166 |
+
"Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance."
|
167 |
+
)
|
168 |
+
attn_mask = None
|
169 |
+
|
170 |
+
q = q.transpose(1, 2).to(dtype)
|
171 |
+
k = k.transpose(1, 2).to(dtype)
|
172 |
+
v = v.transpose(1, 2).to(dtype)
|
173 |
+
|
174 |
+
out = torch.nn.functional.scaled_dot_product_attention(
|
175 |
+
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
|
176 |
+
)
|
177 |
+
|
178 |
+
out = out.transpose(1, 2).contiguous()
|
179 |
+
return out
|
skyreels_v2_infer/modules/clip.py
ADDED
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
|
2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision.transforms as T
|
10 |
+
from diffusers.models import ModelMixin
|
11 |
+
|
12 |
+
from .attention import flash_attention
|
13 |
+
from .tokenizers import HuggingfaceTokenizer
|
14 |
+
from .xlm_roberta import XLMRoberta
|
15 |
+
|
16 |
+
__all__ = [
|
17 |
+
"XLMRobertaCLIP",
|
18 |
+
"clip_xlm_roberta_vit_h_14",
|
19 |
+
"CLIPModel",
|
20 |
+
]
|
21 |
+
|
22 |
+
|
23 |
+
def pos_interpolate(pos, seq_len):
|
24 |
+
if pos.size(1) == seq_len:
|
25 |
+
return pos
|
26 |
+
else:
|
27 |
+
src_grid = int(math.sqrt(pos.size(1)))
|
28 |
+
tar_grid = int(math.sqrt(seq_len))
|
29 |
+
n = pos.size(1) - src_grid * src_grid
|
30 |
+
return torch.cat(
|
31 |
+
[
|
32 |
+
pos[:, :n],
|
33 |
+
F.interpolate(
|
34 |
+
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(0, 3, 1, 2),
|
35 |
+
size=(tar_grid, tar_grid),
|
36 |
+
mode="bicubic",
|
37 |
+
align_corners=False,
|
38 |
+
)
|
39 |
+
.flatten(2)
|
40 |
+
.transpose(1, 2),
|
41 |
+
],
|
42 |
+
dim=1,
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
class QuickGELU(nn.Module):
|
47 |
+
def forward(self, x):
|
48 |
+
return x * torch.sigmoid(1.702 * x)
|
49 |
+
|
50 |
+
|
51 |
+
class LayerNorm(nn.LayerNorm):
|
52 |
+
def forward(self, x):
|
53 |
+
return super().forward(x.float()).type_as(x)
|
54 |
+
|
55 |
+
|
56 |
+
class SelfAttention(nn.Module):
|
57 |
+
def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0):
|
58 |
+
assert dim % num_heads == 0
|
59 |
+
super().__init__()
|
60 |
+
self.dim = dim
|
61 |
+
self.num_heads = num_heads
|
62 |
+
self.head_dim = dim // num_heads
|
63 |
+
self.causal = causal
|
64 |
+
self.attn_dropout = attn_dropout
|
65 |
+
self.proj_dropout = proj_dropout
|
66 |
+
|
67 |
+
# layers
|
68 |
+
self.to_qkv = nn.Linear(dim, dim * 3)
|
69 |
+
self.proj = nn.Linear(dim, dim)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
"""
|
73 |
+
x: [B, L, C].
|
74 |
+
"""
|
75 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
76 |
+
|
77 |
+
# compute query, key, value
|
78 |
+
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
|
79 |
+
|
80 |
+
# compute attention
|
81 |
+
p = self.attn_dropout if self.training else 0.0
|
82 |
+
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
|
83 |
+
x = x.reshape(b, s, c)
|
84 |
+
|
85 |
+
# output
|
86 |
+
x = self.proj(x)
|
87 |
+
x = F.dropout(x, self.proj_dropout, self.training)
|
88 |
+
return x
|
89 |
+
|
90 |
+
|
91 |
+
class SwiGLU(nn.Module):
|
92 |
+
def __init__(self, dim, mid_dim):
|
93 |
+
super().__init__()
|
94 |
+
self.dim = dim
|
95 |
+
self.mid_dim = mid_dim
|
96 |
+
|
97 |
+
# layers
|
98 |
+
self.fc1 = nn.Linear(dim, mid_dim)
|
99 |
+
self.fc2 = nn.Linear(dim, mid_dim)
|
100 |
+
self.fc3 = nn.Linear(mid_dim, dim)
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
x = F.silu(self.fc1(x)) * self.fc2(x)
|
104 |
+
x = self.fc3(x)
|
105 |
+
return x
|
106 |
+
|
107 |
+
|
108 |
+
class AttentionBlock(nn.Module):
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
dim,
|
112 |
+
mlp_ratio,
|
113 |
+
num_heads,
|
114 |
+
post_norm=False,
|
115 |
+
causal=False,
|
116 |
+
activation="quick_gelu",
|
117 |
+
attn_dropout=0.0,
|
118 |
+
proj_dropout=0.0,
|
119 |
+
norm_eps=1e-5,
|
120 |
+
):
|
121 |
+
assert activation in ["quick_gelu", "gelu", "swi_glu"]
|
122 |
+
super().__init__()
|
123 |
+
self.dim = dim
|
124 |
+
self.mlp_ratio = mlp_ratio
|
125 |
+
self.num_heads = num_heads
|
126 |
+
self.post_norm = post_norm
|
127 |
+
self.causal = causal
|
128 |
+
self.norm_eps = norm_eps
|
129 |
+
|
130 |
+
# layers
|
131 |
+
self.norm1 = LayerNorm(dim, eps=norm_eps)
|
132 |
+
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout)
|
133 |
+
self.norm2 = LayerNorm(dim, eps=norm_eps)
|
134 |
+
if activation == "swi_glu":
|
135 |
+
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
|
136 |
+
else:
|
137 |
+
self.mlp = nn.Sequential(
|
138 |
+
nn.Linear(dim, int(dim * mlp_ratio)),
|
139 |
+
QuickGELU() if activation == "quick_gelu" else nn.GELU(),
|
140 |
+
nn.Linear(int(dim * mlp_ratio), dim),
|
141 |
+
nn.Dropout(proj_dropout),
|
142 |
+
)
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
if self.post_norm:
|
146 |
+
x = x + self.norm1(self.attn(x))
|
147 |
+
x = x + self.norm2(self.mlp(x))
|
148 |
+
else:
|
149 |
+
x = x + self.attn(self.norm1(x))
|
150 |
+
x = x + self.mlp(self.norm2(x))
|
151 |
+
return x
|
152 |
+
|
153 |
+
|
154 |
+
class AttentionPool(nn.Module):
|
155 |
+
def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5):
|
156 |
+
assert dim % num_heads == 0
|
157 |
+
super().__init__()
|
158 |
+
self.dim = dim
|
159 |
+
self.mlp_ratio = mlp_ratio
|
160 |
+
self.num_heads = num_heads
|
161 |
+
self.head_dim = dim // num_heads
|
162 |
+
self.proj_dropout = proj_dropout
|
163 |
+
self.norm_eps = norm_eps
|
164 |
+
|
165 |
+
# layers
|
166 |
+
gain = 1.0 / math.sqrt(dim)
|
167 |
+
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
168 |
+
self.to_q = nn.Linear(dim, dim)
|
169 |
+
self.to_kv = nn.Linear(dim, dim * 2)
|
170 |
+
self.proj = nn.Linear(dim, dim)
|
171 |
+
self.norm = LayerNorm(dim, eps=norm_eps)
|
172 |
+
self.mlp = nn.Sequential(
|
173 |
+
nn.Linear(dim, int(dim * mlp_ratio)),
|
174 |
+
QuickGELU() if activation == "quick_gelu" else nn.GELU(),
|
175 |
+
nn.Linear(int(dim * mlp_ratio), dim),
|
176 |
+
nn.Dropout(proj_dropout),
|
177 |
+
)
|
178 |
+
|
179 |
+
def forward(self, x):
|
180 |
+
"""
|
181 |
+
x: [B, L, C].
|
182 |
+
"""
|
183 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
184 |
+
|
185 |
+
# compute query, key, value
|
186 |
+
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
|
187 |
+
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
188 |
+
|
189 |
+
# compute attention
|
190 |
+
x = flash_attention(q, k, v, version=2)
|
191 |
+
x = x.reshape(b, 1, c)
|
192 |
+
|
193 |
+
# output
|
194 |
+
x = self.proj(x)
|
195 |
+
x = F.dropout(x, self.proj_dropout, self.training)
|
196 |
+
|
197 |
+
# mlp
|
198 |
+
x = x + self.mlp(self.norm(x))
|
199 |
+
return x[:, 0]
|
200 |
+
|
201 |
+
|
202 |
+
class VisionTransformer(nn.Module):
|
203 |
+
def __init__(
|
204 |
+
self,
|
205 |
+
image_size=224,
|
206 |
+
patch_size=16,
|
207 |
+
dim=768,
|
208 |
+
mlp_ratio=4,
|
209 |
+
out_dim=512,
|
210 |
+
num_heads=12,
|
211 |
+
num_layers=12,
|
212 |
+
pool_type="token",
|
213 |
+
pre_norm=True,
|
214 |
+
post_norm=False,
|
215 |
+
activation="quick_gelu",
|
216 |
+
attn_dropout=0.0,
|
217 |
+
proj_dropout=0.0,
|
218 |
+
embedding_dropout=0.0,
|
219 |
+
norm_eps=1e-5,
|
220 |
+
):
|
221 |
+
if image_size % patch_size != 0:
|
222 |
+
print("[WARNING] image_size is not divisible by patch_size", flush=True)
|
223 |
+
assert pool_type in ("token", "token_fc", "attn_pool")
|
224 |
+
out_dim = out_dim or dim
|
225 |
+
super().__init__()
|
226 |
+
self.image_size = image_size
|
227 |
+
self.patch_size = patch_size
|
228 |
+
self.num_patches = (image_size // patch_size) ** 2
|
229 |
+
self.dim = dim
|
230 |
+
self.mlp_ratio = mlp_ratio
|
231 |
+
self.out_dim = out_dim
|
232 |
+
self.num_heads = num_heads
|
233 |
+
self.num_layers = num_layers
|
234 |
+
self.pool_type = pool_type
|
235 |
+
self.post_norm = post_norm
|
236 |
+
self.norm_eps = norm_eps
|
237 |
+
|
238 |
+
# embeddings
|
239 |
+
gain = 1.0 / math.sqrt(dim)
|
240 |
+
self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm)
|
241 |
+
if pool_type in ("token", "token_fc"):
|
242 |
+
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
243 |
+
self.pos_embedding = nn.Parameter(
|
244 |
+
gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim)
|
245 |
+
)
|
246 |
+
self.dropout = nn.Dropout(embedding_dropout)
|
247 |
+
|
248 |
+
# transformer
|
249 |
+
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
|
250 |
+
self.transformer = nn.Sequential(
|
251 |
+
*[
|
252 |
+
AttentionBlock(
|
253 |
+
dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps
|
254 |
+
)
|
255 |
+
for _ in range(num_layers)
|
256 |
+
]
|
257 |
+
)
|
258 |
+
self.post_norm = LayerNorm(dim, eps=norm_eps)
|
259 |
+
|
260 |
+
# head
|
261 |
+
if pool_type == "token":
|
262 |
+
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
|
263 |
+
elif pool_type == "token_fc":
|
264 |
+
self.head = nn.Linear(dim, out_dim)
|
265 |
+
elif pool_type == "attn_pool":
|
266 |
+
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps)
|
267 |
+
|
268 |
+
def forward(self, x, interpolation=False, use_31_block=False):
|
269 |
+
b = x.size(0)
|
270 |
+
|
271 |
+
# embeddings
|
272 |
+
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
|
273 |
+
if self.pool_type in ("token", "token_fc"):
|
274 |
+
x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
|
275 |
+
if interpolation:
|
276 |
+
e = pos_interpolate(self.pos_embedding, x.size(1))
|
277 |
+
else:
|
278 |
+
e = self.pos_embedding
|
279 |
+
x = self.dropout(x + e)
|
280 |
+
if self.pre_norm is not None:
|
281 |
+
x = self.pre_norm(x)
|
282 |
+
|
283 |
+
# transformer
|
284 |
+
if use_31_block:
|
285 |
+
x = self.transformer[:-1](x)
|
286 |
+
return x
|
287 |
+
else:
|
288 |
+
x = self.transformer(x)
|
289 |
+
return x
|
290 |
+
|
291 |
+
|
292 |
+
class XLMRobertaWithHead(XLMRoberta):
|
293 |
+
def __init__(self, **kwargs):
|
294 |
+
self.out_dim = kwargs.pop("out_dim")
|
295 |
+
super().__init__(**kwargs)
|
296 |
+
|
297 |
+
# head
|
298 |
+
mid_dim = (self.dim + self.out_dim) // 2
|
299 |
+
self.head = nn.Sequential(
|
300 |
+
nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), nn.Linear(mid_dim, self.out_dim, bias=False)
|
301 |
+
)
|
302 |
+
|
303 |
+
def forward(self, ids):
|
304 |
+
# xlm-roberta
|
305 |
+
x = super().forward(ids)
|
306 |
+
|
307 |
+
# average pooling
|
308 |
+
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
|
309 |
+
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
|
310 |
+
|
311 |
+
# head
|
312 |
+
x = self.head(x)
|
313 |
+
return x
|
314 |
+
|
315 |
+
|
316 |
+
class XLMRobertaCLIP(nn.Module):
|
317 |
+
def __init__(
|
318 |
+
self,
|
319 |
+
embed_dim=1024,
|
320 |
+
image_size=224,
|
321 |
+
patch_size=14,
|
322 |
+
vision_dim=1280,
|
323 |
+
vision_mlp_ratio=4,
|
324 |
+
vision_heads=16,
|
325 |
+
vision_layers=32,
|
326 |
+
vision_pool="token",
|
327 |
+
vision_pre_norm=True,
|
328 |
+
vision_post_norm=False,
|
329 |
+
activation="gelu",
|
330 |
+
vocab_size=250002,
|
331 |
+
max_text_len=514,
|
332 |
+
type_size=1,
|
333 |
+
pad_id=1,
|
334 |
+
text_dim=1024,
|
335 |
+
text_heads=16,
|
336 |
+
text_layers=24,
|
337 |
+
text_post_norm=True,
|
338 |
+
text_dropout=0.1,
|
339 |
+
attn_dropout=0.0,
|
340 |
+
proj_dropout=0.0,
|
341 |
+
embedding_dropout=0.0,
|
342 |
+
norm_eps=1e-5,
|
343 |
+
):
|
344 |
+
super().__init__()
|
345 |
+
self.embed_dim = embed_dim
|
346 |
+
self.image_size = image_size
|
347 |
+
self.patch_size = patch_size
|
348 |
+
self.vision_dim = vision_dim
|
349 |
+
self.vision_mlp_ratio = vision_mlp_ratio
|
350 |
+
self.vision_heads = vision_heads
|
351 |
+
self.vision_layers = vision_layers
|
352 |
+
self.vision_pre_norm = vision_pre_norm
|
353 |
+
self.vision_post_norm = vision_post_norm
|
354 |
+
self.activation = activation
|
355 |
+
self.vocab_size = vocab_size
|
356 |
+
self.max_text_len = max_text_len
|
357 |
+
self.type_size = type_size
|
358 |
+
self.pad_id = pad_id
|
359 |
+
self.text_dim = text_dim
|
360 |
+
self.text_heads = text_heads
|
361 |
+
self.text_layers = text_layers
|
362 |
+
self.text_post_norm = text_post_norm
|
363 |
+
self.norm_eps = norm_eps
|
364 |
+
|
365 |
+
# models
|
366 |
+
self.visual = VisionTransformer(
|
367 |
+
image_size=image_size,
|
368 |
+
patch_size=patch_size,
|
369 |
+
dim=vision_dim,
|
370 |
+
mlp_ratio=vision_mlp_ratio,
|
371 |
+
out_dim=embed_dim,
|
372 |
+
num_heads=vision_heads,
|
373 |
+
num_layers=vision_layers,
|
374 |
+
pool_type=vision_pool,
|
375 |
+
pre_norm=vision_pre_norm,
|
376 |
+
post_norm=vision_post_norm,
|
377 |
+
activation=activation,
|
378 |
+
attn_dropout=attn_dropout,
|
379 |
+
proj_dropout=proj_dropout,
|
380 |
+
embedding_dropout=embedding_dropout,
|
381 |
+
norm_eps=norm_eps,
|
382 |
+
)
|
383 |
+
self.textual = XLMRobertaWithHead(
|
384 |
+
vocab_size=vocab_size,
|
385 |
+
max_seq_len=max_text_len,
|
386 |
+
type_size=type_size,
|
387 |
+
pad_id=pad_id,
|
388 |
+
dim=text_dim,
|
389 |
+
out_dim=embed_dim,
|
390 |
+
num_heads=text_heads,
|
391 |
+
num_layers=text_layers,
|
392 |
+
post_norm=text_post_norm,
|
393 |
+
dropout=text_dropout,
|
394 |
+
)
|
395 |
+
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
|
396 |
+
|
397 |
+
def forward(self, imgs, txt_ids):
|
398 |
+
"""
|
399 |
+
imgs: [B, 3, H, W] of torch.float32.
|
400 |
+
- mean: [0.48145466, 0.4578275, 0.40821073]
|
401 |
+
- std: [0.26862954, 0.26130258, 0.27577711]
|
402 |
+
txt_ids: [B, L] of torch.long.
|
403 |
+
Encoded by data.CLIPTokenizer.
|
404 |
+
"""
|
405 |
+
xi = self.visual(imgs)
|
406 |
+
xt = self.textual(txt_ids)
|
407 |
+
return xi, xt
|
408 |
+
|
409 |
+
def param_groups(self):
|
410 |
+
groups = [
|
411 |
+
{
|
412 |
+
"params": [p for n, p in self.named_parameters() if "norm" in n or n.endswith("bias")],
|
413 |
+
"weight_decay": 0.0,
|
414 |
+
},
|
415 |
+
{"params": [p for n, p in self.named_parameters() if not ("norm" in n or n.endswith("bias"))]},
|
416 |
+
]
|
417 |
+
return groups
|
418 |
+
|
419 |
+
|
420 |
+
def _clip(
|
421 |
+
pretrained=False,
|
422 |
+
pretrained_name=None,
|
423 |
+
model_cls=XLMRobertaCLIP,
|
424 |
+
return_transforms=False,
|
425 |
+
return_tokenizer=False,
|
426 |
+
tokenizer_padding="eos",
|
427 |
+
dtype=torch.float32,
|
428 |
+
device="cpu",
|
429 |
+
**kwargs,
|
430 |
+
):
|
431 |
+
# init a model on device
|
432 |
+
with torch.device(device):
|
433 |
+
model = model_cls(**kwargs)
|
434 |
+
|
435 |
+
# set device
|
436 |
+
model = model.to(dtype=dtype, device=device)
|
437 |
+
output = (model,)
|
438 |
+
|
439 |
+
# init transforms
|
440 |
+
if return_transforms:
|
441 |
+
# mean and std
|
442 |
+
if "siglip" in pretrained_name.lower():
|
443 |
+
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
|
444 |
+
else:
|
445 |
+
mean = [0.48145466, 0.4578275, 0.40821073]
|
446 |
+
std = [0.26862954, 0.26130258, 0.27577711]
|
447 |
+
|
448 |
+
# transforms
|
449 |
+
transforms = T.Compose(
|
450 |
+
[
|
451 |
+
T.Resize((model.image_size, model.image_size), interpolation=T.InterpolationMode.BICUBIC),
|
452 |
+
T.ToTensor(),
|
453 |
+
T.Normalize(mean=mean, std=std),
|
454 |
+
]
|
455 |
+
)
|
456 |
+
output += (transforms,)
|
457 |
+
return output[0] if len(output) == 1 else output
|
458 |
+
|
459 |
+
|
460 |
+
def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-roberta-large-vit-huge-14", **kwargs):
|
461 |
+
cfg = dict(
|
462 |
+
embed_dim=1024,
|
463 |
+
image_size=224,
|
464 |
+
patch_size=14,
|
465 |
+
vision_dim=1280,
|
466 |
+
vision_mlp_ratio=4,
|
467 |
+
vision_heads=16,
|
468 |
+
vision_layers=32,
|
469 |
+
vision_pool="token",
|
470 |
+
activation="gelu",
|
471 |
+
vocab_size=250002,
|
472 |
+
max_text_len=514,
|
473 |
+
type_size=1,
|
474 |
+
pad_id=1,
|
475 |
+
text_dim=1024,
|
476 |
+
text_heads=16,
|
477 |
+
text_layers=24,
|
478 |
+
text_post_norm=True,
|
479 |
+
text_dropout=0.1,
|
480 |
+
attn_dropout=0.0,
|
481 |
+
proj_dropout=0.0,
|
482 |
+
embedding_dropout=0.0,
|
483 |
+
)
|
484 |
+
cfg.update(**kwargs)
|
485 |
+
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
|
486 |
+
|
487 |
+
|
488 |
+
class CLIPModel(ModelMixin):
|
489 |
+
def __init__(self, checkpoint_path, tokenizer_path):
|
490 |
+
self.checkpoint_path = checkpoint_path
|
491 |
+
self.tokenizer_path = tokenizer_path
|
492 |
+
|
493 |
+
super().__init__()
|
494 |
+
# init model
|
495 |
+
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
|
496 |
+
pretrained=False, return_transforms=True, return_tokenizer=False
|
497 |
+
)
|
498 |
+
self.model = self.model.eval().requires_grad_(False)
|
499 |
+
logging.info(f"loading {checkpoint_path}")
|
500 |
+
self.model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
|
501 |
+
|
502 |
+
# init tokenizer
|
503 |
+
self.tokenizer = HuggingfaceTokenizer(
|
504 |
+
name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean="whitespace"
|
505 |
+
)
|
506 |
+
|
507 |
+
def encode_video(self, video):
|
508 |
+
# preprocess
|
509 |
+
b, c, t, h, w = video.shape
|
510 |
+
video = video.transpose(1, 2)
|
511 |
+
video = video.reshape(b * t, c, h, w)
|
512 |
+
size = (self.model.image_size,) * 2
|
513 |
+
video = F.interpolate(
|
514 |
+
video,
|
515 |
+
size=size,
|
516 |
+
mode='bicubic',
|
517 |
+
align_corners=False)
|
518 |
+
|
519 |
+
video = self.transforms.transforms[-1](video.mul_(0.5).add_(0.5))
|
520 |
+
|
521 |
+
# forward
|
522 |
+
with torch.amp.autocast(dtype=self.dtype, device_type=self.device.type):
|
523 |
+
out = self.model.visual(video, use_31_block=True)
|
524 |
+
|
525 |
+
return out
|
skyreels_v2_infer/modules/t5.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from transformers.models.t5.modeling_t5
|
2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from diffusers.models import ModelMixin
|
10 |
+
|
11 |
+
from .tokenizers import HuggingfaceTokenizer
|
12 |
+
|
13 |
+
__all__ = [
|
14 |
+
"T5Model",
|
15 |
+
"T5Encoder",
|
16 |
+
"T5Decoder",
|
17 |
+
"T5EncoderModel",
|
18 |
+
]
|
19 |
+
|
20 |
+
|
21 |
+
def fp16_clamp(x):
|
22 |
+
if x.dtype == torch.float16 and torch.isinf(x).any():
|
23 |
+
clamp = torch.finfo(x.dtype).max - 1000
|
24 |
+
x = torch.clamp(x, min=-clamp, max=clamp)
|
25 |
+
return x
|
26 |
+
|
27 |
+
|
28 |
+
def init_weights(m):
|
29 |
+
if isinstance(m, T5LayerNorm):
|
30 |
+
nn.init.ones_(m.weight)
|
31 |
+
elif isinstance(m, T5Model):
|
32 |
+
nn.init.normal_(m.token_embedding.weight, std=1.0)
|
33 |
+
elif isinstance(m, T5FeedForward):
|
34 |
+
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
|
35 |
+
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
|
36 |
+
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
|
37 |
+
elif isinstance(m, T5Attention):
|
38 |
+
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5)
|
39 |
+
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
|
40 |
+
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
|
41 |
+
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5)
|
42 |
+
elif isinstance(m, T5RelativeEmbedding):
|
43 |
+
nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5)
|
44 |
+
|
45 |
+
|
46 |
+
class GELU(nn.Module):
|
47 |
+
def forward(self, x):
|
48 |
+
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
49 |
+
|
50 |
+
|
51 |
+
class T5LayerNorm(nn.Module):
|
52 |
+
def __init__(self, dim, eps=1e-6):
|
53 |
+
super(T5LayerNorm, self).__init__()
|
54 |
+
self.dim = dim
|
55 |
+
self.eps = eps
|
56 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
60 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
61 |
+
x = x.type_as(self.weight)
|
62 |
+
return self.weight * x
|
63 |
+
|
64 |
+
|
65 |
+
class T5Attention(nn.Module):
|
66 |
+
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
|
67 |
+
assert dim_attn % num_heads == 0
|
68 |
+
super(T5Attention, self).__init__()
|
69 |
+
self.dim = dim
|
70 |
+
self.dim_attn = dim_attn
|
71 |
+
self.num_heads = num_heads
|
72 |
+
self.head_dim = dim_attn // num_heads
|
73 |
+
|
74 |
+
# layers
|
75 |
+
self.q = nn.Linear(dim, dim_attn, bias=False)
|
76 |
+
self.k = nn.Linear(dim, dim_attn, bias=False)
|
77 |
+
self.v = nn.Linear(dim, dim_attn, bias=False)
|
78 |
+
self.o = nn.Linear(dim_attn, dim, bias=False)
|
79 |
+
self.dropout = nn.Dropout(dropout)
|
80 |
+
|
81 |
+
def forward(self, x, context=None, mask=None, pos_bias=None):
|
82 |
+
"""
|
83 |
+
x: [B, L1, C].
|
84 |
+
context: [B, L2, C] or None.
|
85 |
+
mask: [B, L2] or [B, L1, L2] or None.
|
86 |
+
"""
|
87 |
+
# check inputs
|
88 |
+
context = x if context is None else context
|
89 |
+
b, n, c = x.size(0), self.num_heads, self.head_dim
|
90 |
+
|
91 |
+
# compute query, key, value
|
92 |
+
q = self.q(x).view(b, -1, n, c)
|
93 |
+
k = self.k(context).view(b, -1, n, c)
|
94 |
+
v = self.v(context).view(b, -1, n, c)
|
95 |
+
|
96 |
+
# attention bias
|
97 |
+
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
|
98 |
+
if pos_bias is not None:
|
99 |
+
attn_bias += pos_bias
|
100 |
+
if mask is not None:
|
101 |
+
assert mask.ndim in [2, 3]
|
102 |
+
mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1)
|
103 |
+
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
|
104 |
+
|
105 |
+
# compute attention (T5 does not use scaling)
|
106 |
+
attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias
|
107 |
+
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
108 |
+
x = torch.einsum("bnij,bjnc->binc", attn, v)
|
109 |
+
|
110 |
+
# output
|
111 |
+
x = x.reshape(b, -1, n * c)
|
112 |
+
x = self.o(x)
|
113 |
+
x = self.dropout(x)
|
114 |
+
return x
|
115 |
+
|
116 |
+
|
117 |
+
class T5FeedForward(nn.Module):
|
118 |
+
def __init__(self, dim, dim_ffn, dropout=0.1):
|
119 |
+
super(T5FeedForward, self).__init__()
|
120 |
+
self.dim = dim
|
121 |
+
self.dim_ffn = dim_ffn
|
122 |
+
|
123 |
+
# layers
|
124 |
+
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
|
125 |
+
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
126 |
+
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
127 |
+
self.dropout = nn.Dropout(dropout)
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
x = self.fc1(x) * self.gate(x)
|
131 |
+
x = self.dropout(x)
|
132 |
+
x = self.fc2(x)
|
133 |
+
x = self.dropout(x)
|
134 |
+
return x
|
135 |
+
|
136 |
+
|
137 |
+
class T5SelfAttention(nn.Module):
|
138 |
+
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1):
|
139 |
+
super(T5SelfAttention, self).__init__()
|
140 |
+
self.dim = dim
|
141 |
+
self.dim_attn = dim_attn
|
142 |
+
self.dim_ffn = dim_ffn
|
143 |
+
self.num_heads = num_heads
|
144 |
+
self.num_buckets = num_buckets
|
145 |
+
self.shared_pos = shared_pos
|
146 |
+
|
147 |
+
# layers
|
148 |
+
self.norm1 = T5LayerNorm(dim)
|
149 |
+
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
150 |
+
self.norm2 = T5LayerNorm(dim)
|
151 |
+
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
152 |
+
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
|
153 |
+
|
154 |
+
def forward(self, x, mask=None, pos_bias=None):
|
155 |
+
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
|
156 |
+
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
157 |
+
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
158 |
+
return x
|
159 |
+
|
160 |
+
|
161 |
+
class T5CrossAttention(nn.Module):
|
162 |
+
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1):
|
163 |
+
super(T5CrossAttention, self).__init__()
|
164 |
+
self.dim = dim
|
165 |
+
self.dim_attn = dim_attn
|
166 |
+
self.dim_ffn = dim_ffn
|
167 |
+
self.num_heads = num_heads
|
168 |
+
self.num_buckets = num_buckets
|
169 |
+
self.shared_pos = shared_pos
|
170 |
+
|
171 |
+
# layers
|
172 |
+
self.norm1 = T5LayerNorm(dim)
|
173 |
+
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
174 |
+
self.norm2 = T5LayerNorm(dim)
|
175 |
+
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
176 |
+
self.norm3 = T5LayerNorm(dim)
|
177 |
+
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
178 |
+
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False)
|
179 |
+
|
180 |
+
def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None):
|
181 |
+
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
|
182 |
+
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
|
183 |
+
x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask))
|
184 |
+
x = fp16_clamp(x + self.ffn(self.norm3(x)))
|
185 |
+
return x
|
186 |
+
|
187 |
+
|
188 |
+
class T5RelativeEmbedding(nn.Module):
|
189 |
+
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
190 |
+
super(T5RelativeEmbedding, self).__init__()
|
191 |
+
self.num_buckets = num_buckets
|
192 |
+
self.num_heads = num_heads
|
193 |
+
self.bidirectional = bidirectional
|
194 |
+
self.max_dist = max_dist
|
195 |
+
|
196 |
+
# layers
|
197 |
+
self.embedding = nn.Embedding(num_buckets, num_heads)
|
198 |
+
|
199 |
+
def forward(self, lq, lk):
|
200 |
+
device = self.embedding.weight.device
|
201 |
+
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
|
202 |
+
# torch.arange(lq).unsqueeze(1).to(device)
|
203 |
+
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1)
|
204 |
+
rel_pos = self._relative_position_bucket(rel_pos)
|
205 |
+
rel_pos_embeds = self.embedding(rel_pos)
|
206 |
+
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk]
|
207 |
+
return rel_pos_embeds.contiguous()
|
208 |
+
|
209 |
+
def _relative_position_bucket(self, rel_pos):
|
210 |
+
# preprocess
|
211 |
+
if self.bidirectional:
|
212 |
+
num_buckets = self.num_buckets // 2
|
213 |
+
rel_buckets = (rel_pos > 0).long() * num_buckets
|
214 |
+
rel_pos = torch.abs(rel_pos)
|
215 |
+
else:
|
216 |
+
num_buckets = self.num_buckets
|
217 |
+
rel_buckets = 0
|
218 |
+
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
|
219 |
+
|
220 |
+
# embeddings for small and large positions
|
221 |
+
max_exact = num_buckets // 2
|
222 |
+
rel_pos_large = (
|
223 |
+
max_exact
|
224 |
+
+ (
|
225 |
+
torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)
|
226 |
+
).long()
|
227 |
+
)
|
228 |
+
rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
|
229 |
+
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
|
230 |
+
return rel_buckets
|
231 |
+
|
232 |
+
|
233 |
+
class T5Encoder(nn.Module):
|
234 |
+
def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1):
|
235 |
+
super(T5Encoder, self).__init__()
|
236 |
+
self.dim = dim
|
237 |
+
self.dim_attn = dim_attn
|
238 |
+
self.dim_ffn = dim_ffn
|
239 |
+
self.num_heads = num_heads
|
240 |
+
self.num_layers = num_layers
|
241 |
+
self.num_buckets = num_buckets
|
242 |
+
self.shared_pos = shared_pos
|
243 |
+
|
244 |
+
# layers
|
245 |
+
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
|
246 |
+
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
247 |
+
self.dropout = nn.Dropout(dropout)
|
248 |
+
self.blocks = nn.ModuleList(
|
249 |
+
[
|
250 |
+
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout)
|
251 |
+
for _ in range(num_layers)
|
252 |
+
]
|
253 |
+
)
|
254 |
+
self.norm = T5LayerNorm(dim)
|
255 |
+
|
256 |
+
# initialize weights
|
257 |
+
self.apply(init_weights)
|
258 |
+
|
259 |
+
def forward(self, ids, mask=None):
|
260 |
+
x = self.token_embedding(ids)
|
261 |
+
x = self.dropout(x)
|
262 |
+
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
|
263 |
+
for block in self.blocks:
|
264 |
+
x = block(x, mask, pos_bias=e)
|
265 |
+
x = self.norm(x)
|
266 |
+
x = self.dropout(x)
|
267 |
+
return x
|
268 |
+
|
269 |
+
|
270 |
+
class T5Decoder(nn.Module):
|
271 |
+
def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1):
|
272 |
+
super(T5Decoder, self).__init__()
|
273 |
+
self.dim = dim
|
274 |
+
self.dim_attn = dim_attn
|
275 |
+
self.dim_ffn = dim_ffn
|
276 |
+
self.num_heads = num_heads
|
277 |
+
self.num_layers = num_layers
|
278 |
+
self.num_buckets = num_buckets
|
279 |
+
self.shared_pos = shared_pos
|
280 |
+
|
281 |
+
# layers
|
282 |
+
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
|
283 |
+
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None
|
284 |
+
self.dropout = nn.Dropout(dropout)
|
285 |
+
self.blocks = nn.ModuleList(
|
286 |
+
[
|
287 |
+
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout)
|
288 |
+
for _ in range(num_layers)
|
289 |
+
]
|
290 |
+
)
|
291 |
+
self.norm = T5LayerNorm(dim)
|
292 |
+
|
293 |
+
# initialize weights
|
294 |
+
self.apply(init_weights)
|
295 |
+
|
296 |
+
def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
|
297 |
+
b, s = ids.size()
|
298 |
+
|
299 |
+
# causal mask
|
300 |
+
if mask is None:
|
301 |
+
mask = torch.tril(torch.ones(1, s, s).to(ids.device))
|
302 |
+
elif mask.ndim == 2:
|
303 |
+
mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
|
304 |
+
|
305 |
+
# layers
|
306 |
+
x = self.token_embedding(ids)
|
307 |
+
x = self.dropout(x)
|
308 |
+
e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None
|
309 |
+
for block in self.blocks:
|
310 |
+
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
|
311 |
+
x = self.norm(x)
|
312 |
+
x = self.dropout(x)
|
313 |
+
return x
|
314 |
+
|
315 |
+
|
316 |
+
class T5Model(nn.Module):
|
317 |
+
def __init__(
|
318 |
+
self,
|
319 |
+
vocab_size,
|
320 |
+
dim,
|
321 |
+
dim_attn,
|
322 |
+
dim_ffn,
|
323 |
+
num_heads,
|
324 |
+
encoder_layers,
|
325 |
+
decoder_layers,
|
326 |
+
num_buckets,
|
327 |
+
shared_pos=True,
|
328 |
+
dropout=0.1,
|
329 |
+
):
|
330 |
+
super(T5Model, self).__init__()
|
331 |
+
self.vocab_size = vocab_size
|
332 |
+
self.dim = dim
|
333 |
+
self.dim_attn = dim_attn
|
334 |
+
self.dim_ffn = dim_ffn
|
335 |
+
self.num_heads = num_heads
|
336 |
+
self.encoder_layers = encoder_layers
|
337 |
+
self.decoder_layers = decoder_layers
|
338 |
+
self.num_buckets = num_buckets
|
339 |
+
|
340 |
+
# layers
|
341 |
+
self.token_embedding = nn.Embedding(vocab_size, dim)
|
342 |
+
self.encoder = T5Encoder(
|
343 |
+
self.token_embedding, dim, dim_attn, dim_ffn, num_heads, encoder_layers, num_buckets, shared_pos, dropout
|
344 |
+
)
|
345 |
+
self.decoder = T5Decoder(
|
346 |
+
self.token_embedding, dim, dim_attn, dim_ffn, num_heads, decoder_layers, num_buckets, shared_pos, dropout
|
347 |
+
)
|
348 |
+
self.head = nn.Linear(dim, vocab_size, bias=False)
|
349 |
+
|
350 |
+
# initialize weights
|
351 |
+
self.apply(init_weights)
|
352 |
+
|
353 |
+
def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
|
354 |
+
x = self.encoder(encoder_ids, encoder_mask)
|
355 |
+
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
|
356 |
+
x = self.head(x)
|
357 |
+
return x
|
358 |
+
|
359 |
+
|
360 |
+
def _t5(
|
361 |
+
name,
|
362 |
+
encoder_only=False,
|
363 |
+
decoder_only=False,
|
364 |
+
return_tokenizer=False,
|
365 |
+
tokenizer_kwargs={},
|
366 |
+
dtype=torch.float32,
|
367 |
+
device="cpu",
|
368 |
+
**kwargs,
|
369 |
+
):
|
370 |
+
# sanity check
|
371 |
+
assert not (encoder_only and decoder_only)
|
372 |
+
|
373 |
+
# params
|
374 |
+
if encoder_only:
|
375 |
+
model_cls = T5Encoder
|
376 |
+
kwargs["vocab"] = kwargs.pop("vocab_size")
|
377 |
+
kwargs["num_layers"] = kwargs.pop("encoder_layers")
|
378 |
+
_ = kwargs.pop("decoder_layers")
|
379 |
+
elif decoder_only:
|
380 |
+
model_cls = T5Decoder
|
381 |
+
kwargs["vocab"] = kwargs.pop("vocab_size")
|
382 |
+
kwargs["num_layers"] = kwargs.pop("decoder_layers")
|
383 |
+
_ = kwargs.pop("encoder_layers")
|
384 |
+
else:
|
385 |
+
model_cls = T5Model
|
386 |
+
|
387 |
+
# init model
|
388 |
+
with torch.device(device):
|
389 |
+
model = model_cls(**kwargs)
|
390 |
+
|
391 |
+
# set device
|
392 |
+
model = model.to(dtype=dtype, device=device)
|
393 |
+
|
394 |
+
# init tokenizer
|
395 |
+
if return_tokenizer:
|
396 |
+
from .tokenizers import HuggingfaceTokenizer
|
397 |
+
|
398 |
+
tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs)
|
399 |
+
return model, tokenizer
|
400 |
+
else:
|
401 |
+
return model
|
402 |
+
|
403 |
+
|
404 |
+
def umt5_xxl(**kwargs):
|
405 |
+
cfg = dict(
|
406 |
+
vocab_size=256384,
|
407 |
+
dim=4096,
|
408 |
+
dim_attn=4096,
|
409 |
+
dim_ffn=10240,
|
410 |
+
num_heads=64,
|
411 |
+
encoder_layers=24,
|
412 |
+
decoder_layers=24,
|
413 |
+
num_buckets=32,
|
414 |
+
shared_pos=False,
|
415 |
+
dropout=0.1,
|
416 |
+
)
|
417 |
+
cfg.update(**kwargs)
|
418 |
+
return _t5("umt5-xxl", **cfg)
|
419 |
+
|
420 |
+
|
421 |
+
class T5EncoderModel(ModelMixin):
|
422 |
+
def __init__(
|
423 |
+
self,
|
424 |
+
checkpoint_path=None,
|
425 |
+
tokenizer_path=None,
|
426 |
+
text_len=512,
|
427 |
+
shard_fn=None,
|
428 |
+
):
|
429 |
+
self.text_len = text_len
|
430 |
+
self.checkpoint_path = checkpoint_path
|
431 |
+
self.tokenizer_path = tokenizer_path
|
432 |
+
|
433 |
+
super().__init__()
|
434 |
+
# init model
|
435 |
+
model = umt5_xxl(encoder_only=True, return_tokenizer=False)
|
436 |
+
logging.info(f"loading {checkpoint_path}")
|
437 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
|
438 |
+
self.model = model
|
439 |
+
if shard_fn is not None:
|
440 |
+
self.model = shard_fn(self.model, sync_module_states=False)
|
441 |
+
else:
|
442 |
+
self.model.eval().requires_grad_(False)
|
443 |
+
# init tokenizer
|
444 |
+
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace")
|
445 |
+
|
446 |
+
def encode(self, texts):
|
447 |
+
ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
|
448 |
+
ids = ids.to(self.device)
|
449 |
+
mask = mask.to(self.device)
|
450 |
+
# seq_lens = mask.gt(0).sum(dim=1).long()
|
451 |
+
context = self.model(ids, mask)
|
452 |
+
context = context * mask.unsqueeze(-1).cuda()
|
453 |
+
|
454 |
+
return context
|
skyreels_v2_infer/modules/tokenizers.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
2 |
+
import html
|
3 |
+
import string
|
4 |
+
|
5 |
+
import ftfy
|
6 |
+
import regex as re
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
|
9 |
+
__all__ = ["HuggingfaceTokenizer"]
|
10 |
+
|
11 |
+
|
12 |
+
def basic_clean(text):
|
13 |
+
text = ftfy.fix_text(text)
|
14 |
+
text = html.unescape(html.unescape(text))
|
15 |
+
return text.strip()
|
16 |
+
|
17 |
+
|
18 |
+
def whitespace_clean(text):
|
19 |
+
text = re.sub(r"\s+", " ", text)
|
20 |
+
text = text.strip()
|
21 |
+
return text
|
22 |
+
|
23 |
+
|
24 |
+
def canonicalize(text, keep_punctuation_exact_string=None):
|
25 |
+
text = text.replace("_", " ")
|
26 |
+
if keep_punctuation_exact_string:
|
27 |
+
text = keep_punctuation_exact_string.join(
|
28 |
+
part.translate(str.maketrans("", "", string.punctuation))
|
29 |
+
for part in text.split(keep_punctuation_exact_string)
|
30 |
+
)
|
31 |
+
else:
|
32 |
+
text = text.translate(str.maketrans("", "", string.punctuation))
|
33 |
+
text = text.lower()
|
34 |
+
text = re.sub(r"\s+", " ", text)
|
35 |
+
return text.strip()
|
36 |
+
|
37 |
+
|
38 |
+
class HuggingfaceTokenizer:
|
39 |
+
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
40 |
+
assert clean in (None, "whitespace", "lower", "canonicalize")
|
41 |
+
self.name = name
|
42 |
+
self.seq_len = seq_len
|
43 |
+
self.clean = clean
|
44 |
+
|
45 |
+
# init tokenizer
|
46 |
+
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
47 |
+
self.vocab_size = self.tokenizer.vocab_size
|
48 |
+
|
49 |
+
def __call__(self, sequence, **kwargs):
|
50 |
+
return_mask = kwargs.pop("return_mask", False)
|
51 |
+
|
52 |
+
# arguments
|
53 |
+
_kwargs = {"return_tensors": "pt"}
|
54 |
+
if self.seq_len is not None:
|
55 |
+
_kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len})
|
56 |
+
_kwargs.update(**kwargs)
|
57 |
+
|
58 |
+
# tokenization
|
59 |
+
if isinstance(sequence, str):
|
60 |
+
sequence = [sequence]
|
61 |
+
if self.clean:
|
62 |
+
sequence = [self._clean(u) for u in sequence]
|
63 |
+
ids = self.tokenizer(sequence, **_kwargs)
|
64 |
+
|
65 |
+
# output
|
66 |
+
if return_mask:
|
67 |
+
return ids.input_ids, ids.attention_mask
|
68 |
+
else:
|
69 |
+
return ids.input_ids
|
70 |
+
|
71 |
+
def _clean(self, text):
|
72 |
+
if self.clean == "whitespace":
|
73 |
+
text = whitespace_clean(basic_clean(text))
|
74 |
+
elif self.clean == "lower":
|
75 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
76 |
+
elif self.clean == "canonicalize":
|
77 |
+
text = canonicalize(basic_clean(text))
|
78 |
+
return text
|
skyreels_v2_infer/modules/transformer.py
ADDED
@@ -0,0 +1,839 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.amp as amp
|
6 |
+
import torch.nn as nn
|
7 |
+
from diffusers.configuration_utils import ConfigMixin
|
8 |
+
from diffusers.configuration_utils import register_to_config
|
9 |
+
from diffusers.loaders import PeftAdapterMixin
|
10 |
+
from diffusers.models.modeling_utils import ModelMixin
|
11 |
+
from torch.backends.cuda import sdp_kernel
|
12 |
+
from torch.nn.attention.flex_attention import BlockMask
|
13 |
+
from torch.nn.attention.flex_attention import create_block_mask
|
14 |
+
from torch.nn.attention.flex_attention import flex_attention
|
15 |
+
|
16 |
+
from .attention import flash_attention
|
17 |
+
|
18 |
+
|
19 |
+
flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune")
|
20 |
+
|
21 |
+
DISABLE_COMPILE = False # get os env
|
22 |
+
|
23 |
+
__all__ = ["WanModel"]
|
24 |
+
|
25 |
+
|
26 |
+
def sinusoidal_embedding_1d(dim, position):
|
27 |
+
# preprocess
|
28 |
+
assert dim % 2 == 0
|
29 |
+
half = dim // 2
|
30 |
+
position = position.type(torch.float64)
|
31 |
+
|
32 |
+
# calculation
|
33 |
+
sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
|
34 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
35 |
+
return x
|
36 |
+
|
37 |
+
|
38 |
+
@amp.autocast("cuda", enabled=False)
|
39 |
+
def rope_params(max_seq_len, dim, theta=10000):
|
40 |
+
assert dim % 2 == 0
|
41 |
+
freqs = torch.outer(
|
42 |
+
torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))
|
43 |
+
)
|
44 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
45 |
+
return freqs
|
46 |
+
|
47 |
+
|
48 |
+
@amp.autocast("cuda", enabled=False)
|
49 |
+
def rope_apply(x, grid_sizes, freqs):
|
50 |
+
n, c = x.size(2), x.size(3) // 2
|
51 |
+
bs = x.size(0)
|
52 |
+
|
53 |
+
# split freqs
|
54 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
55 |
+
|
56 |
+
# loop over samples
|
57 |
+
f, h, w = grid_sizes.tolist()
|
58 |
+
seq_len = f * h * w
|
59 |
+
|
60 |
+
# precompute multipliers
|
61 |
+
|
62 |
+
x = torch.view_as_complex(x.to(torch.float32).reshape(bs, seq_len, n, -1, 2))
|
63 |
+
freqs_i = torch.cat(
|
64 |
+
[
|
65 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
66 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
67 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
|
68 |
+
],
|
69 |
+
dim=-1,
|
70 |
+
).reshape(seq_len, 1, -1)
|
71 |
+
|
72 |
+
# apply rotary embedding
|
73 |
+
x = torch.view_as_real(x * freqs_i).flatten(3)
|
74 |
+
|
75 |
+
return x
|
76 |
+
|
77 |
+
|
78 |
+
@torch.compile(dynamic=True, disable=DISABLE_COMPILE)
|
79 |
+
def fast_rms_norm(x, weight, eps):
|
80 |
+
x = x.float()
|
81 |
+
x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps)
|
82 |
+
x = x.type_as(x) * weight
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
class WanRMSNorm(nn.Module):
|
87 |
+
def __init__(self, dim, eps=1e-5):
|
88 |
+
super().__init__()
|
89 |
+
self.dim = dim
|
90 |
+
self.eps = eps
|
91 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
r"""
|
95 |
+
Args:
|
96 |
+
x(Tensor): Shape [B, L, C]
|
97 |
+
"""
|
98 |
+
return fast_rms_norm(x, self.weight, self.eps)
|
99 |
+
|
100 |
+
def _norm(self, x):
|
101 |
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
102 |
+
|
103 |
+
|
104 |
+
class WanLayerNorm(nn.LayerNorm):
|
105 |
+
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
|
106 |
+
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
|
107 |
+
|
108 |
+
def forward(self, x):
|
109 |
+
r"""
|
110 |
+
Args:
|
111 |
+
x(Tensor): Shape [B, L, C]
|
112 |
+
"""
|
113 |
+
return super().forward(x)
|
114 |
+
|
115 |
+
|
116 |
+
class WanSelfAttention(nn.Module):
|
117 |
+
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
|
118 |
+
assert dim % num_heads == 0
|
119 |
+
super().__init__()
|
120 |
+
self.dim = dim
|
121 |
+
self.num_heads = num_heads
|
122 |
+
self.head_dim = dim // num_heads
|
123 |
+
self.window_size = window_size
|
124 |
+
self.qk_norm = qk_norm
|
125 |
+
self.eps = eps
|
126 |
+
|
127 |
+
# layers
|
128 |
+
self.q = nn.Linear(dim, dim)
|
129 |
+
self.k = nn.Linear(dim, dim)
|
130 |
+
self.v = nn.Linear(dim, dim)
|
131 |
+
self.o = nn.Linear(dim, dim)
|
132 |
+
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
133 |
+
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
134 |
+
|
135 |
+
self._flag_ar_attention = False
|
136 |
+
|
137 |
+
def set_ar_attention(self):
|
138 |
+
self._flag_ar_attention = True
|
139 |
+
|
140 |
+
def forward(self, x, grid_sizes, freqs, block_mask):
|
141 |
+
r"""
|
142 |
+
Args:
|
143 |
+
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
144 |
+
seq_lens(Tensor): Shape [B]
|
145 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
146 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
147 |
+
"""
|
148 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
149 |
+
|
150 |
+
# query, key, value function
|
151 |
+
def qkv_fn(x):
|
152 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
153 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
154 |
+
v = self.v(x).view(b, s, n, d)
|
155 |
+
return q, k, v
|
156 |
+
|
157 |
+
x = x.to(self.q.weight.dtype)
|
158 |
+
q, k, v = qkv_fn(x)
|
159 |
+
|
160 |
+
if not self._flag_ar_attention:
|
161 |
+
q = rope_apply(q, grid_sizes, freqs)
|
162 |
+
k = rope_apply(k, grid_sizes, freqs)
|
163 |
+
x = flash_attention(q=q, k=k, v=v, window_size=self.window_size)
|
164 |
+
else:
|
165 |
+
q = rope_apply(q, grid_sizes, freqs)
|
166 |
+
k = rope_apply(k, grid_sizes, freqs)
|
167 |
+
q = q.to(torch.bfloat16)
|
168 |
+
k = k.to(torch.bfloat16)
|
169 |
+
v = v.to(torch.bfloat16)
|
170 |
+
|
171 |
+
with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
172 |
+
x = (
|
173 |
+
torch.nn.functional.scaled_dot_product_attention(
|
174 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
|
175 |
+
)
|
176 |
+
.transpose(1, 2)
|
177 |
+
.contiguous()
|
178 |
+
)
|
179 |
+
|
180 |
+
# output
|
181 |
+
x = x.flatten(2)
|
182 |
+
x = self.o(x)
|
183 |
+
return x
|
184 |
+
|
185 |
+
|
186 |
+
class WanT2VCrossAttention(WanSelfAttention):
|
187 |
+
def forward(self, x, context):
|
188 |
+
r"""
|
189 |
+
Args:
|
190 |
+
x(Tensor): Shape [B, L1, C]
|
191 |
+
context(Tensor): Shape [B, L2, C]
|
192 |
+
context_lens(Tensor): Shape [B]
|
193 |
+
"""
|
194 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
195 |
+
|
196 |
+
# compute query, key, value
|
197 |
+
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
198 |
+
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
199 |
+
v = self.v(context).view(b, -1, n, d)
|
200 |
+
|
201 |
+
# compute attention
|
202 |
+
x = flash_attention(q, k, v)
|
203 |
+
|
204 |
+
# output
|
205 |
+
x = x.flatten(2)
|
206 |
+
x = self.o(x)
|
207 |
+
return x
|
208 |
+
|
209 |
+
|
210 |
+
class WanI2VCrossAttention(WanSelfAttention):
|
211 |
+
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
|
212 |
+
super().__init__(dim, num_heads, window_size, qk_norm, eps)
|
213 |
+
|
214 |
+
self.k_img = nn.Linear(dim, dim)
|
215 |
+
self.v_img = nn.Linear(dim, dim)
|
216 |
+
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
217 |
+
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
218 |
+
|
219 |
+
def forward(self, x, context):
|
220 |
+
r"""
|
221 |
+
Args:
|
222 |
+
x(Tensor): Shape [B, L1, C]
|
223 |
+
context(Tensor): Shape [B, L2, C]
|
224 |
+
context_lens(Tensor): Shape [B]
|
225 |
+
"""
|
226 |
+
context_img = context[:, :257]
|
227 |
+
context = context[:, 257:]
|
228 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
229 |
+
|
230 |
+
# compute query, key, value
|
231 |
+
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
232 |
+
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
233 |
+
v = self.v(context).view(b, -1, n, d)
|
234 |
+
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
|
235 |
+
v_img = self.v_img(context_img).view(b, -1, n, d)
|
236 |
+
img_x = flash_attention(q, k_img, v_img)
|
237 |
+
# compute attention
|
238 |
+
x = flash_attention(q, k, v)
|
239 |
+
|
240 |
+
# output
|
241 |
+
x = x.flatten(2)
|
242 |
+
img_x = img_x.flatten(2)
|
243 |
+
x = x + img_x
|
244 |
+
x = self.o(x)
|
245 |
+
return x
|
246 |
+
|
247 |
+
|
248 |
+
WAN_CROSSATTENTION_CLASSES = {
|
249 |
+
"t2v_cross_attn": WanT2VCrossAttention,
|
250 |
+
"i2v_cross_attn": WanI2VCrossAttention,
|
251 |
+
}
|
252 |
+
|
253 |
+
|
254 |
+
def mul_add(x, y, z):
|
255 |
+
return x.float() + y.float() * z.float()
|
256 |
+
|
257 |
+
|
258 |
+
def mul_add_add(x, y, z):
|
259 |
+
return x.float() * (1 + y) + z
|
260 |
+
|
261 |
+
|
262 |
+
mul_add_compile = torch.compile(mul_add, dynamic=True, disable=DISABLE_COMPILE)
|
263 |
+
mul_add_add_compile = torch.compile(mul_add_add, dynamic=True, disable=DISABLE_COMPILE)
|
264 |
+
|
265 |
+
|
266 |
+
class WanAttentionBlock(nn.Module):
|
267 |
+
def __init__(
|
268 |
+
self,
|
269 |
+
cross_attn_type,
|
270 |
+
dim,
|
271 |
+
ffn_dim,
|
272 |
+
num_heads,
|
273 |
+
window_size=(-1, -1),
|
274 |
+
qk_norm=True,
|
275 |
+
cross_attn_norm=False,
|
276 |
+
eps=1e-6,
|
277 |
+
):
|
278 |
+
super().__init__()
|
279 |
+
self.dim = dim
|
280 |
+
self.ffn_dim = ffn_dim
|
281 |
+
self.num_heads = num_heads
|
282 |
+
self.window_size = window_size
|
283 |
+
self.qk_norm = qk_norm
|
284 |
+
self.cross_attn_norm = cross_attn_norm
|
285 |
+
self.eps = eps
|
286 |
+
|
287 |
+
# layers
|
288 |
+
self.norm1 = WanLayerNorm(dim, eps)
|
289 |
+
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
|
290 |
+
self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
291 |
+
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, num_heads, (-1, -1), qk_norm, eps)
|
292 |
+
self.norm2 = WanLayerNorm(dim, eps)
|
293 |
+
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim))
|
294 |
+
|
295 |
+
# modulation
|
296 |
+
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
297 |
+
|
298 |
+
def set_ar_attention(self):
|
299 |
+
self.self_attn.set_ar_attention()
|
300 |
+
|
301 |
+
def forward(
|
302 |
+
self,
|
303 |
+
x,
|
304 |
+
e,
|
305 |
+
grid_sizes,
|
306 |
+
freqs,
|
307 |
+
context,
|
308 |
+
block_mask,
|
309 |
+
):
|
310 |
+
r"""
|
311 |
+
Args:
|
312 |
+
x(Tensor): Shape [B, L, C]
|
313 |
+
e(Tensor): Shape [B, 6, C]
|
314 |
+
seq_lens(Tensor): Shape [B], length of each sequence in batch
|
315 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
316 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
317 |
+
"""
|
318 |
+
if e.dim() == 3:
|
319 |
+
modulation = self.modulation # 1, 6, dim
|
320 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
321 |
+
e = (modulation + e).chunk(6, dim=1)
|
322 |
+
elif e.dim() == 4:
|
323 |
+
modulation = self.modulation.unsqueeze(2) # 1, 6, 1, dim
|
324 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
325 |
+
e = (modulation + e).chunk(6, dim=1)
|
326 |
+
e = [ei.squeeze(1) for ei in e]
|
327 |
+
|
328 |
+
# self-attention
|
329 |
+
out = mul_add_add_compile(self.norm1(x), e[1], e[0])
|
330 |
+
y = self.self_attn(out, grid_sizes, freqs, block_mask)
|
331 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
332 |
+
x = mul_add_compile(x, y, e[2])
|
333 |
+
|
334 |
+
# cross-attention & ffn function
|
335 |
+
def cross_attn_ffn(x, context, e):
|
336 |
+
dtype = context.dtype
|
337 |
+
x = x + self.cross_attn(self.norm3(x.to(dtype)), context)
|
338 |
+
y = self.ffn(mul_add_add_compile(self.norm2(x), e[4], e[3]).to(dtype))
|
339 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
340 |
+
x = mul_add_compile(x, y, e[5])
|
341 |
+
return x
|
342 |
+
|
343 |
+
x = cross_attn_ffn(x, context, e)
|
344 |
+
return x.to(torch.bfloat16)
|
345 |
+
|
346 |
+
|
347 |
+
class Head(nn.Module):
|
348 |
+
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
349 |
+
super().__init__()
|
350 |
+
self.dim = dim
|
351 |
+
self.out_dim = out_dim
|
352 |
+
self.patch_size = patch_size
|
353 |
+
self.eps = eps
|
354 |
+
|
355 |
+
# layers
|
356 |
+
out_dim = math.prod(patch_size) * out_dim
|
357 |
+
self.norm = WanLayerNorm(dim, eps)
|
358 |
+
self.head = nn.Linear(dim, out_dim)
|
359 |
+
|
360 |
+
# modulation
|
361 |
+
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
362 |
+
|
363 |
+
def forward(self, x, e):
|
364 |
+
r"""
|
365 |
+
Args:
|
366 |
+
x(Tensor): Shape [B, L1, C]
|
367 |
+
e(Tensor): Shape [B, C]
|
368 |
+
"""
|
369 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
370 |
+
if e.dim() == 2:
|
371 |
+
modulation = self.modulation # 1, 2, dim
|
372 |
+
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
373 |
+
|
374 |
+
elif e.dim() == 3:
|
375 |
+
modulation = self.modulation.unsqueeze(2) # 1, 2, seq, dim
|
376 |
+
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
377 |
+
e = [ei.squeeze(1) for ei in e]
|
378 |
+
x = self.head(self.norm(x) * (1 + e[1]) + e[0])
|
379 |
+
return x
|
380 |
+
|
381 |
+
|
382 |
+
class MLPProj(torch.nn.Module):
|
383 |
+
def __init__(self, in_dim, out_dim):
|
384 |
+
super().__init__()
|
385 |
+
|
386 |
+
self.proj = torch.nn.Sequential(
|
387 |
+
torch.nn.LayerNorm(in_dim),
|
388 |
+
torch.nn.Linear(in_dim, in_dim),
|
389 |
+
torch.nn.GELU(),
|
390 |
+
torch.nn.Linear(in_dim, out_dim),
|
391 |
+
torch.nn.LayerNorm(out_dim),
|
392 |
+
)
|
393 |
+
|
394 |
+
def forward(self, image_embeds):
|
395 |
+
clip_extra_context_tokens = self.proj(image_embeds)
|
396 |
+
return clip_extra_context_tokens
|
397 |
+
|
398 |
+
|
399 |
+
class WanModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
400 |
+
r"""
|
401 |
+
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
402 |
+
"""
|
403 |
+
|
404 |
+
ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim", "window_size"]
|
405 |
+
_no_split_modules = ["WanAttentionBlock"]
|
406 |
+
|
407 |
+
_supports_gradient_checkpointing = True
|
408 |
+
|
409 |
+
@register_to_config
|
410 |
+
def __init__(
|
411 |
+
self,
|
412 |
+
model_type="t2v",
|
413 |
+
patch_size=(1, 2, 2),
|
414 |
+
text_len=512,
|
415 |
+
in_dim=16,
|
416 |
+
dim=2048,
|
417 |
+
ffn_dim=8192,
|
418 |
+
freq_dim=256,
|
419 |
+
text_dim=4096,
|
420 |
+
out_dim=16,
|
421 |
+
num_heads=16,
|
422 |
+
num_layers=32,
|
423 |
+
window_size=(-1, -1),
|
424 |
+
qk_norm=True,
|
425 |
+
cross_attn_norm=True,
|
426 |
+
inject_sample_info=False,
|
427 |
+
eps=1e-6,
|
428 |
+
):
|
429 |
+
r"""
|
430 |
+
Initialize the diffusion model backbone.
|
431 |
+
|
432 |
+
Args:
|
433 |
+
model_type (`str`, *optional*, defaults to 't2v'):
|
434 |
+
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
|
435 |
+
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
436 |
+
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
437 |
+
text_len (`int`, *optional*, defaults to 512):
|
438 |
+
Fixed length for text embeddings
|
439 |
+
in_dim (`int`, *optional*, defaults to 16):
|
440 |
+
Input video channels (C_in)
|
441 |
+
dim (`int`, *optional*, defaults to 2048):
|
442 |
+
Hidden dimension of the transformer
|
443 |
+
ffn_dim (`int`, *optional*, defaults to 8192):
|
444 |
+
Intermediate dimension in feed-forward network
|
445 |
+
freq_dim (`int`, *optional*, defaults to 256):
|
446 |
+
Dimension for sinusoidal time embeddings
|
447 |
+
text_dim (`int`, *optional*, defaults to 4096):
|
448 |
+
Input dimension for text embeddings
|
449 |
+
out_dim (`int`, *optional*, defaults to 16):
|
450 |
+
Output video channels (C_out)
|
451 |
+
num_heads (`int`, *optional*, defaults to 16):
|
452 |
+
Number of attention heads
|
453 |
+
num_layers (`int`, *optional*, defaults to 32):
|
454 |
+
Number of transformer blocks
|
455 |
+
window_size (`tuple`, *optional*, defaults to (-1, -1)):
|
456 |
+
Window size for local attention (-1 indicates global attention)
|
457 |
+
qk_norm (`bool`, *optional*, defaults to True):
|
458 |
+
Enable query/key normalization
|
459 |
+
cross_attn_norm (`bool`, *optional*, defaults to False):
|
460 |
+
Enable cross-attention normalization
|
461 |
+
eps (`float`, *optional*, defaults to 1e-6):
|
462 |
+
Epsilon value for normalization layers
|
463 |
+
"""
|
464 |
+
|
465 |
+
super().__init__()
|
466 |
+
|
467 |
+
assert model_type in ["t2v", "i2v"]
|
468 |
+
self.model_type = model_type
|
469 |
+
|
470 |
+
self.patch_size = patch_size
|
471 |
+
self.text_len = text_len
|
472 |
+
self.in_dim = in_dim
|
473 |
+
self.dim = dim
|
474 |
+
self.ffn_dim = ffn_dim
|
475 |
+
self.freq_dim = freq_dim
|
476 |
+
self.text_dim = text_dim
|
477 |
+
self.out_dim = out_dim
|
478 |
+
self.num_heads = num_heads
|
479 |
+
self.num_layers = num_layers
|
480 |
+
self.window_size = window_size
|
481 |
+
self.qk_norm = qk_norm
|
482 |
+
self.cross_attn_norm = cross_attn_norm
|
483 |
+
self.eps = eps
|
484 |
+
self.num_frame_per_block = 1
|
485 |
+
self.flag_causal_attention = False
|
486 |
+
self.block_mask = None
|
487 |
+
self.enable_teacache = False
|
488 |
+
|
489 |
+
# embeddings
|
490 |
+
self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
491 |
+
self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim))
|
492 |
+
|
493 |
+
self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
494 |
+
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
495 |
+
|
496 |
+
if inject_sample_info:
|
497 |
+
self.fps_embedding = nn.Embedding(2, dim)
|
498 |
+
self.fps_projection = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim * 6))
|
499 |
+
|
500 |
+
# blocks
|
501 |
+
cross_attn_type = "t2v_cross_attn" if model_type == "t2v" else "i2v_cross_attn"
|
502 |
+
self.blocks = nn.ModuleList(
|
503 |
+
[
|
504 |
+
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
|
505 |
+
for _ in range(num_layers)
|
506 |
+
]
|
507 |
+
)
|
508 |
+
|
509 |
+
# head
|
510 |
+
self.head = Head(dim, out_dim, patch_size, eps)
|
511 |
+
|
512 |
+
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
513 |
+
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
514 |
+
d = dim // num_heads
|
515 |
+
self.freqs = torch.cat(
|
516 |
+
[rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))],
|
517 |
+
dim=1,
|
518 |
+
)
|
519 |
+
|
520 |
+
if model_type == "i2v":
|
521 |
+
self.img_emb = MLPProj(1280, dim)
|
522 |
+
|
523 |
+
self.gradient_checkpointing = False
|
524 |
+
|
525 |
+
self.cpu_offloading = False
|
526 |
+
|
527 |
+
self.inject_sample_info = inject_sample_info
|
528 |
+
# initialize weights
|
529 |
+
self.init_weights()
|
530 |
+
|
531 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
532 |
+
self.gradient_checkpointing = value
|
533 |
+
|
534 |
+
def zero_init_i2v_cross_attn(self):
|
535 |
+
print("zero init i2v cross attn")
|
536 |
+
for i in range(self.num_layers):
|
537 |
+
self.blocks[i].cross_attn.v_img.weight.data.zero_()
|
538 |
+
self.blocks[i].cross_attn.v_img.bias.data.zero_()
|
539 |
+
|
540 |
+
@staticmethod
|
541 |
+
def _prepare_blockwise_causal_attn_mask(
|
542 |
+
device: torch.device | str, num_frames: int = 21, frame_seqlen: int = 1560, num_frame_per_block=1
|
543 |
+
) -> BlockMask:
|
544 |
+
"""
|
545 |
+
we will divide the token sequence into the following format
|
546 |
+
[1 latent frame] [1 latent frame] ... [1 latent frame]
|
547 |
+
We use flexattention to construct the attention mask
|
548 |
+
"""
|
549 |
+
total_length = num_frames * frame_seqlen
|
550 |
+
|
551 |
+
# we do right padding to get to a multiple of 128
|
552 |
+
padded_length = math.ceil(total_length / 128) * 128 - total_length
|
553 |
+
|
554 |
+
ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long)
|
555 |
+
|
556 |
+
# Block-wise causal mask will attend to all elements that are before the end of the current chunk
|
557 |
+
frame_indices = torch.arange(start=0, end=total_length, step=frame_seqlen * num_frame_per_block, device=device)
|
558 |
+
|
559 |
+
for tmp in frame_indices:
|
560 |
+
ends[tmp : tmp + frame_seqlen * num_frame_per_block] = tmp + frame_seqlen * num_frame_per_block
|
561 |
+
|
562 |
+
def attention_mask(b, h, q_idx, kv_idx):
|
563 |
+
return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
|
564 |
+
# return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask
|
565 |
+
|
566 |
+
block_mask = create_block_mask(
|
567 |
+
attention_mask,
|
568 |
+
B=None,
|
569 |
+
H=None,
|
570 |
+
Q_LEN=total_length + padded_length,
|
571 |
+
KV_LEN=total_length + padded_length,
|
572 |
+
_compile=False,
|
573 |
+
device=device,
|
574 |
+
)
|
575 |
+
|
576 |
+
return block_mask
|
577 |
+
|
578 |
+
def initialize_teacache(self, enable_teacache=True, num_steps=25, teacache_thresh=0.15, use_ret_steps=False, ckpt_dir=''):
|
579 |
+
self.enable_teacache = enable_teacache
|
580 |
+
print('using teacache')
|
581 |
+
self.cnt = 0
|
582 |
+
self.num_steps = num_steps
|
583 |
+
self.teacache_thresh = teacache_thresh
|
584 |
+
self.accumulated_rel_l1_distance_even = 0
|
585 |
+
self.accumulated_rel_l1_distance_odd = 0
|
586 |
+
self.previous_e0_even = None
|
587 |
+
self.previous_e0_odd = None
|
588 |
+
self.previous_residual_even = None
|
589 |
+
self.previous_residual_odd = None
|
590 |
+
self.use_ref_steps = use_ret_steps
|
591 |
+
if "I2V" in ckpt_dir:
|
592 |
+
if use_ret_steps:
|
593 |
+
if '540P' in ckpt_dir:
|
594 |
+
self.coefficients = [ 2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
|
595 |
+
if '720P' in ckpt_dir:
|
596 |
+
self.coefficients = [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
|
597 |
+
self.ret_steps = 5*2
|
598 |
+
self.cutoff_steps = num_steps*2
|
599 |
+
else:
|
600 |
+
if '540P' in ckpt_dir:
|
601 |
+
self.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
602 |
+
if '720P' in ckpt_dir:
|
603 |
+
self.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
|
604 |
+
self.ret_steps = 1*2
|
605 |
+
self.cutoff_steps = num_steps*2 - 2
|
606 |
+
else:
|
607 |
+
if use_ret_steps:
|
608 |
+
if '1.3B' in ckpt_dir:
|
609 |
+
self.coefficients = [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02]
|
610 |
+
if '14B' in ckpt_dir:
|
611 |
+
self.coefficients = [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01]
|
612 |
+
self.ret_steps = 5*2
|
613 |
+
self.cutoff_steps = num_steps*2
|
614 |
+
else:
|
615 |
+
if '1.3B' in ckpt_dir:
|
616 |
+
self.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
|
617 |
+
if '14B' in ckpt_dir:
|
618 |
+
self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
|
619 |
+
self.ret_steps = 1*2
|
620 |
+
self.cutoff_steps = num_steps*2 - 2
|
621 |
+
|
622 |
+
def forward(self, x, t, context, clip_fea=None, y=None, fps=None):
|
623 |
+
r"""
|
624 |
+
Forward pass through the diffusion model
|
625 |
+
|
626 |
+
Args:
|
627 |
+
x (List[Tensor]):
|
628 |
+
List of input video tensors, each with shape [C_in, F, H, W]
|
629 |
+
t (Tensor):
|
630 |
+
Diffusion timesteps tensor of shape [B]
|
631 |
+
context (List[Tensor]):
|
632 |
+
List of text embeddings each with shape [L, C]
|
633 |
+
seq_len (`int`):
|
634 |
+
Maximum sequence length for positional encoding
|
635 |
+
clip_fea (Tensor, *optional*):
|
636 |
+
CLIP image features for image-to-video mode
|
637 |
+
y (List[Tensor], *optional*):
|
638 |
+
Conditional video inputs for image-to-video mode, same shape as x
|
639 |
+
|
640 |
+
Returns:
|
641 |
+
List[Tensor]:
|
642 |
+
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
643 |
+
"""
|
644 |
+
if self.model_type == "i2v":
|
645 |
+
assert clip_fea is not None and y is not None
|
646 |
+
# params
|
647 |
+
device = self.patch_embedding.weight.device
|
648 |
+
if self.freqs.device != device:
|
649 |
+
self.freqs = self.freqs.to(device)
|
650 |
+
|
651 |
+
if y is not None:
|
652 |
+
x = torch.cat([x, y], dim=1)
|
653 |
+
|
654 |
+
# embeddings
|
655 |
+
x = self.patch_embedding(x)
|
656 |
+
grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long)
|
657 |
+
x = x.flatten(2).transpose(1, 2)
|
658 |
+
|
659 |
+
if self.flag_causal_attention:
|
660 |
+
frame_num = grid_sizes[0]
|
661 |
+
height = grid_sizes[1]
|
662 |
+
width = grid_sizes[2]
|
663 |
+
block_num = frame_num // self.num_frame_per_block
|
664 |
+
range_tensor = torch.arange(block_num).view(-1, 1)
|
665 |
+
range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten()
|
666 |
+
casual_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
|
667 |
+
casual_mask = casual_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x.device)
|
668 |
+
casual_mask = casual_mask.repeat(1, height, width, 1, height, width)
|
669 |
+
casual_mask = casual_mask.reshape(frame_num * height * width, frame_num * height * width)
|
670 |
+
self.block_mask = casual_mask.unsqueeze(0).unsqueeze(0)
|
671 |
+
|
672 |
+
# time embeddings
|
673 |
+
with amp.autocast("cuda", dtype=torch.float32):
|
674 |
+
if t.dim() == 2:
|
675 |
+
b, f = t.shape
|
676 |
+
_flag_df = True
|
677 |
+
else:
|
678 |
+
_flag_df = False
|
679 |
+
|
680 |
+
e = self.time_embedding(
|
681 |
+
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
|
682 |
+
) # b, dim
|
683 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # b, 6, dim
|
684 |
+
|
685 |
+
if self.inject_sample_info:
|
686 |
+
fps = torch.tensor(fps, dtype=torch.long, device=device)
|
687 |
+
|
688 |
+
fps_emb = self.fps_embedding(fps).float()
|
689 |
+
if _flag_df:
|
690 |
+
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
|
691 |
+
else:
|
692 |
+
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
|
693 |
+
|
694 |
+
if _flag_df:
|
695 |
+
e = e.view(b, f, 1, 1, self.dim)
|
696 |
+
e0 = e0.view(b, f, 1, 1, 6, self.dim)
|
697 |
+
e = e.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3)
|
698 |
+
e0 = e0.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3)
|
699 |
+
e0 = e0.transpose(1, 2).contiguous()
|
700 |
+
|
701 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
702 |
+
|
703 |
+
# context
|
704 |
+
context = self.text_embedding(context)
|
705 |
+
|
706 |
+
if clip_fea is not None:
|
707 |
+
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
708 |
+
context = torch.concat([context_clip, context], dim=1)
|
709 |
+
|
710 |
+
# arguments
|
711 |
+
kwargs = dict(e=e0, grid_sizes=grid_sizes, freqs=self.freqs, context=context, block_mask=self.block_mask)
|
712 |
+
if self.enable_teacache:
|
713 |
+
modulated_inp = e0 if self.use_ref_steps else e
|
714 |
+
# teacache
|
715 |
+
if self.cnt%2==0: # even -> conditon
|
716 |
+
self.is_even = True
|
717 |
+
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
718 |
+
should_calc_even = True
|
719 |
+
self.accumulated_rel_l1_distance_even = 0
|
720 |
+
else:
|
721 |
+
rescale_func = np.poly1d(self.coefficients)
|
722 |
+
self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
|
723 |
+
if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
|
724 |
+
should_calc_even = False
|
725 |
+
else:
|
726 |
+
should_calc_even = True
|
727 |
+
self.accumulated_rel_l1_distance_even = 0
|
728 |
+
self.previous_e0_even = modulated_inp.clone()
|
729 |
+
|
730 |
+
else: # odd -> unconditon
|
731 |
+
self.is_even = False
|
732 |
+
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
|
733 |
+
should_calc_odd = True
|
734 |
+
self.accumulated_rel_l1_distance_odd = 0
|
735 |
+
else:
|
736 |
+
rescale_func = np.poly1d(self.coefficients)
|
737 |
+
self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp-self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()).cpu().item())
|
738 |
+
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
|
739 |
+
should_calc_odd = False
|
740 |
+
else:
|
741 |
+
should_calc_odd = True
|
742 |
+
self.accumulated_rel_l1_distance_odd = 0
|
743 |
+
self.previous_e0_odd = modulated_inp.clone()
|
744 |
+
|
745 |
+
if self.enable_teacache:
|
746 |
+
if self.is_even:
|
747 |
+
if not should_calc_even:
|
748 |
+
x += self.previous_residual_even
|
749 |
+
else:
|
750 |
+
ori_x = x.clone()
|
751 |
+
for block in self.blocks:
|
752 |
+
x = block(x, **kwargs)
|
753 |
+
self.previous_residual_even = x - ori_x
|
754 |
+
else:
|
755 |
+
if not should_calc_odd:
|
756 |
+
x += self.previous_residual_odd
|
757 |
+
else:
|
758 |
+
ori_x = x.clone()
|
759 |
+
for block in self.blocks:
|
760 |
+
x = block(x, **kwargs)
|
761 |
+
self.previous_residual_odd = x - ori_x
|
762 |
+
|
763 |
+
self.cnt += 1
|
764 |
+
if self.cnt >= self.num_steps:
|
765 |
+
self.cnt = 0
|
766 |
+
else:
|
767 |
+
for block in self.blocks:
|
768 |
+
x = block(x, **kwargs)
|
769 |
+
|
770 |
+
x = self.head(x, e)
|
771 |
+
|
772 |
+
# unpatchify
|
773 |
+
x = self.unpatchify(x, grid_sizes)
|
774 |
+
|
775 |
+
return x.float()
|
776 |
+
|
777 |
+
def unpatchify(self, x, grid_sizes):
|
778 |
+
r"""
|
779 |
+
Reconstruct video tensors from patch embeddings.
|
780 |
+
|
781 |
+
Args:
|
782 |
+
x (List[Tensor]):
|
783 |
+
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
784 |
+
grid_sizes (Tensor):
|
785 |
+
Original spatial-temporal grid dimensions before patching,
|
786 |
+
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
787 |
+
|
788 |
+
Returns:
|
789 |
+
List[Tensor]:
|
790 |
+
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
791 |
+
"""
|
792 |
+
|
793 |
+
c = self.out_dim
|
794 |
+
bs = x.shape[0]
|
795 |
+
x = x.view(bs, *grid_sizes, *self.patch_size, c)
|
796 |
+
x = torch.einsum("bfhwpqrc->bcfphqwr", x)
|
797 |
+
x = x.reshape(bs, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
|
798 |
+
|
799 |
+
return x
|
800 |
+
|
801 |
+
def set_ar_attention(self, causal_block_size):
|
802 |
+
self.num_frame_per_block = causal_block_size
|
803 |
+
self.flag_causal_attention = True
|
804 |
+
for block in self.blocks:
|
805 |
+
block.set_ar_attention()
|
806 |
+
|
807 |
+
def init_weights(self):
|
808 |
+
r"""
|
809 |
+
Initialize model parameters using Xavier initialization.
|
810 |
+
"""
|
811 |
+
|
812 |
+
# basic init
|
813 |
+
for m in self.modules():
|
814 |
+
if isinstance(m, nn.Linear):
|
815 |
+
nn.init.xavier_uniform_(m.weight)
|
816 |
+
if m.bias is not None:
|
817 |
+
nn.init.zeros_(m.bias)
|
818 |
+
|
819 |
+
# init embeddings
|
820 |
+
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
|
821 |
+
for m in self.text_embedding.modules():
|
822 |
+
if isinstance(m, nn.Linear):
|
823 |
+
nn.init.normal_(m.weight, std=0.02)
|
824 |
+
for m in self.time_embedding.modules():
|
825 |
+
if isinstance(m, nn.Linear):
|
826 |
+
nn.init.normal_(m.weight, std=0.02)
|
827 |
+
|
828 |
+
if self.inject_sample_info:
|
829 |
+
nn.init.normal_(self.fps_embedding.weight, std=0.02)
|
830 |
+
|
831 |
+
for m in self.fps_projection.modules():
|
832 |
+
if isinstance(m, nn.Linear):
|
833 |
+
nn.init.normal_(m.weight, std=0.02)
|
834 |
+
|
835 |
+
nn.init.zeros_(self.fps_projection[-1].weight)
|
836 |
+
nn.init.zeros_(self.fps_projection[-1].bias)
|
837 |
+
|
838 |
+
# init output layer
|
839 |
+
nn.init.zeros_(self.head.head.weight)
|
skyreels_v2_infer/modules/vae.py
ADDED
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
2 |
+
import logging
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
"WanVAE",
|
12 |
+
]
|
13 |
+
|
14 |
+
CACHE_T = 2
|
15 |
+
|
16 |
+
|
17 |
+
class CausalConv3d(nn.Conv3d):
|
18 |
+
"""
|
19 |
+
Causal 3d convolusion.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, *args, **kwargs):
|
23 |
+
super().__init__(*args, **kwargs)
|
24 |
+
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
|
25 |
+
self.padding = (0, 0, 0)
|
26 |
+
|
27 |
+
def forward(self, x, cache_x=None):
|
28 |
+
padding = list(self._padding)
|
29 |
+
if cache_x is not None and self._padding[4] > 0:
|
30 |
+
cache_x = cache_x.to(x.device)
|
31 |
+
x = torch.cat([cache_x, x], dim=2)
|
32 |
+
padding[4] -= cache_x.shape[2]
|
33 |
+
x = F.pad(x, padding)
|
34 |
+
|
35 |
+
return super().forward(x)
|
36 |
+
|
37 |
+
|
38 |
+
class RMS_norm(nn.Module):
|
39 |
+
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
40 |
+
super().__init__()
|
41 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
42 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
43 |
+
|
44 |
+
self.channel_first = channel_first
|
45 |
+
self.scale = dim**0.5
|
46 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
47 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
51 |
+
|
52 |
+
|
53 |
+
class Upsample(nn.Upsample):
|
54 |
+
def forward(self, x):
|
55 |
+
"""
|
56 |
+
Fix bfloat16 support for nearest neighbor interpolation.
|
57 |
+
"""
|
58 |
+
return super().forward(x.float()).type_as(x)
|
59 |
+
|
60 |
+
|
61 |
+
class Resample(nn.Module):
|
62 |
+
def __init__(self, dim, mode):
|
63 |
+
assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d")
|
64 |
+
super().__init__()
|
65 |
+
self.dim = dim
|
66 |
+
self.mode = mode
|
67 |
+
|
68 |
+
# layers
|
69 |
+
if mode == "upsample2d":
|
70 |
+
self.resample = nn.Sequential(
|
71 |
+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
72 |
+
)
|
73 |
+
elif mode == "upsample3d":
|
74 |
+
self.resample = nn.Sequential(
|
75 |
+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
76 |
+
)
|
77 |
+
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
78 |
+
|
79 |
+
elif mode == "downsample2d":
|
80 |
+
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
81 |
+
elif mode == "downsample3d":
|
82 |
+
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
83 |
+
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
84 |
+
|
85 |
+
else:
|
86 |
+
self.resample = nn.Identity()
|
87 |
+
|
88 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
89 |
+
b, c, t, h, w = x.size()
|
90 |
+
if self.mode == "upsample3d":
|
91 |
+
if feat_cache is not None:
|
92 |
+
idx = feat_idx[0]
|
93 |
+
if feat_cache[idx] is None:
|
94 |
+
feat_cache[idx] = "Rep"
|
95 |
+
feat_idx[0] += 1
|
96 |
+
else:
|
97 |
+
|
98 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
99 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
|
100 |
+
# cache last frame of last two chunk
|
101 |
+
cache_x = torch.cat(
|
102 |
+
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
|
103 |
+
)
|
104 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
|
105 |
+
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
|
106 |
+
if feat_cache[idx] == "Rep":
|
107 |
+
x = self.time_conv(x)
|
108 |
+
else:
|
109 |
+
x = self.time_conv(x, feat_cache[idx])
|
110 |
+
feat_cache[idx] = cache_x
|
111 |
+
feat_idx[0] += 1
|
112 |
+
|
113 |
+
x = x.reshape(b, 2, c, t, h, w)
|
114 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
|
115 |
+
x = x.reshape(b, c, t * 2, h, w)
|
116 |
+
t = x.shape[2]
|
117 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
118 |
+
x = self.resample(x)
|
119 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
120 |
+
|
121 |
+
if self.mode == "downsample3d":
|
122 |
+
if feat_cache is not None:
|
123 |
+
idx = feat_idx[0]
|
124 |
+
if feat_cache[idx] is None:
|
125 |
+
feat_cache[idx] = x.clone()
|
126 |
+
feat_idx[0] += 1
|
127 |
+
else:
|
128 |
+
|
129 |
+
cache_x = x[:, :, -1:, :, :].clone()
|
130 |
+
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
131 |
+
# # cache last frame of last two chunk
|
132 |
+
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
133 |
+
|
134 |
+
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
135 |
+
feat_cache[idx] = cache_x
|
136 |
+
feat_idx[0] += 1
|
137 |
+
return x
|
138 |
+
|
139 |
+
def init_weight(self, conv):
|
140 |
+
conv_weight = conv.weight
|
141 |
+
nn.init.zeros_(conv_weight)
|
142 |
+
c1, c2, t, h, w = conv_weight.size()
|
143 |
+
one_matrix = torch.eye(c1, c2)
|
144 |
+
init_matrix = one_matrix
|
145 |
+
nn.init.zeros_(conv_weight)
|
146 |
+
# conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
|
147 |
+
conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
|
148 |
+
conv.weight.data.copy_(conv_weight)
|
149 |
+
nn.init.zeros_(conv.bias.data)
|
150 |
+
|
151 |
+
def init_weight2(self, conv):
|
152 |
+
conv_weight = conv.weight.data
|
153 |
+
nn.init.zeros_(conv_weight)
|
154 |
+
c1, c2, t, h, w = conv_weight.size()
|
155 |
+
init_matrix = torch.eye(c1 // 2, c2)
|
156 |
+
# init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
|
157 |
+
conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix
|
158 |
+
conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix
|
159 |
+
conv.weight.data.copy_(conv_weight)
|
160 |
+
nn.init.zeros_(conv.bias.data)
|
161 |
+
|
162 |
+
|
163 |
+
class ResidualBlock(nn.Module):
|
164 |
+
def __init__(self, in_dim, out_dim, dropout=0.0):
|
165 |
+
super().__init__()
|
166 |
+
self.in_dim = in_dim
|
167 |
+
self.out_dim = out_dim
|
168 |
+
|
169 |
+
# layers
|
170 |
+
self.residual = nn.Sequential(
|
171 |
+
RMS_norm(in_dim, images=False),
|
172 |
+
nn.SiLU(),
|
173 |
+
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
174 |
+
RMS_norm(out_dim, images=False),
|
175 |
+
nn.SiLU(),
|
176 |
+
nn.Dropout(dropout),
|
177 |
+
CausalConv3d(out_dim, out_dim, 3, padding=1),
|
178 |
+
)
|
179 |
+
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
|
180 |
+
|
181 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
182 |
+
h = self.shortcut(x)
|
183 |
+
for layer in self.residual:
|
184 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
185 |
+
idx = feat_idx[0]
|
186 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
187 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
188 |
+
# cache last frame of last two chunk
|
189 |
+
cache_x = torch.cat(
|
190 |
+
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
|
191 |
+
)
|
192 |
+
x = layer(x, feat_cache[idx])
|
193 |
+
feat_cache[idx] = cache_x
|
194 |
+
feat_idx[0] += 1
|
195 |
+
else:
|
196 |
+
x = layer(x)
|
197 |
+
return x + h
|
198 |
+
|
199 |
+
|
200 |
+
class AttentionBlock(nn.Module):
|
201 |
+
"""
|
202 |
+
Causal self-attention with a single head.
|
203 |
+
"""
|
204 |
+
|
205 |
+
def __init__(self, dim):
|
206 |
+
super().__init__()
|
207 |
+
self.dim = dim
|
208 |
+
|
209 |
+
# layers
|
210 |
+
self.norm = RMS_norm(dim)
|
211 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
212 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
213 |
+
|
214 |
+
# zero out the last layer params
|
215 |
+
nn.init.zeros_(self.proj.weight)
|
216 |
+
|
217 |
+
def forward(self, x):
|
218 |
+
identity = x
|
219 |
+
b, c, t, h, w = x.size()
|
220 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
221 |
+
x = self.norm(x)
|
222 |
+
# compute query, key, value
|
223 |
+
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1)
|
224 |
+
|
225 |
+
# apply attention
|
226 |
+
x = F.scaled_dot_product_attention(
|
227 |
+
q,
|
228 |
+
k,
|
229 |
+
v,
|
230 |
+
)
|
231 |
+
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
232 |
+
|
233 |
+
# output
|
234 |
+
x = self.proj(x)
|
235 |
+
x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
|
236 |
+
return x + identity
|
237 |
+
|
238 |
+
|
239 |
+
class Encoder3d(nn.Module):
|
240 |
+
def __init__(
|
241 |
+
self,
|
242 |
+
dim=128,
|
243 |
+
z_dim=4,
|
244 |
+
dim_mult=[1, 2, 4, 4],
|
245 |
+
num_res_blocks=2,
|
246 |
+
attn_scales=[],
|
247 |
+
temperal_downsample=[True, True, False],
|
248 |
+
dropout=0.0,
|
249 |
+
):
|
250 |
+
super().__init__()
|
251 |
+
self.dim = dim
|
252 |
+
self.z_dim = z_dim
|
253 |
+
self.dim_mult = dim_mult
|
254 |
+
self.num_res_blocks = num_res_blocks
|
255 |
+
self.attn_scales = attn_scales
|
256 |
+
self.temperal_downsample = temperal_downsample
|
257 |
+
|
258 |
+
# dimensions
|
259 |
+
dims = [dim * u for u in [1] + dim_mult]
|
260 |
+
scale = 1.0
|
261 |
+
|
262 |
+
# init block
|
263 |
+
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
264 |
+
|
265 |
+
# downsample blocks
|
266 |
+
downsamples = []
|
267 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
268 |
+
# residual (+attention) blocks
|
269 |
+
for _ in range(num_res_blocks):
|
270 |
+
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
271 |
+
if scale in attn_scales:
|
272 |
+
downsamples.append(AttentionBlock(out_dim))
|
273 |
+
in_dim = out_dim
|
274 |
+
|
275 |
+
# downsample block
|
276 |
+
if i != len(dim_mult) - 1:
|
277 |
+
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
278 |
+
downsamples.append(Resample(out_dim, mode=mode))
|
279 |
+
scale /= 2.0
|
280 |
+
self.downsamples = nn.Sequential(*downsamples)
|
281 |
+
|
282 |
+
# middle blocks
|
283 |
+
self.middle = nn.Sequential(
|
284 |
+
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout)
|
285 |
+
)
|
286 |
+
|
287 |
+
# output blocks
|
288 |
+
self.head = nn.Sequential(
|
289 |
+
RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1)
|
290 |
+
)
|
291 |
+
|
292 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
293 |
+
if feat_cache is not None:
|
294 |
+
idx = feat_idx[0]
|
295 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
296 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
297 |
+
# cache last frame of last two chunk
|
298 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
299 |
+
x = self.conv1(x, feat_cache[idx])
|
300 |
+
feat_cache[idx] = cache_x
|
301 |
+
feat_idx[0] += 1
|
302 |
+
else:
|
303 |
+
x = self.conv1(x)
|
304 |
+
|
305 |
+
## downsamples
|
306 |
+
for layer in self.downsamples:
|
307 |
+
if feat_cache is not None:
|
308 |
+
x = layer(x, feat_cache, feat_idx)
|
309 |
+
else:
|
310 |
+
x = layer(x)
|
311 |
+
|
312 |
+
## middle
|
313 |
+
for layer in self.middle:
|
314 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
315 |
+
x = layer(x, feat_cache, feat_idx)
|
316 |
+
else:
|
317 |
+
x = layer(x)
|
318 |
+
|
319 |
+
## head
|
320 |
+
for layer in self.head:
|
321 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
322 |
+
idx = feat_idx[0]
|
323 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
324 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
325 |
+
# cache last frame of last two chunk
|
326 |
+
cache_x = torch.cat(
|
327 |
+
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
|
328 |
+
)
|
329 |
+
x = layer(x, feat_cache[idx])
|
330 |
+
feat_cache[idx] = cache_x
|
331 |
+
feat_idx[0] += 1
|
332 |
+
else:
|
333 |
+
x = layer(x)
|
334 |
+
return x
|
335 |
+
|
336 |
+
|
337 |
+
class Decoder3d(nn.Module):
|
338 |
+
def __init__(
|
339 |
+
self,
|
340 |
+
dim=128,
|
341 |
+
z_dim=4,
|
342 |
+
dim_mult=[1, 2, 4, 4],
|
343 |
+
num_res_blocks=2,
|
344 |
+
attn_scales=[],
|
345 |
+
temperal_upsample=[False, True, True],
|
346 |
+
dropout=0.0,
|
347 |
+
):
|
348 |
+
super().__init__()
|
349 |
+
self.dim = dim
|
350 |
+
self.z_dim = z_dim
|
351 |
+
self.dim_mult = dim_mult
|
352 |
+
self.num_res_blocks = num_res_blocks
|
353 |
+
self.attn_scales = attn_scales
|
354 |
+
self.temperal_upsample = temperal_upsample
|
355 |
+
|
356 |
+
# dimensions
|
357 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
358 |
+
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
359 |
+
|
360 |
+
# init block
|
361 |
+
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
362 |
+
|
363 |
+
# middle blocks
|
364 |
+
self.middle = nn.Sequential(
|
365 |
+
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout)
|
366 |
+
)
|
367 |
+
|
368 |
+
# upsample blocks
|
369 |
+
upsamples = []
|
370 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
371 |
+
# residual (+attention) blocks
|
372 |
+
if i == 1 or i == 2 or i == 3:
|
373 |
+
in_dim = in_dim // 2
|
374 |
+
for _ in range(num_res_blocks + 1):
|
375 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
376 |
+
if scale in attn_scales:
|
377 |
+
upsamples.append(AttentionBlock(out_dim))
|
378 |
+
in_dim = out_dim
|
379 |
+
|
380 |
+
# upsample block
|
381 |
+
if i != len(dim_mult) - 1:
|
382 |
+
mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
|
383 |
+
upsamples.append(Resample(out_dim, mode=mode))
|
384 |
+
scale *= 2.0
|
385 |
+
self.upsamples = nn.Sequential(*upsamples)
|
386 |
+
|
387 |
+
# output blocks
|
388 |
+
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 3, 3, padding=1))
|
389 |
+
|
390 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
391 |
+
## conv1
|
392 |
+
if feat_cache is not None:
|
393 |
+
idx = feat_idx[0]
|
394 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
395 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
396 |
+
# cache last frame of last two chunk
|
397 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
398 |
+
x = self.conv1(x, feat_cache[idx])
|
399 |
+
feat_cache[idx] = cache_x
|
400 |
+
feat_idx[0] += 1
|
401 |
+
else:
|
402 |
+
x = self.conv1(x)
|
403 |
+
|
404 |
+
## middle
|
405 |
+
for layer in self.middle:
|
406 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
407 |
+
x = layer(x, feat_cache, feat_idx)
|
408 |
+
else:
|
409 |
+
x = layer(x)
|
410 |
+
|
411 |
+
## upsamples
|
412 |
+
for layer in self.upsamples:
|
413 |
+
if feat_cache is not None:
|
414 |
+
x = layer(x, feat_cache, feat_idx)
|
415 |
+
else:
|
416 |
+
x = layer(x)
|
417 |
+
|
418 |
+
## head
|
419 |
+
for layer in self.head:
|
420 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
421 |
+
idx = feat_idx[0]
|
422 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
423 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
424 |
+
# cache last frame of last two chunk
|
425 |
+
cache_x = torch.cat(
|
426 |
+
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
|
427 |
+
)
|
428 |
+
x = layer(x, feat_cache[idx])
|
429 |
+
feat_cache[idx] = cache_x
|
430 |
+
feat_idx[0] += 1
|
431 |
+
else:
|
432 |
+
x = layer(x)
|
433 |
+
return x
|
434 |
+
|
435 |
+
|
436 |
+
def count_conv3d(model):
|
437 |
+
count = 0
|
438 |
+
for m in model.modules():
|
439 |
+
if isinstance(m, CausalConv3d):
|
440 |
+
count += 1
|
441 |
+
return count
|
442 |
+
|
443 |
+
|
444 |
+
class WanVAE_(nn.Module):
|
445 |
+
def __init__(
|
446 |
+
self,
|
447 |
+
dim=128,
|
448 |
+
z_dim=4,
|
449 |
+
dim_mult=[1, 2, 4, 4],
|
450 |
+
num_res_blocks=2,
|
451 |
+
attn_scales=[],
|
452 |
+
temperal_downsample=[True, True, False],
|
453 |
+
dropout=0.0,
|
454 |
+
):
|
455 |
+
super().__init__()
|
456 |
+
self.dim = dim
|
457 |
+
self.z_dim = z_dim
|
458 |
+
self.dim_mult = dim_mult
|
459 |
+
self.num_res_blocks = num_res_blocks
|
460 |
+
self.attn_scales = attn_scales
|
461 |
+
self.temperal_downsample = temperal_downsample
|
462 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
463 |
+
|
464 |
+
# modules
|
465 |
+
self.encoder = Encoder3d(
|
466 |
+
dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
|
467 |
+
)
|
468 |
+
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
469 |
+
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
470 |
+
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout)
|
471 |
+
|
472 |
+
def forward(self, x):
|
473 |
+
mu, log_var = self.encode(x)
|
474 |
+
z = self.reparameterize(mu, log_var)
|
475 |
+
x_recon = self.decode(z)
|
476 |
+
return x_recon, mu, log_var
|
477 |
+
|
478 |
+
def encode(self, x, scale):
|
479 |
+
self.clear_cache()
|
480 |
+
## cache
|
481 |
+
t = x.shape[2]
|
482 |
+
iter_ = 1 + (t - 1) // 4
|
483 |
+
## 对encode输入的x,按时间拆分为1、4、4、4....
|
484 |
+
for i in range(iter_):
|
485 |
+
self._enc_conv_idx = [0]
|
486 |
+
if i == 0:
|
487 |
+
out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
|
488 |
+
else:
|
489 |
+
out_ = self.encoder(
|
490 |
+
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
|
491 |
+
feat_cache=self._enc_feat_map,
|
492 |
+
feat_idx=self._enc_conv_idx,
|
493 |
+
)
|
494 |
+
out = torch.cat([out, out_], 2)
|
495 |
+
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
496 |
+
if isinstance(scale[0], torch.Tensor):
|
497 |
+
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
|
498 |
+
else:
|
499 |
+
mu = (mu - scale[0]) * scale[1]
|
500 |
+
self.clear_cache()
|
501 |
+
return mu
|
502 |
+
|
503 |
+
def decode(self, z, scale):
|
504 |
+
self.clear_cache()
|
505 |
+
# z: [b,c,t,h,w]
|
506 |
+
if isinstance(scale[0], torch.Tensor):
|
507 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
|
508 |
+
else:
|
509 |
+
z = z / scale[1] + scale[0]
|
510 |
+
iter_ = z.shape[2]
|
511 |
+
x = self.conv2(z)
|
512 |
+
for i in range(iter_):
|
513 |
+
self._conv_idx = [0]
|
514 |
+
if i == 0:
|
515 |
+
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
516 |
+
else:
|
517 |
+
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
518 |
+
out = torch.cat([out, out_], 2)
|
519 |
+
self.clear_cache()
|
520 |
+
return out
|
521 |
+
|
522 |
+
def reparameterize(self, mu, log_var):
|
523 |
+
std = torch.exp(0.5 * log_var)
|
524 |
+
eps = torch.randn_like(std)
|
525 |
+
return eps * std + mu
|
526 |
+
|
527 |
+
def sample(self, imgs, deterministic=False):
|
528 |
+
mu, log_var = self.encode(imgs)
|
529 |
+
if deterministic:
|
530 |
+
return mu
|
531 |
+
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
532 |
+
return mu + std * torch.randn_like(std)
|
533 |
+
|
534 |
+
def clear_cache(self):
|
535 |
+
self._conv_num = count_conv3d(self.decoder)
|
536 |
+
self._conv_idx = [0]
|
537 |
+
self._feat_map = [None] * self._conv_num
|
538 |
+
# cache encode
|
539 |
+
self._enc_conv_num = count_conv3d(self.encoder)
|
540 |
+
self._enc_conv_idx = [0]
|
541 |
+
self._enc_feat_map = [None] * self._enc_conv_num
|
542 |
+
|
543 |
+
|
544 |
+
def _video_vae(pretrained_path=None, z_dim=None, device="cpu", **kwargs):
|
545 |
+
"""
|
546 |
+
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
|
547 |
+
"""
|
548 |
+
# params
|
549 |
+
cfg = dict(
|
550 |
+
dim=96,
|
551 |
+
z_dim=z_dim,
|
552 |
+
dim_mult=[1, 2, 4, 4],
|
553 |
+
num_res_blocks=2,
|
554 |
+
attn_scales=[],
|
555 |
+
temperal_downsample=[False, True, True],
|
556 |
+
dropout=0.0,
|
557 |
+
)
|
558 |
+
cfg.update(**kwargs)
|
559 |
+
|
560 |
+
# init model
|
561 |
+
with torch.device("meta"):
|
562 |
+
model = WanVAE_(**cfg)
|
563 |
+
|
564 |
+
# load checkpoint
|
565 |
+
logging.info(f"loading {pretrained_path}")
|
566 |
+
model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True)
|
567 |
+
|
568 |
+
return model
|
569 |
+
|
570 |
+
|
571 |
+
class WanVAE:
|
572 |
+
def __init__(self, vae_pth="cache/vae_step_411000.pth", z_dim=16):
|
573 |
+
|
574 |
+
mean = [
|
575 |
+
-0.7571,
|
576 |
+
-0.7089,
|
577 |
+
-0.9113,
|
578 |
+
0.1075,
|
579 |
+
-0.1745,
|
580 |
+
0.9653,
|
581 |
+
-0.1517,
|
582 |
+
1.5508,
|
583 |
+
0.4134,
|
584 |
+
-0.0715,
|
585 |
+
0.5517,
|
586 |
+
-0.3632,
|
587 |
+
-0.1922,
|
588 |
+
-0.9497,
|
589 |
+
0.2503,
|
590 |
+
-0.2921,
|
591 |
+
]
|
592 |
+
std = [
|
593 |
+
2.8184,
|
594 |
+
1.4541,
|
595 |
+
2.3275,
|
596 |
+
2.6558,
|
597 |
+
1.2196,
|
598 |
+
1.7708,
|
599 |
+
2.6052,
|
600 |
+
2.0743,
|
601 |
+
3.2687,
|
602 |
+
2.1526,
|
603 |
+
2.8652,
|
604 |
+
1.5579,
|
605 |
+
1.6382,
|
606 |
+
1.1253,
|
607 |
+
2.8251,
|
608 |
+
1.9160,
|
609 |
+
]
|
610 |
+
self.vae_stride = (4, 8, 8)
|
611 |
+
self.mean = torch.tensor(mean)
|
612 |
+
self.std = torch.tensor(std)
|
613 |
+
self.scale = [self.mean, 1.0 / self.std]
|
614 |
+
|
615 |
+
# init model
|
616 |
+
self.vae = (
|
617 |
+
_video_vae(
|
618 |
+
pretrained_path=vae_pth,
|
619 |
+
z_dim=z_dim,
|
620 |
+
)
|
621 |
+
.eval()
|
622 |
+
.requires_grad_(False)
|
623 |
+
)
|
624 |
+
|
625 |
+
def encode(self, video):
|
626 |
+
"""
|
627 |
+
videos: A list of videos each with shape [C, T, H, W].
|
628 |
+
"""
|
629 |
+
return self.vae.encode(video, self.scale).float()
|
630 |
+
|
631 |
+
def to(self, *args, **kwargs):
|
632 |
+
self.mean = self.mean.to(*args, **kwargs)
|
633 |
+
self.std = self.std.to(*args, **kwargs)
|
634 |
+
self.scale = [self.mean, 1.0 / self.std]
|
635 |
+
self.vae = self.vae.to(*args, **kwargs)
|
636 |
+
return self
|
637 |
+
|
638 |
+
def decode(self, z):
|
639 |
+
return self.vae.decode(z, self.scale).float().clamp_(-1, 1)
|
skyreels_v2_infer/modules/xlm_roberta.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
|
2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
__all__ = ["XLMRoberta", "xlm_roberta_large"]
|
8 |
+
|
9 |
+
|
10 |
+
class SelfAttention(nn.Module):
|
11 |
+
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
|
12 |
+
assert dim % num_heads == 0
|
13 |
+
super().__init__()
|
14 |
+
self.dim = dim
|
15 |
+
self.num_heads = num_heads
|
16 |
+
self.head_dim = dim // num_heads
|
17 |
+
self.eps = eps
|
18 |
+
|
19 |
+
# layers
|
20 |
+
self.q = nn.Linear(dim, dim)
|
21 |
+
self.k = nn.Linear(dim, dim)
|
22 |
+
self.v = nn.Linear(dim, dim)
|
23 |
+
self.o = nn.Linear(dim, dim)
|
24 |
+
self.dropout = nn.Dropout(dropout)
|
25 |
+
|
26 |
+
def forward(self, x, mask):
|
27 |
+
"""
|
28 |
+
x: [B, L, C].
|
29 |
+
"""
|
30 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
31 |
+
|
32 |
+
# compute query, key, value
|
33 |
+
q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
34 |
+
k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
35 |
+
v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
36 |
+
|
37 |
+
# compute attention
|
38 |
+
p = self.dropout.p if self.training else 0.0
|
39 |
+
x = F.scaled_dot_product_attention(q, k, v, mask, p)
|
40 |
+
x = x.permute(0, 2, 1, 3).reshape(b, s, c)
|
41 |
+
|
42 |
+
# output
|
43 |
+
x = self.o(x)
|
44 |
+
x = self.dropout(x)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
class AttentionBlock(nn.Module):
|
49 |
+
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
|
50 |
+
super().__init__()
|
51 |
+
self.dim = dim
|
52 |
+
self.num_heads = num_heads
|
53 |
+
self.post_norm = post_norm
|
54 |
+
self.eps = eps
|
55 |
+
|
56 |
+
# layers
|
57 |
+
self.attn = SelfAttention(dim, num_heads, dropout, eps)
|
58 |
+
self.norm1 = nn.LayerNorm(dim, eps=eps)
|
59 |
+
self.ffn = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), nn.Dropout(dropout))
|
60 |
+
self.norm2 = nn.LayerNorm(dim, eps=eps)
|
61 |
+
|
62 |
+
def forward(self, x, mask):
|
63 |
+
if self.post_norm:
|
64 |
+
x = self.norm1(x + self.attn(x, mask))
|
65 |
+
x = self.norm2(x + self.ffn(x))
|
66 |
+
else:
|
67 |
+
x = x + self.attn(self.norm1(x), mask)
|
68 |
+
x = x + self.ffn(self.norm2(x))
|
69 |
+
return x
|
70 |
+
|
71 |
+
|
72 |
+
class XLMRoberta(nn.Module):
|
73 |
+
"""
|
74 |
+
XLMRobertaModel with no pooler and no LM head.
|
75 |
+
"""
|
76 |
+
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
vocab_size=250002,
|
80 |
+
max_seq_len=514,
|
81 |
+
type_size=1,
|
82 |
+
pad_id=1,
|
83 |
+
dim=1024,
|
84 |
+
num_heads=16,
|
85 |
+
num_layers=24,
|
86 |
+
post_norm=True,
|
87 |
+
dropout=0.1,
|
88 |
+
eps=1e-5,
|
89 |
+
):
|
90 |
+
super().__init__()
|
91 |
+
self.vocab_size = vocab_size
|
92 |
+
self.max_seq_len = max_seq_len
|
93 |
+
self.type_size = type_size
|
94 |
+
self.pad_id = pad_id
|
95 |
+
self.dim = dim
|
96 |
+
self.num_heads = num_heads
|
97 |
+
self.num_layers = num_layers
|
98 |
+
self.post_norm = post_norm
|
99 |
+
self.eps = eps
|
100 |
+
|
101 |
+
# embeddings
|
102 |
+
self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
|
103 |
+
self.type_embedding = nn.Embedding(type_size, dim)
|
104 |
+
self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
|
105 |
+
self.dropout = nn.Dropout(dropout)
|
106 |
+
|
107 |
+
# blocks
|
108 |
+
self.blocks = nn.ModuleList(
|
109 |
+
[AttentionBlock(dim, num_heads, post_norm, dropout, eps) for _ in range(num_layers)]
|
110 |
+
)
|
111 |
+
|
112 |
+
# norm layer
|
113 |
+
self.norm = nn.LayerNorm(dim, eps=eps)
|
114 |
+
|
115 |
+
def forward(self, ids):
|
116 |
+
"""
|
117 |
+
ids: [B, L] of torch.LongTensor.
|
118 |
+
"""
|
119 |
+
b, s = ids.shape
|
120 |
+
mask = ids.ne(self.pad_id).long()
|
121 |
+
|
122 |
+
# embeddings
|
123 |
+
x = (
|
124 |
+
self.token_embedding(ids)
|
125 |
+
+ self.type_embedding(torch.zeros_like(ids))
|
126 |
+
+ self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
|
127 |
+
)
|
128 |
+
if self.post_norm:
|
129 |
+
x = self.norm(x)
|
130 |
+
x = self.dropout(x)
|
131 |
+
|
132 |
+
# blocks
|
133 |
+
mask = torch.where(mask.view(b, 1, 1, s).gt(0), 0.0, torch.finfo(x.dtype).min)
|
134 |
+
for block in self.blocks:
|
135 |
+
x = block(x, mask)
|
136 |
+
|
137 |
+
# output
|
138 |
+
if not self.post_norm:
|
139 |
+
x = self.norm(x)
|
140 |
+
return x
|
141 |
+
|
142 |
+
|
143 |
+
def xlm_roberta_large(pretrained=False, return_tokenizer=False, device="cpu", **kwargs):
|
144 |
+
"""
|
145 |
+
XLMRobertaLarge adapted from Huggingface.
|
146 |
+
"""
|
147 |
+
# params
|
148 |
+
cfg = dict(
|
149 |
+
vocab_size=250002,
|
150 |
+
max_seq_len=514,
|
151 |
+
type_size=1,
|
152 |
+
pad_id=1,
|
153 |
+
dim=1024,
|
154 |
+
num_heads=16,
|
155 |
+
num_layers=24,
|
156 |
+
post_norm=True,
|
157 |
+
dropout=0.1,
|
158 |
+
eps=1e-5,
|
159 |
+
)
|
160 |
+
cfg.update(**kwargs)
|
161 |
+
|
162 |
+
# init a model on device
|
163 |
+
with torch.device(device):
|
164 |
+
model = XLMRoberta(**cfg)
|
165 |
+
return model
|
skyreels_v2_infer/pipelines/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .diffusion_forcing_pipeline import DiffusionForcingPipeline
|
2 |
+
from .image2video_pipeline import Image2VideoPipeline
|
3 |
+
from .image2video_pipeline import resizecrop
|
4 |
+
from .prompt_enhancer import PromptEnhancer
|
5 |
+
from .text2video_pipeline import Text2VideoPipeline
|
skyreels_v2_infer/pipelines/diffusion_forcing_pipeline.py
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
from typing import List
|
4 |
+
from typing import Optional
|
5 |
+
from typing import Tuple
|
6 |
+
from typing import Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from diffusers.image_processor import PipelineImageInput
|
11 |
+
from diffusers.utils.torch_utils import randn_tensor
|
12 |
+
from diffusers.video_processor import VideoProcessor
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from ..modules import get_text_encoder
|
16 |
+
from ..modules import get_transformer
|
17 |
+
from ..modules import get_vae
|
18 |
+
from ..scheduler.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
19 |
+
|
20 |
+
|
21 |
+
class DiffusionForcingPipeline:
|
22 |
+
"""
|
23 |
+
A pipeline for diffusion-based video generation tasks.
|
24 |
+
|
25 |
+
This pipeline supports two main tasks:
|
26 |
+
- Image-to-Video (i2v): Generates a video sequence from a source image
|
27 |
+
- Text-to-Video (t2v): Generates a video sequence from a text description
|
28 |
+
|
29 |
+
The pipeline integrates multiple components including:
|
30 |
+
- A transformer model for diffusion
|
31 |
+
- A VAE for encoding/decoding
|
32 |
+
- A text encoder for processing text prompts
|
33 |
+
- An image encoder for processing image inputs (i2v mode only)
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
model_path: str,
|
39 |
+
dit_path: str,
|
40 |
+
device: str = "cuda",
|
41 |
+
weight_dtype=torch.bfloat16,
|
42 |
+
use_usp=False,
|
43 |
+
offload=False,
|
44 |
+
):
|
45 |
+
"""
|
46 |
+
Initialize the diffusion forcing pipeline class
|
47 |
+
|
48 |
+
Args:
|
49 |
+
model_path (str): Path to the model
|
50 |
+
dit_path (str): Path to the DIT model, containing model configuration file (config.json) and weight file (*.safetensor)
|
51 |
+
device (str): Device to run on, defaults to 'cuda'
|
52 |
+
weight_dtype: Weight data type, defaults to torch.bfloat16
|
53 |
+
"""
|
54 |
+
load_device = "cpu" if offload else device
|
55 |
+
self.transformer = get_transformer(dit_path, load_device, weight_dtype)
|
56 |
+
vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth")
|
57 |
+
self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
|
58 |
+
self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype)
|
59 |
+
self.video_processor = VideoProcessor(vae_scale_factor=16)
|
60 |
+
self.device = device
|
61 |
+
self.offload = offload
|
62 |
+
|
63 |
+
if use_usp:
|
64 |
+
from xfuser.core.distributed import get_sequence_parallel_world_size
|
65 |
+
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
|
66 |
+
import types
|
67 |
+
|
68 |
+
for block in self.transformer.blocks:
|
69 |
+
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
70 |
+
self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer)
|
71 |
+
self.sp_size = get_sequence_parallel_world_size()
|
72 |
+
|
73 |
+
self.scheduler = FlowUniPCMultistepScheduler()
|
74 |
+
|
75 |
+
@property
|
76 |
+
def do_classifier_free_guidance(self) -> bool:
|
77 |
+
return self._guidance_scale > 1
|
78 |
+
|
79 |
+
def encode_image(
|
80 |
+
self, image: PipelineImageInput, height: int, width: int, num_frames: int
|
81 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
82 |
+
|
83 |
+
# prefix_video
|
84 |
+
prefix_video = np.array(image.resize((width, height))).transpose(2, 0, 1)
|
85 |
+
prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1)
|
86 |
+
if prefix_video.dtype == torch.uint8:
|
87 |
+
prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0
|
88 |
+
prefix_video = prefix_video.to(self.device)
|
89 |
+
prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
|
90 |
+
causal_block_size = self.transformer.num_frame_per_block
|
91 |
+
if prefix_video[0].shape[1] % causal_block_size != 0:
|
92 |
+
truncate_len = prefix_video[0].shape[1] % causal_block_size
|
93 |
+
print("the length of prefix video is truncated for the casual block size alignment.")
|
94 |
+
prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
|
95 |
+
predix_video_latent_length = prefix_video[0].shape[1]
|
96 |
+
return prefix_video, predix_video_latent_length
|
97 |
+
|
98 |
+
def prepare_latents(
|
99 |
+
self,
|
100 |
+
shape: Tuple[int],
|
101 |
+
dtype: Optional[torch.dtype] = None,
|
102 |
+
device: Optional[torch.device] = None,
|
103 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
104 |
+
) -> torch.Tensor:
|
105 |
+
return randn_tensor(shape, generator, device=device, dtype=dtype)
|
106 |
+
|
107 |
+
def generate_timestep_matrix(
|
108 |
+
self,
|
109 |
+
num_frames,
|
110 |
+
step_template,
|
111 |
+
base_num_frames,
|
112 |
+
ar_step=5,
|
113 |
+
num_pre_ready=0,
|
114 |
+
casual_block_size=1,
|
115 |
+
shrink_interval_with_mask=False,
|
116 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
|
117 |
+
step_matrix, step_index = [], []
|
118 |
+
update_mask, valid_interval = [], []
|
119 |
+
num_iterations = len(step_template) + 1
|
120 |
+
num_frames_block = num_frames // casual_block_size
|
121 |
+
base_num_frames_block = base_num_frames // casual_block_size
|
122 |
+
if base_num_frames_block < num_frames_block:
|
123 |
+
infer_step_num = len(step_template)
|
124 |
+
gen_block = base_num_frames_block
|
125 |
+
min_ar_step = infer_step_num / gen_block
|
126 |
+
assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting"
|
127 |
+
# print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block)
|
128 |
+
step_template = torch.cat(
|
129 |
+
[
|
130 |
+
torch.tensor([999], dtype=torch.int64, device=step_template.device),
|
131 |
+
step_template.long(),
|
132 |
+
torch.tensor([0], dtype=torch.int64, device=step_template.device),
|
133 |
+
]
|
134 |
+
) # to handle the counter in row works starting from 1
|
135 |
+
pre_row = torch.zeros(num_frames_block, dtype=torch.long)
|
136 |
+
if num_pre_ready > 0:
|
137 |
+
pre_row[: num_pre_ready // casual_block_size] = num_iterations
|
138 |
+
|
139 |
+
while torch.all(pre_row >= (num_iterations - 1)) == False:
|
140 |
+
new_row = torch.zeros(num_frames_block, dtype=torch.long)
|
141 |
+
for i in range(num_frames_block):
|
142 |
+
if i == 0 or pre_row[i - 1] >= (
|
143 |
+
num_iterations - 1
|
144 |
+
): # the first frame or the last frame is completely denoised
|
145 |
+
new_row[i] = pre_row[i] + 1
|
146 |
+
else:
|
147 |
+
new_row[i] = new_row[i - 1] - ar_step
|
148 |
+
new_row = new_row.clamp(0, num_iterations)
|
149 |
+
|
150 |
+
update_mask.append(
|
151 |
+
(new_row != pre_row) & (new_row != num_iterations)
|
152 |
+
) # False: no need to update, True: need to update
|
153 |
+
step_index.append(new_row)
|
154 |
+
step_matrix.append(step_template[new_row])
|
155 |
+
pre_row = new_row
|
156 |
+
|
157 |
+
# for long video we split into several sequences, base_num_frames is set to the model max length (for training)
|
158 |
+
terminal_flag = base_num_frames_block
|
159 |
+
if shrink_interval_with_mask:
|
160 |
+
idx_sequence = torch.arange(num_frames_block, dtype=torch.int64)
|
161 |
+
update_mask = update_mask[0]
|
162 |
+
update_mask_idx = idx_sequence[update_mask]
|
163 |
+
last_update_idx = update_mask_idx[-1].item()
|
164 |
+
terminal_flag = last_update_idx + 1
|
165 |
+
# for i in range(0, len(update_mask)):
|
166 |
+
for curr_mask in update_mask:
|
167 |
+
if terminal_flag < num_frames_block and curr_mask[terminal_flag]:
|
168 |
+
terminal_flag += 1
|
169 |
+
valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag))
|
170 |
+
|
171 |
+
step_update_mask = torch.stack(update_mask, dim=0)
|
172 |
+
step_index = torch.stack(step_index, dim=0)
|
173 |
+
step_matrix = torch.stack(step_matrix, dim=0)
|
174 |
+
|
175 |
+
if casual_block_size > 1:
|
176 |
+
step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
|
177 |
+
step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
|
178 |
+
step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
|
179 |
+
valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval]
|
180 |
+
|
181 |
+
return step_matrix, step_index, step_update_mask, valid_interval
|
182 |
+
|
183 |
+
@torch.no_grad()
|
184 |
+
def __call__(
|
185 |
+
self,
|
186 |
+
prompt: Union[str, List[str]],
|
187 |
+
negative_prompt: Union[str, List[str]] = "",
|
188 |
+
image: PipelineImageInput = None,
|
189 |
+
height: int = 480,
|
190 |
+
width: int = 832,
|
191 |
+
num_frames: int = 97,
|
192 |
+
num_inference_steps: int = 50,
|
193 |
+
shift: float = 1.0,
|
194 |
+
guidance_scale: float = 5.0,
|
195 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
196 |
+
overlap_history: int = None,
|
197 |
+
addnoise_condition: int = 0,
|
198 |
+
base_num_frames: int = 97,
|
199 |
+
ar_step: int = 5,
|
200 |
+
causal_block_size: int = None,
|
201 |
+
fps: int = 24,
|
202 |
+
):
|
203 |
+
latent_height = height // 8
|
204 |
+
latent_width = width // 8
|
205 |
+
latent_length = (num_frames - 1) // 4 + 1
|
206 |
+
|
207 |
+
self._guidance_scale = guidance_scale
|
208 |
+
|
209 |
+
i2v_extra_kwrags = {}
|
210 |
+
prefix_video = None
|
211 |
+
predix_video_latent_length = 0
|
212 |
+
if image:
|
213 |
+
prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames)
|
214 |
+
|
215 |
+
self.text_encoder.to(self.device)
|
216 |
+
prompt_embeds = self.text_encoder.encode(prompt).to(self.transformer.dtype)
|
217 |
+
if self.do_classifier_free_guidance:
|
218 |
+
negative_prompt_embeds = self.text_encoder.encode(negative_prompt).to(self.transformer.dtype)
|
219 |
+
if self.offload:
|
220 |
+
self.text_encoder.cpu()
|
221 |
+
torch.cuda.empty_cache()
|
222 |
+
|
223 |
+
self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift)
|
224 |
+
init_timesteps = self.scheduler.timesteps
|
225 |
+
if causal_block_size is None:
|
226 |
+
causal_block_size = self.transformer.num_frame_per_block
|
227 |
+
fps_embeds = [fps] * prompt_embeds.shape[0]
|
228 |
+
fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
|
229 |
+
transformer_dtype = self.transformer.dtype
|
230 |
+
# with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad():
|
231 |
+
if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames:
|
232 |
+
# short video generation
|
233 |
+
latent_shape = [16, latent_length, latent_height, latent_width]
|
234 |
+
latents = self.prepare_latents(
|
235 |
+
latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator
|
236 |
+
)
|
237 |
+
latents = [latents]
|
238 |
+
if prefix_video is not None:
|
239 |
+
latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype)
|
240 |
+
base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
|
241 |
+
step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
|
242 |
+
latent_length, init_timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size
|
243 |
+
)
|
244 |
+
sample_schedulers = []
|
245 |
+
for _ in range(latent_length):
|
246 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
247 |
+
num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
|
248 |
+
)
|
249 |
+
sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift)
|
250 |
+
sample_schedulers.append(sample_scheduler)
|
251 |
+
sample_schedulers_counter = [0] * latent_length
|
252 |
+
self.transformer.to(self.device)
|
253 |
+
for i, timestep_i in enumerate(tqdm(step_matrix)):
|
254 |
+
update_mask_i = step_update_mask[i]
|
255 |
+
valid_interval_i = valid_interval[i]
|
256 |
+
valid_interval_start, valid_interval_end = valid_interval_i
|
257 |
+
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
|
258 |
+
latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
|
259 |
+
if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
|
260 |
+
noise_factor = 0.001 * addnoise_condition
|
261 |
+
timestep_for_noised_condition = addnoise_condition
|
262 |
+
latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
|
263 |
+
latent_model_input[0][:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor)
|
264 |
+
+ torch.randn_like(latent_model_input[0][:, valid_interval_start:predix_video_latent_length])
|
265 |
+
* noise_factor
|
266 |
+
)
|
267 |
+
timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
|
268 |
+
if not self.do_classifier_free_guidance:
|
269 |
+
noise_pred = self.transformer(
|
270 |
+
torch.stack([latent_model_input[0]]),
|
271 |
+
t=timestep,
|
272 |
+
context=prompt_embeds,
|
273 |
+
fps=fps_embeds,
|
274 |
+
**i2v_extra_kwrags,
|
275 |
+
)[0]
|
276 |
+
else:
|
277 |
+
noise_pred_cond = self.transformer(
|
278 |
+
torch.stack([latent_model_input[0]]),
|
279 |
+
t=timestep,
|
280 |
+
context=prompt_embeds,
|
281 |
+
fps=fps_embeds,
|
282 |
+
**i2v_extra_kwrags,
|
283 |
+
)[0]
|
284 |
+
noise_pred_uncond = self.transformer(
|
285 |
+
torch.stack([latent_model_input[0]]),
|
286 |
+
t=timestep,
|
287 |
+
context=negative_prompt_embeds,
|
288 |
+
fps=fps_embeds,
|
289 |
+
**i2v_extra_kwrags,
|
290 |
+
)[0]
|
291 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
292 |
+
for idx in range(valid_interval_start, valid_interval_end):
|
293 |
+
if update_mask_i[idx].item():
|
294 |
+
latents[0][:, idx] = sample_schedulers[idx].step(
|
295 |
+
noise_pred[:, idx - valid_interval_start],
|
296 |
+
timestep_i[idx],
|
297 |
+
latents[0][:, idx],
|
298 |
+
return_dict=False,
|
299 |
+
generator=generator,
|
300 |
+
)[0]
|
301 |
+
sample_schedulers_counter[idx] += 1
|
302 |
+
if self.offload:
|
303 |
+
self.transformer.cpu()
|
304 |
+
torch.cuda.empty_cache()
|
305 |
+
x0 = latents[0].unsqueeze(0)
|
306 |
+
videos = self.vae.decode(x0)
|
307 |
+
videos = (videos / 2 + 0.5).clamp(0, 1)
|
308 |
+
videos = [video for video in videos]
|
309 |
+
videos = [video.permute(1, 2, 3, 0) * 255 for video in videos]
|
310 |
+
videos = [video.cpu().numpy().astype(np.uint8) for video in videos]
|
311 |
+
return videos
|
312 |
+
else:
|
313 |
+
# long video generation
|
314 |
+
base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
|
315 |
+
overlap_history_frames = (overlap_history - 1) // 4 + 1
|
316 |
+
n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1
|
317 |
+
print(f"n_iter:{n_iter}")
|
318 |
+
output_video = None
|
319 |
+
for i in range(n_iter):
|
320 |
+
if output_video is not None: # i !=0
|
321 |
+
prefix_video = output_video[:, -overlap_history:].to(prompt_embeds.device)
|
322 |
+
prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
|
323 |
+
if prefix_video[0].shape[1] % causal_block_size != 0:
|
324 |
+
truncate_len = prefix_video[0].shape[1] % causal_block_size
|
325 |
+
print("the length of prefix video is truncated for the casual block size alignment.")
|
326 |
+
prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
|
327 |
+
predix_video_latent_length = prefix_video[0].shape[1]
|
328 |
+
finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames
|
329 |
+
left_frame_num = latent_length - finished_frame_num
|
330 |
+
base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames)
|
331 |
+
if ar_step > 0 and self.transformer.enable_teacache:
|
332 |
+
num_steps = num_inference_steps + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step
|
333 |
+
self.transformer.num_steps = num_steps
|
334 |
+
else: # i == 0
|
335 |
+
base_num_frames_iter = base_num_frames
|
336 |
+
latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
|
337 |
+
latents = self.prepare_latents(
|
338 |
+
latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator
|
339 |
+
)
|
340 |
+
latents = [latents]
|
341 |
+
if prefix_video is not None:
|
342 |
+
latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype)
|
343 |
+
step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
|
344 |
+
base_num_frames_iter,
|
345 |
+
init_timesteps,
|
346 |
+
base_num_frames_iter,
|
347 |
+
ar_step,
|
348 |
+
predix_video_latent_length,
|
349 |
+
causal_block_size,
|
350 |
+
)
|
351 |
+
sample_schedulers = []
|
352 |
+
for _ in range(base_num_frames_iter):
|
353 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
354 |
+
num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
|
355 |
+
)
|
356 |
+
sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift)
|
357 |
+
sample_schedulers.append(sample_scheduler)
|
358 |
+
sample_schedulers_counter = [0] * base_num_frames_iter
|
359 |
+
self.transformer.to(self.device)
|
360 |
+
for i, timestep_i in enumerate(tqdm(step_matrix)):
|
361 |
+
update_mask_i = step_update_mask[i]
|
362 |
+
valid_interval_i = valid_interval[i]
|
363 |
+
valid_interval_start, valid_interval_end = valid_interval_i
|
364 |
+
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
|
365 |
+
latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
|
366 |
+
if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
|
367 |
+
noise_factor = 0.001 * addnoise_condition
|
368 |
+
timestep_for_noised_condition = addnoise_condition
|
369 |
+
latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
|
370 |
+
latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
|
371 |
+
* (1.0 - noise_factor)
|
372 |
+
+ torch.randn_like(
|
373 |
+
latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
|
374 |
+
)
|
375 |
+
* noise_factor
|
376 |
+
)
|
377 |
+
timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
|
378 |
+
if not self.do_classifier_free_guidance:
|
379 |
+
noise_pred = self.transformer(
|
380 |
+
torch.stack([latent_model_input[0]]),
|
381 |
+
t=timestep,
|
382 |
+
context=prompt_embeds,
|
383 |
+
fps=fps_embeds,
|
384 |
+
**i2v_extra_kwrags,
|
385 |
+
)[0]
|
386 |
+
else:
|
387 |
+
noise_pred_cond = self.transformer(
|
388 |
+
torch.stack([latent_model_input[0]]),
|
389 |
+
t=timestep,
|
390 |
+
context=prompt_embeds,
|
391 |
+
fps=fps_embeds,
|
392 |
+
**i2v_extra_kwrags,
|
393 |
+
)[0]
|
394 |
+
noise_pred_uncond = self.transformer(
|
395 |
+
torch.stack([latent_model_input[0]]),
|
396 |
+
t=timestep,
|
397 |
+
context=negative_prompt_embeds,
|
398 |
+
fps=fps_embeds,
|
399 |
+
**i2v_extra_kwrags,
|
400 |
+
)[0]
|
401 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
402 |
+
for idx in range(valid_interval_start, valid_interval_end):
|
403 |
+
if update_mask_i[idx].item():
|
404 |
+
latents[0][:, idx] = sample_schedulers[idx].step(
|
405 |
+
noise_pred[:, idx - valid_interval_start],
|
406 |
+
timestep_i[idx],
|
407 |
+
latents[0][:, idx],
|
408 |
+
return_dict=False,
|
409 |
+
generator=generator,
|
410 |
+
)[0]
|
411 |
+
sample_schedulers_counter[idx] += 1
|
412 |
+
if self.offload:
|
413 |
+
self.transformer.cpu()
|
414 |
+
torch.cuda.empty_cache()
|
415 |
+
x0 = latents[0].unsqueeze(0)
|
416 |
+
videos = [self.vae.decode(x0)[0]]
|
417 |
+
if output_video is None:
|
418 |
+
output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w
|
419 |
+
else:
|
420 |
+
output_video = torch.cat(
|
421 |
+
[output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1
|
422 |
+
) # c, f, h, w
|
423 |
+
output_video = [(output_video / 2 + 0.5).clamp(0, 1)]
|
424 |
+
output_video = [video for video in output_video]
|
425 |
+
output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video]
|
426 |
+
output_video = [video.cpu().numpy().astype(np.uint8) for video in output_video]
|
427 |
+
return output_video
|
skyreels_v2_infer/pipelines/image2video_pipeline.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
from typing import Optional
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from diffusers.image_processor import PipelineImageInput
|
9 |
+
from diffusers.video_processor import VideoProcessor
|
10 |
+
from PIL import Image
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from ..modules import get_image_encoder
|
14 |
+
from ..modules import get_text_encoder
|
15 |
+
from ..modules import get_transformer
|
16 |
+
from ..modules import get_vae
|
17 |
+
from ..scheduler.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
18 |
+
|
19 |
+
|
20 |
+
def resizecrop(image: Image.Image, th, tw):
|
21 |
+
w, h = image.size
|
22 |
+
if w == tw and h == th:
|
23 |
+
return image
|
24 |
+
if h / w > th / tw:
|
25 |
+
new_w = int(w)
|
26 |
+
new_h = int(new_w * th / tw)
|
27 |
+
else:
|
28 |
+
new_h = int(h)
|
29 |
+
new_w = int(new_h * tw / th)
|
30 |
+
left = (w - new_w) / 2
|
31 |
+
top = (h - new_h) / 2
|
32 |
+
right = (w + new_w) / 2
|
33 |
+
bottom = (h + new_h) / 2
|
34 |
+
image = image.crop((left, top, right, bottom))
|
35 |
+
return image
|
36 |
+
|
37 |
+
|
38 |
+
class Image2VideoPipeline:
|
39 |
+
def __init__(
|
40 |
+
self, model_path, dit_path, device: str = "cuda", weight_dtype=torch.bfloat16, use_usp=False, offload=False
|
41 |
+
):
|
42 |
+
load_device = "cpu" if offload else device
|
43 |
+
self.transformer = get_transformer(dit_path, load_device, weight_dtype)
|
44 |
+
vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth")
|
45 |
+
self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
|
46 |
+
self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype)
|
47 |
+
self.clip = get_image_encoder(model_path, load_device, weight_dtype)
|
48 |
+
self.sp_size = 1
|
49 |
+
self.device = device
|
50 |
+
self.offload = offload
|
51 |
+
self.video_processor = VideoProcessor(vae_scale_factor=16)
|
52 |
+
if use_usp:
|
53 |
+
from xfuser.core.distributed import get_sequence_parallel_world_size
|
54 |
+
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
|
55 |
+
import types
|
56 |
+
|
57 |
+
for block in self.transformer.blocks:
|
58 |
+
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
59 |
+
self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer)
|
60 |
+
self.sp_size = get_sequence_parallel_world_size()
|
61 |
+
|
62 |
+
self.scheduler = FlowUniPCMultistepScheduler()
|
63 |
+
self.vae_stride = (4, 8, 8)
|
64 |
+
self.patch_size = (1, 2, 2)
|
65 |
+
|
66 |
+
@torch.no_grad()
|
67 |
+
def __call__(
|
68 |
+
self,
|
69 |
+
image: PipelineImageInput,
|
70 |
+
prompt: Union[str, List[str]] = None,
|
71 |
+
negative_prompt: Union[str, List[str]] = None,
|
72 |
+
height: int = 544,
|
73 |
+
width: int = 960,
|
74 |
+
num_frames: int = 97,
|
75 |
+
num_inference_steps: int = 50,
|
76 |
+
guidance_scale: float = 5.0,
|
77 |
+
shift: float = 5.0,
|
78 |
+
generator: Optional[torch.Generator] = None,
|
79 |
+
):
|
80 |
+
F = num_frames
|
81 |
+
|
82 |
+
latent_height = height // 8 // 2 * 2
|
83 |
+
latent_width = width // 8 // 2 * 2
|
84 |
+
latent_length = (F - 1) // 4 + 1
|
85 |
+
|
86 |
+
h = latent_height * 8
|
87 |
+
w = latent_width * 8
|
88 |
+
|
89 |
+
img = self.video_processor.preprocess(image, height=h, width=w)
|
90 |
+
|
91 |
+
img = img.to(device=self.device, dtype=self.transformer.dtype)
|
92 |
+
|
93 |
+
padding_video = torch.zeros(img.shape[0], 3, F - 1, h, w, device=self.device)
|
94 |
+
|
95 |
+
img = img.unsqueeze(2)
|
96 |
+
img_cond = torch.concat([img, padding_video], dim=2)
|
97 |
+
img_cond = self.vae.encode(img_cond)
|
98 |
+
mask = torch.ones_like(img_cond)
|
99 |
+
mask[:, :, 1:] = 0
|
100 |
+
y = torch.cat([mask[:, :4], img_cond], dim=1)
|
101 |
+
self.clip.to(self.device)
|
102 |
+
clip_context = self.clip.encode_video(img)
|
103 |
+
if self.offload:
|
104 |
+
self.clip.cpu()
|
105 |
+
torch.cuda.empty_cache()
|
106 |
+
|
107 |
+
# preprocess
|
108 |
+
self.text_encoder.to(self.device)
|
109 |
+
context = self.text_encoder.encode(prompt).to(self.device)
|
110 |
+
context_null = self.text_encoder.encode(negative_prompt).to(self.device)
|
111 |
+
if self.offload:
|
112 |
+
self.text_encoder.cpu()
|
113 |
+
torch.cuda.empty_cache()
|
114 |
+
|
115 |
+
latent = torch.randn(
|
116 |
+
16, latent_length, latent_height, latent_width, dtype=torch.float32, generator=generator, device=self.device
|
117 |
+
)
|
118 |
+
|
119 |
+
self.transformer.to(self.device)
|
120 |
+
with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad():
|
121 |
+
self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
|
122 |
+
timesteps = self.scheduler.timesteps
|
123 |
+
|
124 |
+
arg_c = {
|
125 |
+
"context": context,
|
126 |
+
"clip_fea": clip_context,
|
127 |
+
"y": y,
|
128 |
+
}
|
129 |
+
|
130 |
+
arg_null = {
|
131 |
+
"context": context_null,
|
132 |
+
"clip_fea": clip_context,
|
133 |
+
"y": y,
|
134 |
+
}
|
135 |
+
|
136 |
+
self.transformer.to(self.device)
|
137 |
+
for _, t in enumerate(tqdm(timesteps)):
|
138 |
+
latent_model_input = torch.stack([latent]).to(self.device)
|
139 |
+
timestep = torch.stack([t]).to(self.device)
|
140 |
+
noise_pred_cond = self.transformer(latent_model_input, t=timestep, **arg_c)[0].to(self.device)
|
141 |
+
noise_pred_uncond = self.transformer(latent_model_input, t=timestep, **arg_null)[0].to(self.device)
|
142 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
143 |
+
|
144 |
+
temp_x0 = self.scheduler.step(
|
145 |
+
noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False, generator=generator
|
146 |
+
)[0]
|
147 |
+
latent = temp_x0.squeeze(0)
|
148 |
+
if self.offload:
|
149 |
+
self.transformer.cpu()
|
150 |
+
torch.cuda.empty_cache()
|
151 |
+
videos = self.vae.decode(latent)
|
152 |
+
videos = (videos / 2 + 0.5).clamp(0, 1)
|
153 |
+
videos = [video for video in videos]
|
154 |
+
videos = [video.permute(1, 2, 3, 0) * 255 for video in videos]
|
155 |
+
videos = [video.cpu().numpy().astype(np.uint8) for video in videos]
|
156 |
+
return videos
|
skyreels_v2_infer/pipelines/prompt_enhancer.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
|
4 |
+
sys_prompt = """
|
5 |
+
Transform the short prompt into a detailed video-generation caption using this structure:
|
6 |
+
Opening shot type (long/medium/close-up/extreme close-up/full shot)
|
7 |
+
Primary subject(s) with vivid attributes (colors, textures, actions, interactions)
|
8 |
+
Dynamic elements (movement, transitions, or changes over time, e.g., 'gradually lowers,' 'begins to climb,' 'camera moves toward...')
|
9 |
+
Scene composition (background, environment, spatial relationships)
|
10 |
+
Lighting/atmosphere (natural/artificial, time of day, mood)
|
11 |
+
Camera motion (zooms, pans, static/handheld shots) if applicable.
|
12 |
+
|
13 |
+
Pattern Summary from Examples:
|
14 |
+
[Shot Type] of [Subject+Action] + [Detailed Subject Description] + [Environmental Context] + [Lighting Conditions] + [Camera Movement]
|
15 |
+
|
16 |
+
One case:
|
17 |
+
Short prompt: a person is playing football
|
18 |
+
Long prompt: Medium shot of a young athlete in a red jersey sprinting across a muddy field, dribbling a soccer ball with precise footwork. The player glances toward the goalpost, adjusts their stance, and kicks the ball forcefully into the net. Raindrops fall lightly, creating reflections under stadium floodlights. The camera follows the ball’s trajectory in a smooth pan.
|
19 |
+
|
20 |
+
Note: If the subject is stationary, incorporate camera movement to ensure the generated video remains dynamic.
|
21 |
+
|
22 |
+
Now expand this short prompt: [{}]. Please only output the final long prompt in English.
|
23 |
+
"""
|
24 |
+
|
25 |
+
class PromptEnhancer:
|
26 |
+
def __init__(self, model_name="Qwen/Qwen2.5-32B-Instruct"):
|
27 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
28 |
+
model_name,
|
29 |
+
torch_dtype="auto",
|
30 |
+
device_map="cuda:0",
|
31 |
+
)
|
32 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
33 |
+
|
34 |
+
def __call__(self, prompt):
|
35 |
+
prompt = prompt.strip()
|
36 |
+
prompt = sys_prompt.format(prompt)
|
37 |
+
messages = [
|
38 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
39 |
+
{"role": "user", "content": prompt}
|
40 |
+
]
|
41 |
+
text = self.tokenizer.apply_chat_template(
|
42 |
+
messages,
|
43 |
+
tokenize=False,
|
44 |
+
add_generation_prompt=True
|
45 |
+
)
|
46 |
+
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
|
47 |
+
generated_ids = self.model.generate(
|
48 |
+
**model_inputs,
|
49 |
+
max_new_tokens=2048,
|
50 |
+
)
|
51 |
+
generated_ids = [
|
52 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
53 |
+
]
|
54 |
+
rewritten_prompt = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
55 |
+
return rewritten_prompt
|
56 |
+
|
57 |
+
if __name__ == '__main__':
|
58 |
+
parser = argparse.ArgumentParser()
|
59 |
+
parser.add_argument("--prompt", type=str, default="In a still frame, a stop sign")
|
60 |
+
args = parser.parse_args()
|
61 |
+
|
62 |
+
prompt_enhancer = PromptEnhancer()
|
63 |
+
enhanced_prompt = prompt_enhancer(args.prompt)
|
64 |
+
print(f'Original prompt: {args.prompt}')
|
65 |
+
print(f'Enhanced prompt: {enhanced_prompt}')
|
skyreels_v2_infer/pipelines/text2video_pipeline.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
from typing import Optional
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from diffusers.video_processor import VideoProcessor
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from ..modules import get_text_encoder
|
12 |
+
from ..modules import get_transformer
|
13 |
+
from ..modules import get_vae
|
14 |
+
from ..scheduler.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
15 |
+
|
16 |
+
|
17 |
+
class Text2VideoPipeline:
|
18 |
+
def __init__(
|
19 |
+
self, model_path, dit_path, device: str = "cuda", weight_dtype=torch.bfloat16, use_usp=False, offload=False
|
20 |
+
):
|
21 |
+
load_device = "cpu" if offload else device
|
22 |
+
self.transformer = get_transformer(dit_path, load_device, weight_dtype)
|
23 |
+
vae_model_path = os.path.join(model_path, "Wan2.1_VAE.pth")
|
24 |
+
self.vae = get_vae(vae_model_path, device, weight_dtype=torch.float32)
|
25 |
+
self.text_encoder = get_text_encoder(model_path, load_device, weight_dtype)
|
26 |
+
self.video_processor = VideoProcessor(vae_scale_factor=16)
|
27 |
+
self.sp_size = 1
|
28 |
+
self.device = device
|
29 |
+
self.offload = offload
|
30 |
+
if use_usp:
|
31 |
+
from xfuser.core.distributed import get_sequence_parallel_world_size
|
32 |
+
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
|
33 |
+
import types
|
34 |
+
|
35 |
+
for block in self.transformer.blocks:
|
36 |
+
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
37 |
+
self.transformer.forward = types.MethodType(usp_dit_forward, self.transformer)
|
38 |
+
self.sp_size = get_sequence_parallel_world_size()
|
39 |
+
|
40 |
+
self.scheduler = FlowUniPCMultistepScheduler()
|
41 |
+
self.vae_stride = (4, 8, 8)
|
42 |
+
self.patch_size = (1, 2, 2)
|
43 |
+
|
44 |
+
@torch.no_grad()
|
45 |
+
def __call__(
|
46 |
+
self,
|
47 |
+
prompt: Union[str, List[str]] = None,
|
48 |
+
negative_prompt: Union[str, List[str]] = None,
|
49 |
+
width: int = 544,
|
50 |
+
height: int = 960,
|
51 |
+
num_frames: int = 97,
|
52 |
+
num_inference_steps: int = 50,
|
53 |
+
guidance_scale: float = 5.0,
|
54 |
+
shift: float = 5.0,
|
55 |
+
generator: Optional[torch.Generator] = None,
|
56 |
+
):
|
57 |
+
# preprocess
|
58 |
+
F = num_frames
|
59 |
+
target_shape = (
|
60 |
+
self.vae.vae.z_dim,
|
61 |
+
(F - 1) // self.vae_stride[0] + 1,
|
62 |
+
height // self.vae_stride[1],
|
63 |
+
width // self.vae_stride[2],
|
64 |
+
)
|
65 |
+
self.text_encoder.to(self.device)
|
66 |
+
context = self.text_encoder.encode(prompt).to(self.device)
|
67 |
+
context_null = self.text_encoder.encode(negative_prompt).to(self.device)
|
68 |
+
if self.offload:
|
69 |
+
self.text_encoder.cpu()
|
70 |
+
torch.cuda.empty_cache()
|
71 |
+
|
72 |
+
latents = [
|
73 |
+
torch.randn(
|
74 |
+
target_shape[0],
|
75 |
+
target_shape[1],
|
76 |
+
target_shape[2],
|
77 |
+
target_shape[3],
|
78 |
+
dtype=torch.float32,
|
79 |
+
device=self.device,
|
80 |
+
generator=generator,
|
81 |
+
)
|
82 |
+
]
|
83 |
+
|
84 |
+
# evaluation mode
|
85 |
+
self.transformer.to(self.device)
|
86 |
+
with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad():
|
87 |
+
self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
|
88 |
+
timesteps = self.scheduler.timesteps
|
89 |
+
|
90 |
+
for _, t in enumerate(tqdm(timesteps)):
|
91 |
+
latent_model_input = torch.stack(latents)
|
92 |
+
timestep = torch.stack([t])
|
93 |
+
noise_pred_cond = self.transformer(latent_model_input, t=timestep, context=context)[0]
|
94 |
+
noise_pred_uncond = self.transformer(latent_model_input, t=timestep, context=context_null)[0]
|
95 |
+
|
96 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
97 |
+
|
98 |
+
temp_x0 = self.scheduler.step(
|
99 |
+
noise_pred.unsqueeze(0), t, latents[0].unsqueeze(0), return_dict=False, generator=generator
|
100 |
+
)[0]
|
101 |
+
latents = [temp_x0.squeeze(0)]
|
102 |
+
if self.offload:
|
103 |
+
self.transformer.cpu()
|
104 |
+
torch.cuda.empty_cache()
|
105 |
+
videos = self.vae.decode(latents[0])
|
106 |
+
videos = (videos / 2 + 0.5).clamp(0, 1)
|
107 |
+
videos = [video for video in videos]
|
108 |
+
videos = [video.permute(1, 2, 3, 0) * 255 for video in videos]
|
109 |
+
videos = [video.cpu().numpy().astype(np.uint8) for video in videos]
|
110 |
+
return videos
|
skyreels_v2_infer/scheduler/__init__.py
ADDED
File without changes
|
skyreels_v2_infer/scheduler/fm_solvers_unipc.py
ADDED
@@ -0,0 +1,759 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
|
2 |
+
# Convert unipc for flow matching
|
3 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
4 |
+
import math
|
5 |
+
from typing import List
|
6 |
+
from typing import Optional
|
7 |
+
from typing import Tuple
|
8 |
+
from typing import Union
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from diffusers.configuration_utils import ConfigMixin
|
13 |
+
from diffusers.configuration_utils import register_to_config
|
14 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers
|
15 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
16 |
+
from diffusers.schedulers.scheduling_utils import SchedulerOutput
|
17 |
+
from diffusers.utils import deprecate
|
18 |
+
|
19 |
+
|
20 |
+
class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
21 |
+
"""
|
22 |
+
`UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
|
23 |
+
|
24 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
25 |
+
methods the library implements for all schedulers such as loading and saving.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
num_train_timesteps (`int`, defaults to 1000):
|
29 |
+
The number of diffusion steps to train the model.
|
30 |
+
solver_order (`int`, default `2`):
|
31 |
+
The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
|
32 |
+
due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
|
33 |
+
unconditional sampling.
|
34 |
+
prediction_type (`str`, defaults to "flow_prediction"):
|
35 |
+
Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
|
36 |
+
the flow of the diffusion process.
|
37 |
+
thresholding (`bool`, defaults to `False`):
|
38 |
+
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
39 |
+
as Stable Diffusion.
|
40 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
41 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
42 |
+
sample_max_value (`float`, defaults to 1.0):
|
43 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
|
44 |
+
predict_x0 (`bool`, defaults to `True`):
|
45 |
+
Whether to use the updating algorithm on the predicted x0.
|
46 |
+
solver_type (`str`, default `bh2`):
|
47 |
+
Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
|
48 |
+
otherwise.
|
49 |
+
lower_order_final (`bool`, default `True`):
|
50 |
+
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
51 |
+
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
52 |
+
disable_corrector (`list`, default `[]`):
|
53 |
+
Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
|
54 |
+
and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
|
55 |
+
usually disabled during the first few steps.
|
56 |
+
solver_p (`SchedulerMixin`, default `None`):
|
57 |
+
Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
|
58 |
+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
59 |
+
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
60 |
+
the sigmas are determined according to a sequence of noise levels {σi}.
|
61 |
+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
62 |
+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
63 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
64 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
65 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
66 |
+
steps_offset (`int`, defaults to 0):
|
67 |
+
An offset added to the inference steps, as required by some model families.
|
68 |
+
final_sigmas_type (`str`, defaults to `"zero"`):
|
69 |
+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
70 |
+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
71 |
+
"""
|
72 |
+
|
73 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
74 |
+
order = 1
|
75 |
+
|
76 |
+
@register_to_config
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
num_train_timesteps: int = 1000,
|
80 |
+
solver_order: int = 2,
|
81 |
+
prediction_type: str = "flow_prediction",
|
82 |
+
shift: Optional[float] = 1.0,
|
83 |
+
use_dynamic_shifting=False,
|
84 |
+
thresholding: bool = False,
|
85 |
+
dynamic_thresholding_ratio: float = 0.995,
|
86 |
+
sample_max_value: float = 1.0,
|
87 |
+
predict_x0: bool = True,
|
88 |
+
solver_type: str = "bh2",
|
89 |
+
lower_order_final: bool = True,
|
90 |
+
disable_corrector: List[int] = [],
|
91 |
+
solver_p: SchedulerMixin = None,
|
92 |
+
timestep_spacing: str = "linspace",
|
93 |
+
steps_offset: int = 0,
|
94 |
+
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
95 |
+
):
|
96 |
+
|
97 |
+
if solver_type not in ["bh1", "bh2"]:
|
98 |
+
if solver_type in ["midpoint", "heun", "logrho"]:
|
99 |
+
self.register_to_config(solver_type="bh2")
|
100 |
+
else:
|
101 |
+
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
102 |
+
|
103 |
+
self.predict_x0 = predict_x0
|
104 |
+
# setable values
|
105 |
+
self.num_inference_steps = None
|
106 |
+
alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
|
107 |
+
sigmas = 1.0 - alphas
|
108 |
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
|
109 |
+
|
110 |
+
if not use_dynamic_shifting:
|
111 |
+
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
|
112 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
|
113 |
+
|
114 |
+
self.sigmas = sigmas
|
115 |
+
self.timesteps = sigmas * num_train_timesteps
|
116 |
+
|
117 |
+
self.model_outputs = [None] * solver_order
|
118 |
+
self.timestep_list = [None] * solver_order
|
119 |
+
self.lower_order_nums = 0
|
120 |
+
self.disable_corrector = disable_corrector
|
121 |
+
self.solver_p = solver_p
|
122 |
+
self.last_sample = None
|
123 |
+
self._step_index = None
|
124 |
+
self._begin_index = None
|
125 |
+
|
126 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
127 |
+
self.sigma_min = self.sigmas[-1].item()
|
128 |
+
self.sigma_max = self.sigmas[0].item()
|
129 |
+
|
130 |
+
@property
|
131 |
+
def step_index(self):
|
132 |
+
"""
|
133 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
134 |
+
"""
|
135 |
+
return self._step_index
|
136 |
+
|
137 |
+
@property
|
138 |
+
def begin_index(self):
|
139 |
+
"""
|
140 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
141 |
+
"""
|
142 |
+
return self._begin_index
|
143 |
+
|
144 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
145 |
+
def set_begin_index(self, begin_index: int = 0):
|
146 |
+
"""
|
147 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
begin_index (`int`):
|
151 |
+
The begin index for the scheduler.
|
152 |
+
"""
|
153 |
+
self._begin_index = begin_index
|
154 |
+
|
155 |
+
# Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
|
156 |
+
def set_timesteps(
|
157 |
+
self,
|
158 |
+
num_inference_steps: Union[int, None] = None,
|
159 |
+
device: Union[str, torch.device] = None,
|
160 |
+
sigmas: Optional[List[float]] = None,
|
161 |
+
mu: Optional[Union[float, None]] = None,
|
162 |
+
shift: Optional[Union[float, None]] = None,
|
163 |
+
):
|
164 |
+
"""
|
165 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
166 |
+
Args:
|
167 |
+
num_inference_steps (`int`):
|
168 |
+
Total number of the spacing of the time steps.
|
169 |
+
device (`str` or `torch.device`, *optional*):
|
170 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
171 |
+
"""
|
172 |
+
|
173 |
+
if self.config.use_dynamic_shifting and mu is None:
|
174 |
+
raise ValueError(" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
|
175 |
+
|
176 |
+
if sigmas is None:
|
177 |
+
sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] # pyright: ignore
|
178 |
+
|
179 |
+
if self.config.use_dynamic_shifting:
|
180 |
+
sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
|
181 |
+
else:
|
182 |
+
if shift is None:
|
183 |
+
shift = self.config.shift
|
184 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
|
185 |
+
|
186 |
+
if self.config.final_sigmas_type == "sigma_min":
|
187 |
+
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
188 |
+
elif self.config.final_sigmas_type == "zero":
|
189 |
+
sigma_last = 0
|
190 |
+
else:
|
191 |
+
raise ValueError(
|
192 |
+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
193 |
+
)
|
194 |
+
|
195 |
+
timesteps = sigmas * self.config.num_train_timesteps
|
196 |
+
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) # pyright: ignore
|
197 |
+
|
198 |
+
self.sigmas = torch.from_numpy(sigmas)
|
199 |
+
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
|
200 |
+
|
201 |
+
self.num_inference_steps = len(timesteps)
|
202 |
+
|
203 |
+
self.model_outputs = [
|
204 |
+
None,
|
205 |
+
] * self.config.solver_order
|
206 |
+
self.lower_order_nums = 0
|
207 |
+
self.last_sample = None
|
208 |
+
if self.solver_p:
|
209 |
+
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
|
210 |
+
|
211 |
+
# add an index counter for schedulers that allow duplicated timesteps
|
212 |
+
self._step_index = None
|
213 |
+
self._begin_index = None
|
214 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
215 |
+
|
216 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
217 |
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
218 |
+
"""
|
219 |
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
220 |
+
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
221 |
+
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
222 |
+
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
223 |
+
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
224 |
+
|
225 |
+
https://arxiv.org/abs/2205.11487
|
226 |
+
"""
|
227 |
+
dtype = sample.dtype
|
228 |
+
batch_size, channels, *remaining_dims = sample.shape
|
229 |
+
|
230 |
+
if dtype not in (torch.float32, torch.float64):
|
231 |
+
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
232 |
+
|
233 |
+
# Flatten sample for doing quantile calculation along each image
|
234 |
+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
235 |
+
|
236 |
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
237 |
+
|
238 |
+
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
239 |
+
s = torch.clamp(
|
240 |
+
s, min=1, max=self.config.sample_max_value
|
241 |
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
242 |
+
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
243 |
+
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
244 |
+
|
245 |
+
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
246 |
+
sample = sample.to(dtype)
|
247 |
+
|
248 |
+
return sample
|
249 |
+
|
250 |
+
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
|
251 |
+
def _sigma_to_t(self, sigma):
|
252 |
+
return sigma * self.config.num_train_timesteps
|
253 |
+
|
254 |
+
def _sigma_to_alpha_sigma_t(self, sigma):
|
255 |
+
return 1 - sigma, sigma
|
256 |
+
|
257 |
+
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
|
258 |
+
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
259 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
260 |
+
|
261 |
+
def convert_model_output(
|
262 |
+
self,
|
263 |
+
model_output: torch.Tensor,
|
264 |
+
*args,
|
265 |
+
sample: torch.Tensor = None,
|
266 |
+
**kwargs,
|
267 |
+
) -> torch.Tensor:
|
268 |
+
r"""
|
269 |
+
Convert the model output to the corresponding type the UniPC algorithm needs.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
model_output (`torch.Tensor`):
|
273 |
+
The direct output from the learned diffusion model.
|
274 |
+
timestep (`int`):
|
275 |
+
The current discrete timestep in the diffusion chain.
|
276 |
+
sample (`torch.Tensor`):
|
277 |
+
A current instance of a sample created by the diffusion process.
|
278 |
+
|
279 |
+
Returns:
|
280 |
+
`torch.Tensor`:
|
281 |
+
The converted model output.
|
282 |
+
"""
|
283 |
+
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
284 |
+
if sample is None:
|
285 |
+
if len(args) > 1:
|
286 |
+
sample = args[1]
|
287 |
+
else:
|
288 |
+
raise ValueError("missing `sample` as a required keyward argument")
|
289 |
+
if timestep is not None:
|
290 |
+
deprecate(
|
291 |
+
"timesteps",
|
292 |
+
"1.0.0",
|
293 |
+
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
294 |
+
)
|
295 |
+
|
296 |
+
sigma = self.sigmas[self.step_index]
|
297 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
298 |
+
|
299 |
+
if self.predict_x0:
|
300 |
+
if self.config.prediction_type == "flow_prediction":
|
301 |
+
sigma_t = self.sigmas[self.step_index]
|
302 |
+
x0_pred = sample - sigma_t * model_output
|
303 |
+
else:
|
304 |
+
raise ValueError(
|
305 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
|
306 |
+
" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
|
307 |
+
)
|
308 |
+
|
309 |
+
if self.config.thresholding:
|
310 |
+
x0_pred = self._threshold_sample(x0_pred)
|
311 |
+
|
312 |
+
return x0_pred
|
313 |
+
else:
|
314 |
+
if self.config.prediction_type == "flow_prediction":
|
315 |
+
sigma_t = self.sigmas[self.step_index]
|
316 |
+
epsilon = sample - (1 - sigma_t) * model_output
|
317 |
+
else:
|
318 |
+
raise ValueError(
|
319 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
|
320 |
+
" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
|
321 |
+
)
|
322 |
+
|
323 |
+
if self.config.thresholding:
|
324 |
+
sigma_t = self.sigmas[self.step_index]
|
325 |
+
x0_pred = sample - sigma_t * model_output
|
326 |
+
x0_pred = self._threshold_sample(x0_pred)
|
327 |
+
epsilon = model_output + x0_pred
|
328 |
+
|
329 |
+
return epsilon
|
330 |
+
|
331 |
+
def multistep_uni_p_bh_update(
|
332 |
+
self,
|
333 |
+
model_output: torch.Tensor,
|
334 |
+
*args,
|
335 |
+
sample: torch.Tensor = None,
|
336 |
+
order: int = None, # pyright: ignore
|
337 |
+
**kwargs,
|
338 |
+
) -> torch.Tensor:
|
339 |
+
"""
|
340 |
+
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
|
341 |
+
|
342 |
+
Args:
|
343 |
+
model_output (`torch.Tensor`):
|
344 |
+
The direct output from the learned diffusion model at the current timestep.
|
345 |
+
prev_timestep (`int`):
|
346 |
+
The previous discrete timestep in the diffusion chain.
|
347 |
+
sample (`torch.Tensor`):
|
348 |
+
A current instance of a sample created by the diffusion process.
|
349 |
+
order (`int`):
|
350 |
+
The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
|
351 |
+
|
352 |
+
Returns:
|
353 |
+
`torch.Tensor`:
|
354 |
+
The sample tensor at the previous timestep.
|
355 |
+
"""
|
356 |
+
prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
|
357 |
+
if sample is None:
|
358 |
+
if len(args) > 1:
|
359 |
+
sample = args[1]
|
360 |
+
else:
|
361 |
+
raise ValueError(" missing `sample` as a required keyward argument")
|
362 |
+
if order is None:
|
363 |
+
if len(args) > 2:
|
364 |
+
order = args[2]
|
365 |
+
else:
|
366 |
+
raise ValueError(" missing `order` as a required keyward argument")
|
367 |
+
if prev_timestep is not None:
|
368 |
+
deprecate(
|
369 |
+
"prev_timestep",
|
370 |
+
"1.0.0",
|
371 |
+
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
372 |
+
)
|
373 |
+
model_output_list = self.model_outputs
|
374 |
+
|
375 |
+
s0 = self.timestep_list[-1]
|
376 |
+
m0 = model_output_list[-1]
|
377 |
+
x = sample
|
378 |
+
|
379 |
+
if self.solver_p:
|
380 |
+
x_t = self.solver_p.step(model_output, s0, x).prev_sample
|
381 |
+
return x_t
|
382 |
+
|
383 |
+
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] # pyright: ignore
|
384 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
385 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
386 |
+
|
387 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
388 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
389 |
+
|
390 |
+
h = lambda_t - lambda_s0
|
391 |
+
device = sample.device
|
392 |
+
|
393 |
+
rks = []
|
394 |
+
D1s = []
|
395 |
+
for i in range(1, order):
|
396 |
+
si = self.step_index - i # pyright: ignore
|
397 |
+
mi = model_output_list[-(i + 1)]
|
398 |
+
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
399 |
+
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
400 |
+
rk = (lambda_si - lambda_s0) / h
|
401 |
+
rks.append(rk)
|
402 |
+
D1s.append((mi - m0) / rk) # pyright: ignore
|
403 |
+
|
404 |
+
rks.append(1.0)
|
405 |
+
rks = torch.tensor(rks, device=device)
|
406 |
+
|
407 |
+
R = []
|
408 |
+
b = []
|
409 |
+
|
410 |
+
hh = -h if self.predict_x0 else h
|
411 |
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
412 |
+
h_phi_k = h_phi_1 / hh - 1
|
413 |
+
|
414 |
+
factorial_i = 1
|
415 |
+
|
416 |
+
if self.config.solver_type == "bh1":
|
417 |
+
B_h = hh
|
418 |
+
elif self.config.solver_type == "bh2":
|
419 |
+
B_h = torch.expm1(hh)
|
420 |
+
else:
|
421 |
+
raise NotImplementedError()
|
422 |
+
|
423 |
+
for i in range(1, order + 1):
|
424 |
+
R.append(torch.pow(rks, i - 1))
|
425 |
+
b.append(h_phi_k * factorial_i / B_h)
|
426 |
+
factorial_i *= i + 1
|
427 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
428 |
+
|
429 |
+
R = torch.stack(R)
|
430 |
+
b = torch.tensor(b, device=device)
|
431 |
+
|
432 |
+
if len(D1s) > 0:
|
433 |
+
D1s = torch.stack(D1s, dim=1) # (B, K)
|
434 |
+
# for order 2, we use a simplified version
|
435 |
+
if order == 2:
|
436 |
+
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
|
437 |
+
else:
|
438 |
+
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
|
439 |
+
else:
|
440 |
+
D1s = None
|
441 |
+
|
442 |
+
if self.predict_x0:
|
443 |
+
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
444 |
+
if D1s is not None:
|
445 |
+
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore
|
446 |
+
else:
|
447 |
+
pred_res = 0
|
448 |
+
x_t = x_t_ - alpha_t * B_h * pred_res
|
449 |
+
else:
|
450 |
+
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
451 |
+
if D1s is not None:
|
452 |
+
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore
|
453 |
+
else:
|
454 |
+
pred_res = 0
|
455 |
+
x_t = x_t_ - sigma_t * B_h * pred_res
|
456 |
+
|
457 |
+
x_t = x_t.to(x.dtype)
|
458 |
+
return x_t
|
459 |
+
|
460 |
+
def multistep_uni_c_bh_update(
|
461 |
+
self,
|
462 |
+
this_model_output: torch.Tensor,
|
463 |
+
*args,
|
464 |
+
last_sample: torch.Tensor = None,
|
465 |
+
this_sample: torch.Tensor = None,
|
466 |
+
order: int = None, # pyright: ignore
|
467 |
+
**kwargs,
|
468 |
+
) -> torch.Tensor:
|
469 |
+
"""
|
470 |
+
One step for the UniC (B(h) version).
|
471 |
+
|
472 |
+
Args:
|
473 |
+
this_model_output (`torch.Tensor`):
|
474 |
+
The model outputs at `x_t`.
|
475 |
+
this_timestep (`int`):
|
476 |
+
The current timestep `t`.
|
477 |
+
last_sample (`torch.Tensor`):
|
478 |
+
The generated sample before the last predictor `x_{t-1}`.
|
479 |
+
this_sample (`torch.Tensor`):
|
480 |
+
The generated sample after the last predictor `x_{t}`.
|
481 |
+
order (`int`):
|
482 |
+
The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
|
483 |
+
|
484 |
+
Returns:
|
485 |
+
`torch.Tensor`:
|
486 |
+
The corrected sample tensor at the current timestep.
|
487 |
+
"""
|
488 |
+
this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
|
489 |
+
if last_sample is None:
|
490 |
+
if len(args) > 1:
|
491 |
+
last_sample = args[1]
|
492 |
+
else:
|
493 |
+
raise ValueError(" missing`last_sample` as a required keyward argument")
|
494 |
+
if this_sample is None:
|
495 |
+
if len(args) > 2:
|
496 |
+
this_sample = args[2]
|
497 |
+
else:
|
498 |
+
raise ValueError(" missing`this_sample` as a required keyward argument")
|
499 |
+
if order is None:
|
500 |
+
if len(args) > 3:
|
501 |
+
order = args[3]
|
502 |
+
else:
|
503 |
+
raise ValueError(" missing`order` as a required keyward argument")
|
504 |
+
if this_timestep is not None:
|
505 |
+
deprecate(
|
506 |
+
"this_timestep",
|
507 |
+
"1.0.0",
|
508 |
+
"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
509 |
+
)
|
510 |
+
|
511 |
+
model_output_list = self.model_outputs
|
512 |
+
|
513 |
+
m0 = model_output_list[-1]
|
514 |
+
x = last_sample
|
515 |
+
x_t = this_sample
|
516 |
+
model_t = this_model_output
|
517 |
+
|
518 |
+
sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] # pyright: ignore
|
519 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
520 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
521 |
+
|
522 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
523 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
524 |
+
|
525 |
+
h = lambda_t - lambda_s0
|
526 |
+
device = this_sample.device
|
527 |
+
|
528 |
+
rks = []
|
529 |
+
D1s = []
|
530 |
+
for i in range(1, order):
|
531 |
+
si = self.step_index - (i + 1) # pyright: ignore
|
532 |
+
mi = model_output_list[-(i + 1)]
|
533 |
+
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
534 |
+
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
535 |
+
rk = (lambda_si - lambda_s0) / h
|
536 |
+
rks.append(rk)
|
537 |
+
D1s.append((mi - m0) / rk) # pyright: ignore
|
538 |
+
|
539 |
+
rks.append(1.0)
|
540 |
+
rks = torch.tensor(rks, device=device)
|
541 |
+
|
542 |
+
R = []
|
543 |
+
b = []
|
544 |
+
|
545 |
+
hh = -h if self.predict_x0 else h
|
546 |
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
547 |
+
h_phi_k = h_phi_1 / hh - 1
|
548 |
+
|
549 |
+
factorial_i = 1
|
550 |
+
|
551 |
+
if self.config.solver_type == "bh1":
|
552 |
+
B_h = hh
|
553 |
+
elif self.config.solver_type == "bh2":
|
554 |
+
B_h = torch.expm1(hh)
|
555 |
+
else:
|
556 |
+
raise NotImplementedError()
|
557 |
+
|
558 |
+
for i in range(1, order + 1):
|
559 |
+
R.append(torch.pow(rks, i - 1))
|
560 |
+
b.append(h_phi_k * factorial_i / B_h)
|
561 |
+
factorial_i *= i + 1
|
562 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
563 |
+
|
564 |
+
R = torch.stack(R)
|
565 |
+
b = torch.tensor(b, device=device)
|
566 |
+
|
567 |
+
if len(D1s) > 0:
|
568 |
+
D1s = torch.stack(D1s, dim=1)
|
569 |
+
else:
|
570 |
+
D1s = None
|
571 |
+
|
572 |
+
# for order 1, we use a simplified version
|
573 |
+
if order == 1:
|
574 |
+
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
|
575 |
+
else:
|
576 |
+
rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
|
577 |
+
|
578 |
+
if self.predict_x0:
|
579 |
+
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
580 |
+
if D1s is not None:
|
581 |
+
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
|
582 |
+
else:
|
583 |
+
corr_res = 0
|
584 |
+
D1_t = model_t - m0
|
585 |
+
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
586 |
+
else:
|
587 |
+
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
588 |
+
if D1s is not None:
|
589 |
+
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
|
590 |
+
else:
|
591 |
+
corr_res = 0
|
592 |
+
D1_t = model_t - m0
|
593 |
+
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
594 |
+
x_t = x_t.to(x.dtype)
|
595 |
+
return x_t
|
596 |
+
|
597 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
598 |
+
if schedule_timesteps is None:
|
599 |
+
schedule_timesteps = self.timesteps
|
600 |
+
|
601 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
602 |
+
|
603 |
+
# The sigma index that is taken for the **very** first `step`
|
604 |
+
# is always the second index (or the last index if there is only 1)
|
605 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
606 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
607 |
+
pos = 1 if len(indices) > 1 else 0
|
608 |
+
|
609 |
+
return indices[pos].item()
|
610 |
+
|
611 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
612 |
+
def _init_step_index(self, timestep):
|
613 |
+
"""
|
614 |
+
Initialize the step_index counter for the scheduler.
|
615 |
+
"""
|
616 |
+
|
617 |
+
if self.begin_index is None:
|
618 |
+
if isinstance(timestep, torch.Tensor):
|
619 |
+
timestep = timestep.to(self.timesteps.device)
|
620 |
+
self._step_index = self.index_for_timestep(timestep)
|
621 |
+
else:
|
622 |
+
self._step_index = self._begin_index
|
623 |
+
|
624 |
+
def step(
|
625 |
+
self,
|
626 |
+
model_output: torch.Tensor,
|
627 |
+
timestep: Union[int, torch.Tensor],
|
628 |
+
sample: torch.Tensor,
|
629 |
+
return_dict: bool = True,
|
630 |
+
generator=None,
|
631 |
+
) -> Union[SchedulerOutput, Tuple]:
|
632 |
+
"""
|
633 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
|
634 |
+
the multistep UniPC.
|
635 |
+
|
636 |
+
Args:
|
637 |
+
model_output (`torch.Tensor`):
|
638 |
+
The direct output from learned diffusion model.
|
639 |
+
timestep (`int`):
|
640 |
+
The current discrete timestep in the diffusion chain.
|
641 |
+
sample (`torch.Tensor`):
|
642 |
+
A current instance of a sample created by the diffusion process.
|
643 |
+
return_dict (`bool`):
|
644 |
+
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
645 |
+
|
646 |
+
Returns:
|
647 |
+
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
648 |
+
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
649 |
+
tuple is returned where the first element is the sample tensor.
|
650 |
+
|
651 |
+
"""
|
652 |
+
if self.num_inference_steps is None:
|
653 |
+
raise ValueError(
|
654 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
655 |
+
)
|
656 |
+
|
657 |
+
if self.step_index is None:
|
658 |
+
self._init_step_index(timestep)
|
659 |
+
|
660 |
+
use_corrector = (
|
661 |
+
self.step_index > 0
|
662 |
+
and self.step_index - 1 not in self.disable_corrector
|
663 |
+
and self.last_sample is not None # pyright: ignore
|
664 |
+
)
|
665 |
+
|
666 |
+
model_output_convert = self.convert_model_output(model_output, sample=sample)
|
667 |
+
if use_corrector:
|
668 |
+
sample = self.multistep_uni_c_bh_update(
|
669 |
+
this_model_output=model_output_convert,
|
670 |
+
last_sample=self.last_sample,
|
671 |
+
this_sample=sample,
|
672 |
+
order=self.this_order,
|
673 |
+
)
|
674 |
+
|
675 |
+
for i in range(self.config.solver_order - 1):
|
676 |
+
self.model_outputs[i] = self.model_outputs[i + 1]
|
677 |
+
self.timestep_list[i] = self.timestep_list[i + 1]
|
678 |
+
|
679 |
+
self.model_outputs[-1] = model_output_convert
|
680 |
+
self.timestep_list[-1] = timestep # pyright: ignore
|
681 |
+
|
682 |
+
if self.config.lower_order_final:
|
683 |
+
this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) # pyright: ignore
|
684 |
+
else:
|
685 |
+
this_order = self.config.solver_order
|
686 |
+
|
687 |
+
self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
|
688 |
+
assert self.this_order > 0
|
689 |
+
|
690 |
+
self.last_sample = sample
|
691 |
+
prev_sample = self.multistep_uni_p_bh_update(
|
692 |
+
model_output=model_output, # pass the original non-converted model output, in case solver-p is used
|
693 |
+
sample=sample,
|
694 |
+
order=self.this_order,
|
695 |
+
)
|
696 |
+
|
697 |
+
if self.lower_order_nums < self.config.solver_order:
|
698 |
+
self.lower_order_nums += 1
|
699 |
+
|
700 |
+
# upon completion increase step index by one
|
701 |
+
self._step_index += 1 # pyright: ignore
|
702 |
+
|
703 |
+
if not return_dict:
|
704 |
+
return (prev_sample,)
|
705 |
+
|
706 |
+
return SchedulerOutput(prev_sample=prev_sample)
|
707 |
+
|
708 |
+
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
709 |
+
"""
|
710 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
711 |
+
current timestep.
|
712 |
+
|
713 |
+
Args:
|
714 |
+
sample (`torch.Tensor`):
|
715 |
+
The input sample.
|
716 |
+
|
717 |
+
Returns:
|
718 |
+
`torch.Tensor`:
|
719 |
+
A scaled input sample.
|
720 |
+
"""
|
721 |
+
return sample
|
722 |
+
|
723 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
|
724 |
+
def add_noise(
|
725 |
+
self,
|
726 |
+
original_samples: torch.Tensor,
|
727 |
+
noise: torch.Tensor,
|
728 |
+
timesteps: torch.IntTensor,
|
729 |
+
) -> torch.Tensor:
|
730 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
731 |
+
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
732 |
+
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
733 |
+
# mps does not support float64
|
734 |
+
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
735 |
+
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
736 |
+
else:
|
737 |
+
schedule_timesteps = self.timesteps.to(original_samples.device)
|
738 |
+
timesteps = timesteps.to(original_samples.device)
|
739 |
+
|
740 |
+
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
|
741 |
+
if self.begin_index is None:
|
742 |
+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
743 |
+
elif self.step_index is not None:
|
744 |
+
# add_noise is called after first denoising step (for inpainting)
|
745 |
+
step_indices = [self.step_index] * timesteps.shape[0]
|
746 |
+
else:
|
747 |
+
# add noise is called before first denoising step to create initial latent(img2img)
|
748 |
+
step_indices = [self.begin_index] * timesteps.shape[0]
|
749 |
+
|
750 |
+
sigma = sigmas[step_indices].flatten()
|
751 |
+
while len(sigma.shape) < len(original_samples.shape):
|
752 |
+
sigma = sigma.unsqueeze(-1)
|
753 |
+
|
754 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
755 |
+
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
756 |
+
return noisy_samples
|
757 |
+
|
758 |
+
def __len__(self):
|
759 |
+
return self.config.num_train_timesteps
|