openfree commited on
Commit
22a1a78
·
verified ·
1 Parent(s): b491ce3

Upload 9 files

Browse files
diffusers_helper/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # diffusers_helper package
diffusers_helper/bucket_tools.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bucket_options = {
2
+ 640: [
3
+ (416, 960),
4
+ (448, 864),
5
+ (480, 832),
6
+ (512, 768),
7
+ (544, 704),
8
+ (576, 672),
9
+ (608, 640),
10
+ (640, 608),
11
+ (672, 576),
12
+ (704, 544),
13
+ (768, 512),
14
+ (832, 480),
15
+ (864, 448),
16
+ (960, 416),
17
+ ],
18
+ }
19
+
20
+
21
+ def find_nearest_bucket(h, w, resolution=640):
22
+ min_metric = float('inf')
23
+ best_bucket = None
24
+ for (bucket_h, bucket_w) in bucket_options[resolution]:
25
+ metric = abs(h * bucket_w - w * bucket_h)
26
+ if metric <= min_metric:
27
+ min_metric = metric
28
+ best_bucket = (bucket_h, bucket_w)
29
+ return best_bucket
30
+
diffusers_helper/clip_vision.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def hf_clip_vision_encode(image, feature_extractor, image_encoder):
5
+ assert isinstance(image, np.ndarray)
6
+ assert image.ndim == 3 and image.shape[2] == 3
7
+ assert image.dtype == np.uint8
8
+
9
+ preprocessed = feature_extractor.preprocess(images=image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype)
10
+ image_encoder_output = image_encoder(**preprocessed)
11
+
12
+ return image_encoder_output
diffusers_helper/dit_common.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import accelerate.accelerator
3
+
4
+ from diffusers.models.normalization import RMSNorm, LayerNorm, FP32LayerNorm, AdaLayerNormContinuous
5
+
6
+
7
+ accelerate.accelerator.convert_outputs_to_fp32 = lambda x: x
8
+
9
+
10
+ def LayerNorm_forward(self, x):
11
+ return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x)
12
+
13
+
14
+ LayerNorm.forward = LayerNorm_forward
15
+ torch.nn.LayerNorm.forward = LayerNorm_forward
16
+
17
+
18
+ def FP32LayerNorm_forward(self, x):
19
+ origin_dtype = x.dtype
20
+ return torch.nn.functional.layer_norm(
21
+ x.float(),
22
+ self.normalized_shape,
23
+ self.weight.float() if self.weight is not None else None,
24
+ self.bias.float() if self.bias is not None else None,
25
+ self.eps,
26
+ ).to(origin_dtype)
27
+
28
+
29
+ FP32LayerNorm.forward = FP32LayerNorm_forward
30
+
31
+
32
+ def RMSNorm_forward(self, hidden_states):
33
+ input_dtype = hidden_states.dtype
34
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
35
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
36
+
37
+ if self.weight is None:
38
+ return hidden_states.to(input_dtype)
39
+
40
+ return hidden_states.to(input_dtype) * self.weight.to(input_dtype)
41
+
42
+
43
+ RMSNorm.forward = RMSNorm_forward
44
+
45
+
46
+ def AdaLayerNormContinuous_forward(self, x, conditioning_embedding):
47
+ emb = self.linear(self.silu(conditioning_embedding))
48
+ scale, shift = emb.chunk(2, dim=1)
49
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
50
+ return x
51
+
52
+
53
+ AdaLayerNormContinuous.forward = AdaLayerNormContinuous_forward
diffusers_helper/hf_login.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from huggingface_hub import login
3
+
4
+ def login():
5
+ # 如果是在Hugging Face Space环境中运行,使用环境变量中的token
6
+ if os.environ.get('SPACE_ID') is not None:
7
+ print("Running in Hugging Face Space, using environment HF_TOKEN")
8
+ # Space自带访问权限,无需额外登录
9
+ return
10
+
11
+ # 如果本地环境有token,则使用它登录
12
+ hf_token = os.environ.get('HF_TOKEN')
13
+ if hf_token:
14
+ print("Logging in with HF_TOKEN from environment")
15
+ login(token=hf_token)
16
+ return
17
+
18
+ # 检查缓存的token
19
+ cache_file = os.path.expanduser('~/.huggingface/token')
20
+ if os.path.exists(cache_file):
21
+ print("Found cached Hugging Face token")
22
+ return
23
+
24
+ print("No Hugging Face token found. Using public access.")
25
+ # 无token时使用公共访问,速度可能较慢且有限制
diffusers_helper/hunyuan.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
4
+ from diffusers_helper.utils import crop_or_pad_yield_mask
5
+
6
+
7
+ @torch.no_grad()
8
+ def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, max_length=256):
9
+ assert isinstance(prompt, str)
10
+
11
+ prompt = [prompt]
12
+
13
+ # LLAMA
14
+
15
+ prompt_llama = [DEFAULT_PROMPT_TEMPLATE["template"].format(p) for p in prompt]
16
+ crop_start = DEFAULT_PROMPT_TEMPLATE["crop_start"]
17
+
18
+ llama_inputs = tokenizer(
19
+ prompt_llama,
20
+ padding="max_length",
21
+ max_length=max_length + crop_start,
22
+ truncation=True,
23
+ return_tensors="pt",
24
+ return_length=False,
25
+ return_overflowing_tokens=False,
26
+ return_attention_mask=True,
27
+ )
28
+
29
+ llama_input_ids = llama_inputs.input_ids.to(text_encoder.device)
30
+ llama_attention_mask = llama_inputs.attention_mask.to(text_encoder.device)
31
+ llama_attention_length = int(llama_attention_mask.sum())
32
+
33
+ llama_outputs = text_encoder(
34
+ input_ids=llama_input_ids,
35
+ attention_mask=llama_attention_mask,
36
+ output_hidden_states=True,
37
+ )
38
+
39
+ llama_vec = llama_outputs.hidden_states[-3][:, crop_start:llama_attention_length]
40
+ # llama_vec_remaining = llama_outputs.hidden_states[-3][:, llama_attention_length:]
41
+ llama_attention_mask = llama_attention_mask[:, crop_start:llama_attention_length]
42
+
43
+ assert torch.all(llama_attention_mask.bool())
44
+
45
+ # CLIP
46
+
47
+ clip_l_input_ids = tokenizer_2(
48
+ prompt,
49
+ padding="max_length",
50
+ max_length=77,
51
+ truncation=True,
52
+ return_overflowing_tokens=False,
53
+ return_length=False,
54
+ return_tensors="pt",
55
+ ).input_ids
56
+ clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output
57
+
58
+ return llama_vec, clip_l_pooler
59
+
60
+
61
+ @torch.no_grad()
62
+ def vae_decode_fake(latents):
63
+ latent_rgb_factors = [
64
+ [-0.0395, -0.0331, 0.0445],
65
+ [0.0696, 0.0795, 0.0518],
66
+ [0.0135, -0.0945, -0.0282],
67
+ [0.0108, -0.0250, -0.0765],
68
+ [-0.0209, 0.0032, 0.0224],
69
+ [-0.0804, -0.0254, -0.0639],
70
+ [-0.0991, 0.0271, -0.0669],
71
+ [-0.0646, -0.0422, -0.0400],
72
+ [-0.0696, -0.0595, -0.0894],
73
+ [-0.0799, -0.0208, -0.0375],
74
+ [0.1166, 0.1627, 0.0962],
75
+ [0.1165, 0.0432, 0.0407],
76
+ [-0.2315, -0.1920, -0.1355],
77
+ [-0.0270, 0.0401, -0.0821],
78
+ [-0.0616, -0.0997, -0.0727],
79
+ [0.0249, -0.0469, -0.1703]
80
+ ] # From comfyui
81
+
82
+ latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761]
83
+
84
+ weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None]
85
+ bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
86
+
87
+ images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1)
88
+ images = images.clamp(0.0, 1.0)
89
+
90
+ return images
91
+
92
+
93
+ @torch.no_grad()
94
+ def vae_decode(latents, vae, image_mode=False):
95
+ latents = latents / vae.config.scaling_factor
96
+
97
+ if not image_mode:
98
+ image = vae.decode(latents.to(device=vae.device, dtype=vae.dtype)).sample
99
+ else:
100
+ latents = latents.to(device=vae.device, dtype=vae.dtype).unbind(2)
101
+ image = [vae.decode(l.unsqueeze(2)).sample for l in latents]
102
+ image = torch.cat(image, dim=2)
103
+
104
+ return image
105
+
106
+
107
+ @torch.no_grad()
108
+ def vae_encode(image, vae):
109
+ latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
110
+ latents = latents * vae.config.scaling_factor
111
+ return latents
diffusers_helper/memory.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # By lllyasviel
2
+
3
+
4
+ import torch
5
+ import os
6
+
7
+ # 检查是否在Hugging Face Space环境中
8
+ IN_HF_SPACE = os.environ.get('SPACE_ID') is not None
9
+
10
+ # 设置CPU设备
11
+ cpu = torch.device('cpu')
12
+
13
+ # 在Stateless GPU环境中,不要在主进程初始化CUDA
14
+ def get_gpu_device():
15
+ if IN_HF_SPACE:
16
+ # 在Spaces中将延迟初始化GPU设备
17
+ return 'cuda' # 返回字符串,而不是实际初始化设备
18
+
19
+ # 非Spaces环境正常初始化
20
+ try:
21
+ if torch.cuda.is_available():
22
+ return torch.device(f'cuda:{torch.cuda.current_device()}')
23
+ else:
24
+ print("CUDA不可用,使用CPU作为默认设备")
25
+ return torch.device('cpu')
26
+ except Exception as e:
27
+ print(f"初始化CUDA设备时出错: {e}")
28
+ print("回退到CPU设备")
29
+ return torch.device('cpu')
30
+
31
+ # 保存一个字符串表示,而不是实际的设备对象
32
+ gpu = get_gpu_device()
33
+
34
+ gpu_complete_modules = []
35
+
36
+
37
+ class DynamicSwapInstaller:
38
+ @staticmethod
39
+ def _install_module(module: torch.nn.Module, **kwargs):
40
+ original_class = module.__class__
41
+ module.__dict__['forge_backup_original_class'] = original_class
42
+
43
+ def hacked_get_attr(self, name: str):
44
+ if '_parameters' in self.__dict__:
45
+ _parameters = self.__dict__['_parameters']
46
+ if name in _parameters:
47
+ p = _parameters[name]
48
+ if p is None:
49
+ return None
50
+ if p.__class__ == torch.nn.Parameter:
51
+ return torch.nn.Parameter(p.to(**kwargs), requires_grad=p.requires_grad)
52
+ else:
53
+ return p.to(**kwargs)
54
+ if '_buffers' in self.__dict__:
55
+ _buffers = self.__dict__['_buffers']
56
+ if name in _buffers:
57
+ return _buffers[name].to(**kwargs)
58
+ return super(original_class, self).__getattr__(name)
59
+
60
+ module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), {
61
+ '__getattr__': hacked_get_attr,
62
+ })
63
+
64
+ return
65
+
66
+ @staticmethod
67
+ def _uninstall_module(module: torch.nn.Module):
68
+ if 'forge_backup_original_class' in module.__dict__:
69
+ module.__class__ = module.__dict__.pop('forge_backup_original_class')
70
+ return
71
+
72
+ @staticmethod
73
+ def install_model(model: torch.nn.Module, **kwargs):
74
+ for m in model.modules():
75
+ DynamicSwapInstaller._install_module(m, **kwargs)
76
+ return
77
+
78
+ @staticmethod
79
+ def uninstall_model(model: torch.nn.Module):
80
+ for m in model.modules():
81
+ DynamicSwapInstaller._uninstall_module(m)
82
+ return
83
+
84
+
85
+ def fake_diffusers_current_device(model: torch.nn.Module, target_device):
86
+ # 转换字符串设备为torch.device
87
+ if isinstance(target_device, str):
88
+ target_device = torch.device(target_device)
89
+
90
+ if hasattr(model, 'scale_shift_table'):
91
+ model.scale_shift_table.data = model.scale_shift_table.data.to(target_device)
92
+ return
93
+
94
+ for k, p in model.named_modules():
95
+ if hasattr(p, 'weight'):
96
+ p.to(target_device)
97
+ return
98
+
99
+
100
+ def get_cuda_free_memory_gb(device=None):
101
+ if device is None:
102
+ device = gpu
103
+
104
+ # 如果是字符串,转换为设备
105
+ if isinstance(device, str):
106
+ device = torch.device(device)
107
+
108
+ # 如果不是CUDA设备,返回默认值
109
+ if device.type != 'cuda':
110
+ print("无法获取非CUDA设备的内存信息,返回默认值")
111
+ return 6.0 # 返回一个默认值
112
+
113
+ try:
114
+ memory_stats = torch.cuda.memory_stats(device)
115
+ bytes_active = memory_stats['active_bytes.all.current']
116
+ bytes_reserved = memory_stats['reserved_bytes.all.current']
117
+ bytes_free_cuda, _ = torch.cuda.mem_get_info(device)
118
+ bytes_inactive_reserved = bytes_reserved - bytes_active
119
+ bytes_total_available = bytes_free_cuda + bytes_inactive_reserved
120
+ return bytes_total_available / (1024 ** 3)
121
+ except Exception as e:
122
+ print(f"获取CUDA内存信息时出错: {e}")
123
+ return 6.0 # 返回一个默认值
124
+
125
+
126
+ def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0):
127
+ print(f'Moving {model.__class__.__name__} to {target_device} with preserved memory: {preserved_memory_gb} GB')
128
+
129
+ # 如果是字符串,转换为设备
130
+ if isinstance(target_device, str):
131
+ target_device = torch.device(target_device)
132
+
133
+ # 如果gpu是字符串,转换为设备
134
+ gpu_device = gpu
135
+ if isinstance(gpu_device, str):
136
+ gpu_device = torch.device(gpu_device)
137
+
138
+ # 如果目标设备是CPU或当前在CPU上,直接移动
139
+ if target_device.type == 'cpu' or gpu_device.type == 'cpu':
140
+ model.to(device=target_device)
141
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
142
+ return
143
+
144
+ for m in model.modules():
145
+ if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb:
146
+ torch.cuda.empty_cache()
147
+ return
148
+
149
+ if hasattr(m, 'weight'):
150
+ m.to(device=target_device)
151
+
152
+ model.to(device=target_device)
153
+ torch.cuda.empty_cache()
154
+ return
155
+
156
+
157
+ def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0):
158
+ print(f'Offloading {model.__class__.__name__} from {target_device} to preserve memory: {preserved_memory_gb} GB')
159
+
160
+ # 如果是字符串,转换为设备
161
+ if isinstance(target_device, str):
162
+ target_device = torch.device(target_device)
163
+
164
+ # 如果gpu是字符串,转换为设备
165
+ gpu_device = gpu
166
+ if isinstance(gpu_device, str):
167
+ gpu_device = torch.device(gpu_device)
168
+
169
+ # 如果目标设备是CPU或当前在CPU上,直接处理
170
+ if target_device.type == 'cpu' or gpu_device.type == 'cpu':
171
+ model.to(device=cpu)
172
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
173
+ return
174
+
175
+ for m in model.modules():
176
+ if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb:
177
+ torch.cuda.empty_cache()
178
+ return
179
+
180
+ if hasattr(m, 'weight'):
181
+ m.to(device=cpu)
182
+
183
+ model.to(device=cpu)
184
+ torch.cuda.empty_cache()
185
+ return
186
+
187
+
188
+ def unload_complete_models(*args):
189
+ for m in gpu_complete_modules + list(args):
190
+ m.to(device=cpu)
191
+ print(f'Unloaded {m.__class__.__name__} as complete.')
192
+
193
+ gpu_complete_modules.clear()
194
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
195
+ return
196
+
197
+
198
+ def load_model_as_complete(model, target_device, unload=True):
199
+ # 如果是字符串,转换为设备
200
+ if isinstance(target_device, str):
201
+ target_device = torch.device(target_device)
202
+
203
+ if unload:
204
+ unload_complete_models()
205
+
206
+ model.to(device=target_device)
207
+ print(f'Loaded {model.__class__.__name__} to {target_device} as complete.')
208
+
209
+ gpu_complete_modules.append(model)
210
+ return
diffusers_helper/thread_utils.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ from threading import Thread, Lock
4
+
5
+
6
+ class Listener:
7
+ task_queue = []
8
+ lock = Lock()
9
+ thread = None
10
+
11
+ @classmethod
12
+ def _process_tasks(cls):
13
+ while True:
14
+ task = None
15
+ with cls.lock:
16
+ if cls.task_queue:
17
+ task = cls.task_queue.pop(0)
18
+
19
+ if task is None:
20
+ time.sleep(0.001)
21
+ continue
22
+
23
+ func, args, kwargs = task
24
+ try:
25
+ func(*args, **kwargs)
26
+ except Exception as e:
27
+ print(f"Error in listener thread: {e}")
28
+
29
+ @classmethod
30
+ def add_task(cls, func, *args, **kwargs):
31
+ with cls.lock:
32
+ cls.task_queue.append((func, args, kwargs))
33
+
34
+ if cls.thread is None:
35
+ cls.thread = Thread(target=cls._process_tasks, daemon=True)
36
+ cls.thread.start()
37
+
38
+
39
+ def async_run(func, *args, **kwargs):
40
+ Listener.add_task(func, *args, **kwargs)
41
+
42
+
43
+ class FIFOQueue:
44
+ def __init__(self):
45
+ self.queue = []
46
+ self.lock = Lock()
47
+ print("【调试】创建新的FIFOQueue")
48
+
49
+ def push(self, item):
50
+ print(f"【调试】FIFOQueue.push: 准备添加项目: {item}")
51
+ with self.lock:
52
+ self.queue.append(item)
53
+ print(f"【调试】FIFOQueue.push: 成功添加项目: {item}, 当前队列长度: {len(self.queue)}")
54
+
55
+ def pop(self):
56
+ print("【调试】FIFOQueue.pop: 准备弹出队列首项")
57
+ with self.lock:
58
+ if self.queue:
59
+ item = self.queue.pop(0)
60
+ print(f"【调试】FIFOQueue.pop: 成功弹出项目: {item}, 剩余队列长度: {len(self.queue)}")
61
+ return item
62
+ print("【调试】FIFOQueue.pop: 队列为空,返回None")
63
+ return None
64
+
65
+ def top(self):
66
+ print("【调试】FIFOQueue.top: 准备查看队列首项")
67
+ with self.lock:
68
+ if self.queue:
69
+ item = self.queue[0]
70
+ print(f"【调试】FIFOQueue.top: 队列首项为: {item}, 当前队列长度: {len(self.queue)}")
71
+ return item
72
+ print("【调试】FIFOQueue.top: 队列为空,返回None")
73
+ return None
74
+
75
+ def next(self):
76
+ print("【调试】FIFOQueue.next: 等待弹出队列首项")
77
+ while True:
78
+ with self.lock:
79
+ if self.queue:
80
+ item = self.queue.pop(0)
81
+ print(f"【调试】FIFOQueue.next: 成功弹出项目: {item}, 剩余队列长度: {len(self.queue)}")
82
+ return item
83
+
84
+ time.sleep(0.001)
85
+
86
+
87
+ class AsyncStream:
88
+ def __init__(self):
89
+ self.input_queue = FIFOQueue()
90
+ self.output_queue = FIFOQueue()
91
+
92
+
93
+ class InterruptibleStreamData:
94
+ def __init__(self):
95
+ self.input_queue = FIFOQueue()
96
+ self.output_queue = FIFOQueue()
97
+ print("【调试】创建新的InterruptibleStreamData,初始化输入输出队列")
98
+
99
+ # 推送数据至输出队列
100
+ def push_output(self, item):
101
+ print(f"【调试】InterruptibleStreamData.push_output: 准备推送输出: {type(item)}")
102
+ self.output_queue.push(item)
103
+ print(f"【调试】InterruptibleStreamData.push_output: 成功推送输出")
104
+
105
+ # 获取下一个输出数据
106
+ def get_output(self):
107
+ print("【调试】InterruptibleStreamData.get_output: 准备获取下一个输出数据")
108
+ item = self.output_queue.next()
109
+ print(f"【调试】InterruptibleStreamData.get_output: 获取到输出数据: {type(item)}")
110
+ return item
111
+
112
+ # 推送数据至输入队列
113
+ def push_input(self, item):
114
+ print(f"【调试】InterruptibleStreamData.push_input: 准备推送输入: {type(item)}")
115
+ self.input_queue.push(item)
116
+ print(f"【调试】InterruptibleStreamData.push_input: 成功推送输入")
117
+
118
+ # 获取下一个输入数据
119
+ def get_input(self):
120
+ print("【调试】InterruptibleStreamData.get_input: 准备获取下一个输入数据")
121
+ item = self.input_queue.next()
122
+ print(f"【调试】InterruptibleStreamData.get_input: 获取到输入数据: {type(item)}")
123
+ return item
diffusers_helper/utils.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ import random
5
+ import glob
6
+ import torch
7
+ import einops
8
+ import numpy as np
9
+ import datetime
10
+ import torchvision
11
+
12
+ import safetensors.torch as sf
13
+ from PIL import Image
14
+
15
+
16
+ def min_resize(x, m):
17
+ if x.shape[0] < x.shape[1]:
18
+ s0 = m
19
+ s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
20
+ else:
21
+ s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
22
+ s1 = m
23
+ new_max = max(s1, s0)
24
+ raw_max = max(x.shape[0], x.shape[1])
25
+ if new_max < raw_max:
26
+ interpolation = cv2.INTER_AREA
27
+ else:
28
+ interpolation = cv2.INTER_LANCZOS4
29
+ y = cv2.resize(x, (s1, s0), interpolation=interpolation)
30
+ return y
31
+
32
+
33
+ def d_resize(x, y):
34
+ H, W, C = y.shape
35
+ new_min = min(H, W)
36
+ raw_min = min(x.shape[0], x.shape[1])
37
+ if new_min < raw_min:
38
+ interpolation = cv2.INTER_AREA
39
+ else:
40
+ interpolation = cv2.INTER_LANCZOS4
41
+ y = cv2.resize(x, (W, H), interpolation=interpolation)
42
+ return y
43
+
44
+
45
+ def resize_and_center_crop(image, target_width, target_height):
46
+ if target_height == image.shape[0] and target_width == image.shape[1]:
47
+ return image
48
+
49
+ pil_image = Image.fromarray(image)
50
+ original_width, original_height = pil_image.size
51
+ scale_factor = max(target_width / original_width, target_height / original_height)
52
+ resized_width = int(round(original_width * scale_factor))
53
+ resized_height = int(round(original_height * scale_factor))
54
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
55
+ left = (resized_width - target_width) / 2
56
+ top = (resized_height - target_height) / 2
57
+ right = (resized_width + target_width) / 2
58
+ bottom = (resized_height + target_height) / 2
59
+ cropped_image = resized_image.crop((left, top, right, bottom))
60
+ return np.array(cropped_image)
61
+
62
+
63
+ def resize_and_center_crop_pytorch(image, target_width, target_height):
64
+ B, C, H, W = image.shape
65
+
66
+ if H == target_height and W == target_width:
67
+ return image
68
+
69
+ scale_factor = max(target_width / W, target_height / H)
70
+ resized_width = int(round(W * scale_factor))
71
+ resized_height = int(round(H * scale_factor))
72
+
73
+ resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False)
74
+
75
+ top = (resized_height - target_height) // 2
76
+ left = (resized_width - target_width) // 2
77
+ cropped = resized[:, :, top:top + target_height, left:left + target_width]
78
+
79
+ return cropped
80
+
81
+
82
+ def resize_without_crop(image, target_width, target_height):
83
+ if target_height == image.shape[0] and target_width == image.shape[1]:
84
+ return image
85
+
86
+ pil_image = Image.fromarray(image)
87
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
88
+ return np.array(resized_image)
89
+
90
+
91
+ def just_crop(image, w, h):
92
+ if h == image.shape[0] and w == image.shape[1]:
93
+ return image
94
+
95
+ original_height, original_width = image.shape[:2]
96
+ k = min(original_height / h, original_width / w)
97
+ new_width = int(round(w * k))
98
+ new_height = int(round(h * k))
99
+ x_start = (original_width - new_width) // 2
100
+ y_start = (original_height - new_height) // 2
101
+ cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width]
102
+ return cropped_image
103
+
104
+
105
+ def write_to_json(data, file_path):
106
+ temp_file_path = file_path + ".tmp"
107
+ with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
108
+ json.dump(data, temp_file, indent=4)
109
+ os.replace(temp_file_path, file_path)
110
+ return
111
+
112
+
113
+ def read_from_json(file_path):
114
+ with open(file_path, 'rt', encoding='utf-8') as file:
115
+ data = json.load(file)
116
+ return data
117
+
118
+
119
+ def get_active_parameters(m):
120
+ return {k: v for k, v in m.named_parameters() if v.requires_grad}
121
+
122
+
123
+ def cast_training_params(m, dtype=torch.float32):
124
+ result = {}
125
+ for n, param in m.named_parameters():
126
+ if param.requires_grad:
127
+ param.data = param.to(dtype)
128
+ result[n] = param
129
+ return result
130
+
131
+
132
+ def separate_lora_AB(parameters, B_patterns=None):
133
+ parameters_normal = {}
134
+ parameters_B = {}
135
+
136
+ if B_patterns is None:
137
+ B_patterns = ['.lora_B.', '__zero__']
138
+
139
+ for k, v in parameters.items():
140
+ if any(B_pattern in k for B_pattern in B_patterns):
141
+ parameters_B[k] = v
142
+ else:
143
+ parameters_normal[k] = v
144
+
145
+ return parameters_normal, parameters_B
146
+
147
+
148
+ def set_attr_recursive(obj, attr, value):
149
+ attrs = attr.split(".")
150
+ for name in attrs[:-1]:
151
+ obj = getattr(obj, name)
152
+ setattr(obj, attrs[-1], value)
153
+ return
154
+
155
+
156
+ def print_tensor_list_size(tensors):
157
+ total_size = 0
158
+ total_elements = 0
159
+
160
+ if isinstance(tensors, dict):
161
+ tensors = tensors.values()
162
+
163
+ for tensor in tensors:
164
+ total_size += tensor.nelement() * tensor.element_size()
165
+ total_elements += tensor.nelement()
166
+
167
+ total_size_MB = total_size / (1024 ** 2)
168
+ total_elements_B = total_elements / 1e9
169
+
170
+ print(f"Total number of tensors: {len(tensors)}")
171
+ print(f"Total size of tensors: {total_size_MB:.2f} MB")
172
+ print(f"Total number of parameters: {total_elements_B:.3f} billion")
173
+ return
174
+
175
+
176
+ @torch.no_grad()
177
+ def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
178
+ batch_size = a.size(0)
179
+
180
+ if b is None:
181
+ b = torch.zeros_like(a)
182
+
183
+ if mask_a is None:
184
+ mask_a = torch.rand(batch_size) < probability_a
185
+
186
+ mask_a = mask_a.to(a.device)
187
+ mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
188
+ result = torch.where(mask_a, a, b)
189
+ return result
190
+
191
+
192
+ @torch.no_grad()
193
+ def zero_module(module):
194
+ for p in module.parameters():
195
+ p.detach().zero_()
196
+ return module
197
+
198
+
199
+ @torch.no_grad()
200
+ def supress_lower_channels(m, k, alpha=0.01):
201
+ data = m.weight.data.clone()
202
+
203
+ assert int(data.shape[1]) >= k
204
+
205
+ data[:, :k] = data[:, :k] * alpha
206
+ m.weight.data = data.contiguous().clone()
207
+ return m
208
+
209
+
210
+ def freeze_module(m):
211
+ if not hasattr(m, '_forward_inside_frozen_module'):
212
+ m._forward_inside_frozen_module = m.forward
213
+ m.requires_grad_(False)
214
+ m.forward = torch.no_grad()(m.forward)
215
+ return m
216
+
217
+
218
+ def get_latest_safetensors(folder_path):
219
+ safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors'))
220
+
221
+ if not safetensors_files:
222
+ raise ValueError('No file to resume!')
223
+
224
+ latest_file = max(safetensors_files, key=os.path.getmtime)
225
+ latest_file = os.path.abspath(os.path.realpath(latest_file))
226
+ return latest_file
227
+
228
+
229
+ def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
230
+ tags = tags_str.split(', ')
231
+ tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
232
+ prompt = ', '.join(tags)
233
+ return prompt
234
+
235
+
236
+ def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
237
+ numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
238
+ if round_to_int:
239
+ numbers = np.round(numbers).astype(int)
240
+ return numbers.tolist()
241
+
242
+
243
+ def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
244
+ edges = np.linspace(0, 1, n + 1)
245
+ points = np.random.uniform(edges[:-1], edges[1:])
246
+ numbers = inclusive + (exclusive - inclusive) * points
247
+ if round_to_int:
248
+ numbers = np.round(numbers).astype(int)
249
+ return numbers.tolist()
250
+
251
+
252
+ def soft_append_bcthw(history, current, overlap=0):
253
+ if overlap <= 0:
254
+ return torch.cat([history, current], dim=2)
255
+
256
+ assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
257
+ assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"
258
+
259
+ weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
260
+ blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
261
+ output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)
262
+
263
+ return output.to(history)
264
+
265
+
266
+ def save_bcthw_as_mp4(x, output_filename, fps=10):
267
+ b, c, t, h, w = x.shape
268
+
269
+ per_row = b
270
+ for p in [6, 5, 4, 3, 2]:
271
+ if b % p == 0:
272
+ per_row = p
273
+ break
274
+
275
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
276
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
277
+ x = x.detach().cpu().to(torch.uint8)
278
+ x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
279
+ torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': '0'})
280
+ return x
281
+
282
+
283
+ def save_bcthw_as_png(x, output_filename):
284
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
285
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
286
+ x = x.detach().cpu().to(torch.uint8)
287
+ x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
288
+ torchvision.io.write_png(x, output_filename)
289
+ return output_filename
290
+
291
+
292
+ def save_bchw_as_png(x, output_filename):
293
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
294
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
295
+ x = x.detach().cpu().to(torch.uint8)
296
+ x = einops.rearrange(x, 'b c h w -> c h (b w)')
297
+ torchvision.io.write_png(x, output_filename)
298
+ return output_filename
299
+
300
+
301
+ def add_tensors_with_padding(tensor1, tensor2):
302
+ if tensor1.shape == tensor2.shape:
303
+ return tensor1 + tensor2
304
+
305
+ shape1 = tensor1.shape
306
+ shape2 = tensor2.shape
307
+
308
+ new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
309
+
310
+ padded_tensor1 = torch.zeros(new_shape)
311
+ padded_tensor2 = torch.zeros(new_shape)
312
+
313
+ padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
314
+ padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
315
+
316
+ result = padded_tensor1 + padded_tensor2
317
+ return result
318
+
319
+
320
+ def print_free_mem():
321
+ torch.cuda.empty_cache()
322
+ free_mem, total_mem = torch.cuda.mem_get_info(0)
323
+ free_mem_mb = free_mem / (1024 ** 2)
324
+ total_mem_mb = total_mem / (1024 ** 2)
325
+ print(f"Free memory: {free_mem_mb:.2f} MB")
326
+ print(f"Total memory: {total_mem_mb:.2f} MB")
327
+ return
328
+
329
+
330
+ def print_gpu_parameters(device, state_dict, log_count=1):
331
+ summary = {"device": device, "keys_count": len(state_dict)}
332
+
333
+ logged_params = {}
334
+ for i, (key, tensor) in enumerate(state_dict.items()):
335
+ if i >= log_count:
336
+ break
337
+ logged_params[key] = tensor.flatten()[:3].tolist()
338
+
339
+ summary["params"] = logged_params
340
+
341
+ print(str(summary))
342
+ return
343
+
344
+
345
+ def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18):
346
+ from PIL import Image, ImageDraw, ImageFont
347
+
348
+ txt = Image.new("RGB", (width, height), color="white")
349
+ draw = ImageDraw.Draw(txt)
350
+ font = ImageFont.truetype(font_path, size=size)
351
+
352
+ if text == '':
353
+ return np.array(txt)
354
+
355
+ # Split text into lines that fit within the image width
356
+ lines = []
357
+ words = text.split()
358
+ current_line = words[0]
359
+
360
+ for word in words[1:]:
361
+ line_with_word = f"{current_line} {word}"
362
+ if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
363
+ current_line = line_with_word
364
+ else:
365
+ lines.append(current_line)
366
+ current_line = word
367
+
368
+ lines.append(current_line)
369
+
370
+ # Draw the text line by line
371
+ y = 0
372
+ line_height = draw.textbbox((0, 0), "A", font=font)[3]
373
+
374
+ for line in lines:
375
+ if y + line_height > height:
376
+ break # stop drawing if the next line will be outside the image
377
+ draw.text((0, y), line, fill="black", font=font)
378
+ y += line_height
379
+
380
+ return np.array(txt)
381
+
382
+
383
+ def blue_mark(x):
384
+ x = x.copy()
385
+ c = x[:, :, 2]
386
+ b = cv2.blur(c, (9, 9))
387
+ x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
388
+ return x
389
+
390
+
391
+ def green_mark(x):
392
+ x = x.copy()
393
+ x[:, :, 2] = -1
394
+ x[:, :, 0] = -1
395
+ return x
396
+
397
+
398
+ def frame_mark(x):
399
+ x = x.copy()
400
+ x[:64] = -1
401
+ x[-64:] = -1
402
+ x[:, :8] = 1
403
+ x[:, -8:] = 1
404
+ return x
405
+
406
+
407
+ @torch.inference_mode()
408
+ def pytorch2numpy(imgs):
409
+ results = []
410
+ for x in imgs:
411
+ y = x.movedim(0, -1)
412
+ y = y * 127.5 + 127.5
413
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
414
+ results.append(y)
415
+ return results
416
+
417
+
418
+ @torch.inference_mode()
419
+ def numpy2pytorch(imgs):
420
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
421
+ h = h.movedim(-1, 1)
422
+ return h
423
+
424
+
425
+ @torch.no_grad()
426
+ def duplicate_prefix_to_suffix(x, count, zero_out=False):
427
+ if zero_out:
428
+ return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
429
+ else:
430
+ return torch.cat([x, x[:count]], dim=0)
431
+
432
+
433
+ def weighted_mse(a, b, weight):
434
+ return torch.mean(weight.float() * (a.float() - b.float()) ** 2)
435
+
436
+
437
+ def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
438
+ x = (x - x_min) / (x_max - x_min)
439
+ x = max(0.0, min(x, 1.0))
440
+ x = x ** sigma
441
+ return y_min + x * (y_max - y_min)
442
+
443
+
444
+ def expand_to_dims(x, target_dims):
445
+ return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))
446
+
447
+
448
+ def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
449
+ if tensor is None:
450
+ return None
451
+
452
+ first_dim = tensor.shape[0]
453
+
454
+ if first_dim == batch_size:
455
+ return tensor
456
+
457
+ if batch_size % first_dim != 0:
458
+ raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")
459
+
460
+ repeat_times = batch_size // first_dim
461
+
462
+ return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))
463
+
464
+
465
+ def dim5(x):
466
+ return expand_to_dims(x, 5)
467
+
468
+
469
+ def dim4(x):
470
+ return expand_to_dims(x, 4)
471
+
472
+
473
+ def dim3(x):
474
+ return expand_to_dims(x, 3)
475
+
476
+
477
+ def crop_or_pad_yield_mask(x, length):
478
+ B, F, C = x.shape
479
+ device = x.device
480
+ dtype = x.dtype
481
+
482
+ if F < length:
483
+ y = torch.zeros((B, length, C), dtype=dtype, device=device)
484
+ mask = torch.zeros((B, length), dtype=torch.bool, device=device)
485
+ y[:, :F, :] = x
486
+ mask[:, :F] = True
487
+ return y, mask
488
+
489
+ return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)
490
+
491
+
492
+ def extend_dim(x, dim, minimal_length, zero_pad=False):
493
+ original_length = int(x.shape[dim])
494
+
495
+ if original_length >= minimal_length:
496
+ return x
497
+
498
+ if zero_pad:
499
+ padding_shape = list(x.shape)
500
+ padding_shape[dim] = minimal_length - original_length
501
+ padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
502
+ else:
503
+ idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
504
+ last_element = x[idx]
505
+ padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)
506
+
507
+ return torch.cat([x, padding], dim=dim)
508
+
509
+
510
+ def lazy_positional_encoding(t, repeats=None):
511
+ if not isinstance(t, list):
512
+ t = [t]
513
+
514
+ from diffusers.models.embeddings import get_timestep_embedding
515
+
516
+ te = torch.tensor(t)
517
+ te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)
518
+
519
+ if repeats is None:
520
+ return te
521
+
522
+ te = te[:, None, :].expand(-1, repeats, -1)
523
+
524
+ return te
525
+
526
+
527
+ def state_dict_offset_merge(A, B, C=None):
528
+ result = {}
529
+ keys = A.keys()
530
+
531
+ for key in keys:
532
+ A_value = A[key]
533
+ B_value = B[key].to(A_value)
534
+
535
+ if C is None:
536
+ result[key] = A_value + B_value
537
+ else:
538
+ C_value = C[key].to(A_value)
539
+ result[key] = A_value + B_value - C_value
540
+
541
+ return result
542
+
543
+
544
+ def state_dict_weighted_merge(state_dicts, weights):
545
+ if len(state_dicts) != len(weights):
546
+ raise ValueError("Number of state dictionaries must match number of weights")
547
+
548
+ if not state_dicts:
549
+ return {}
550
+
551
+ total_weight = sum(weights)
552
+
553
+ if total_weight == 0:
554
+ raise ValueError("Sum of weights cannot be zero")
555
+
556
+ normalized_weights = [w / total_weight for w in weights]
557
+
558
+ keys = state_dicts[0].keys()
559
+ result = {}
560
+
561
+ for key in keys:
562
+ result[key] = state_dicts[0][key] * normalized_weights[0]
563
+
564
+ for i in range(1, len(state_dicts)):
565
+ state_dict_value = state_dicts[i][key].to(result[key])
566
+ result[key] += state_dict_value * normalized_weights[i]
567
+
568
+ return result
569
+
570
+
571
+ def group_files_by_folder(all_files):
572
+ grouped_files = {}
573
+
574
+ for file in all_files:
575
+ folder_name = os.path.basename(os.path.dirname(file))
576
+ if folder_name not in grouped_files:
577
+ grouped_files[folder_name] = []
578
+ grouped_files[folder_name].append(file)
579
+
580
+ list_of_lists = list(grouped_files.values())
581
+ return list_of_lists
582
+
583
+
584
+ def generate_timestamp():
585
+ now = datetime.datetime.now()
586
+ timestamp = now.strftime('%y%m%d_%H%M%S')
587
+ milliseconds = f"{int(now.microsecond / 1000):03d}"
588
+ random_number = random.randint(0, 9999)
589
+ return f"{timestamp}_{milliseconds}_{random_number}"
590
+
591
+
592
+ def write_PIL_image_with_png_info(image, metadata, path):
593
+ from PIL.PngImagePlugin import PngInfo
594
+
595
+ png_info = PngInfo()
596
+ for key, value in metadata.items():
597
+ png_info.add_text(key, value)
598
+
599
+ image.save(path, "PNG", pnginfo=png_info)
600
+ return image
601
+
602
+
603
+ def torch_safe_save(content, path):
604
+ torch.save(content, path + '_tmp')
605
+ os.replace(path + '_tmp', path)
606
+ return path
607
+
608
+
609
+ def move_optimizer_to_device(optimizer, device):
610
+ for state in optimizer.state.values():
611
+ for k, v in state.items():
612
+ if isinstance(v, torch.Tensor):
613
+ state[k] = v.to(device)