""" register the attention controller into the UNet of stable diffusion Build a customized attention function `_attention' Replace the original attention function with `forward' and `spatial_temporal_forward' in attention_controlled_forward function Most of spatial_temporal_forward is directly copy from `video_diffusion/models/attention.py' TODO FIXME: merge redundant code with attention.py """ from einops import rearrange import torch import torch.nn.functional as F import math from diffusers.utils.import_utils import is_xformers_available import numpy as np if is_xformers_available(): import xformers import xformers.ops else: xformers = None def register_attention_control(model, controller, text_cond, clip_length, height, width, ddim_inversion): "Connect a model with a controller" def attention_controlled_forward(self, place_in_unet, attention_type='cross'): to_out = self.to_out if type(to_out) is torch.nn.modules.container.ModuleList: to_out = self.to_out[0] else: to_out = self.to_out def _attention(query, key, value, is_cross, attention_mask=None): if self.upcast_attention: query = query.float() key = key.float() # print("query",query.shape) # print("key",key.shape) attention_scores = torch.baddbmm( torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), query, key.transpose(-1, -2), beta=0, alpha=self.scale, ) #print("attention_scores",attention_scores.shape) if attention_mask is not None: attention_scores = attention_scores + attention_mask if self.upcast_softmax: attention_scores = attention_scores.float() # START OF CORE FUNCTION # if not ddim_inversion: attention_probs = controller(reshape_batch_dim_to_temporal_heads(attention_scores), is_cross, place_in_unet) attention_probs = reshape_temporal_heads_to_batch_dim(attention_probs) # END OF CORE FUNCTION attention_probs = attention_probs.softmax(dim=-1) # cast back to the original dtype attention_probs = attention_probs.to(value.dtype) # compute attention output hidden_states = torch.bmm(attention_probs, value) # reshape hidden_states hidden_states = reshape_batch_dim_to_heads(hidden_states) return hidden_states def reshape_temporal_heads_to_batch_dim(tensor): head_size = self.heads tensor = rearrange(tensor, " b h s t -> (b h) s t ", h = head_size) return tensor def reshape_batch_dim_to_temporal_heads(tensor): head_size = self.heads tensor = rearrange(tensor, "(b h) s t -> b h s t", h = head_size) return tensor def reshape_heads_to_batch_dim3(tensor): batch_size1, batch_size2, seq_len, dim = tensor.shape head_size = self.heads tensor = tensor.reshape(batch_size1, batch_size2, seq_len, head_size, dim // head_size) tensor = tensor.permute(0, 3, 1, 2, 4) return tensor def reshape_heads_to_batch_dim(tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) return tensor def reshape_batch_dim_to_heads(tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor def _memory_efficient_attention_xformers(query, key, value, attention_mask): # TODO attention_mask query = query.contiguous() key = key.contiguous() value = value.contiguous() hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) hidden_states = reshape_batch_dim_to_heads(hidden_states) return hidden_states def forward(hidden_states, encoder_hidden_states=None, attention_mask=None): # hidden_states: torch.Size([16, 4096, 320]) # encoder_hidden_states: torch.Size([16, 77, 768]) is_cross = encoder_hidden_states is not None #encoder_hidden_states = encoder_hidden_states text_cond_frames = text_cond.repeat_interleave(clip_length, 0) # wrong implementation text_cond.repeat(clip_length,1,1) ######for debug###### # text_cond_repeat_interleave = text_cond.repeat_interleave(clip_length, 0) # print("after repeat interleave", text_cond_repeat_interleave.shape, text_cond_repeat_interleave.view(-1)[:20]) # text_cond_repeat = text_cond.repeat(clip_length,1,1) # print("First 20 elements after repeat:", text_cond_repeat.shape, text_cond_repeat.view(-1)[:20]) ######for debug###### encoder_hidden_states = text_cond_frames if self.group_norm is not None: hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = self.to_q(hidden_states) query = reshape_heads_to_batch_dim(query) if self.added_kv_proj_dim is not None: key = self.to_k(hidden_states) value = self.to_v(hidden_states) encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) key = reshape_heads_to_batch_dim(key) value = reshape_heads_to_batch_dim(value) encoder_hidden_states_key_proj = reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) else: encoder_hidden_states = text_cond_frames if encoder_hidden_states is not None else hidden_states key = self.to_k(encoder_hidden_states) value = self.to_v(encoder_hidden_states) key = reshape_heads_to_batch_dim(key) value = reshape_heads_to_batch_dim(value) if attention_mask is not None: if attention_mask.shape[-1] != query.shape[1]: target_length = query.shape[1] attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) if self._use_memory_efficient_attention_xformers and query.shape[-2] > ((height//2) * (width//2)): # for large attention map of 64X64, use xformers to save memory hidden_states = _memory_efficient_attention_xformers(query, key, value, attention_mask) # Some versions of xformers return output in fp32, cast it back to the dtype of the input hidden_states = hidden_states.to(query.dtype) else: hidden_states = _attention(query, key, value, is_cross=is_cross, attention_mask=attention_mask) # else: # hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) # linear proj hidden_states = self.to_out[0](hidden_states) #dropout hidden_states = self.to_out[1](hidden_states) return hidden_states def spatial_temporal_forward( hidden_states, encoder_hidden_states=None, attention_mask=None, clip_length: int = None, SparseCausalAttention_index: list = [-1, 'first'] #list = [0] ): """ Most of spatial_temporal_forward is directly copy from `video_diffusion.models.attention.SparseCausalAttention' We add two modification 1. use self defined attention function that is controlled by AttentionControlEdit module 2. remove the dropout to reduce randomness FIXME: merge redundant code with attention.py """ if ( self.added_kv_proj_dim is not None or encoder_hidden_states is not None or attention_mask is not None ): raise NotImplementedError if self.group_norm is not None: hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = self.to_q(hidden_states) query = reshape_heads_to_batch_dim(query) key = self.to_k(hidden_states) value = self.to_v(hidden_states) if clip_length is not None: key = rearrange(key, "(b f) d c -> b f d c", f=clip_length) value = rearrange(value, "(b f) d c -> b f d c", f=clip_length) # *********************** Start of Spatial-temporal attention ********** frame_index_list = [] if len(SparseCausalAttention_index) > 0: for index in SparseCausalAttention_index: if isinstance(index, str): if index == 'first': frame_index = [0] * clip_length if index == 'last': frame_index = [clip_length-1] * clip_length if (index == 'mid') or (index == 'middle'): frame_index = [int((clip_length-1)//2)] * clip_length else: assert isinstance(index, int), 'relative index must be int' frame_index = torch.arange(clip_length) + index frame_index = frame_index.clip(0, clip_length-1) frame_index_list.append(frame_index) # print("frame_index_list",frame_index_list) [bz, frame, 4096, 320] key = torch.cat([ key[:, frame_index] for frame_index in frame_index_list #[bz, frame, 8192, 320]) ], dim=2) value = torch.cat([ value[:, frame_index] for frame_index in frame_index_list ], dim=2) # *********************** End of Spatial-temporal attention ********** key = rearrange(key, "b f d c -> (b f) d c", f=clip_length) value = rearrange(value, "b f d c -> (b f) d c", f=clip_length) # print("key after rearrange",key.shape) # print("value after rearrange",value.shape) key = reshape_heads_to_batch_dim(key) value = reshape_heads_to_batch_dim(value) # print("query after head to batch dim",query.shape) # print("key after head to batch dim",key.shape) if torch.isnan(query.reshape(-1)[0]): print("nan value query",query.reshape(-1)[:10]) print("nan value key",key.reshape(-1)[:10]) exit() # print("query after reshape heads to batch ",query.shape) # print("key after reshape heads to batch",key.shape) if self._use_memory_efficient_attention_xformers and query.shape[-2] > ((height//2) * (width//2)): # FIXME there should be only one variable to control whether use xformers # if self._use_memory_efficient_attention_xformers: # for large attention map of 64X64, use xformers to save memory hidden_states = _memory_efficient_attention_xformers(query, key, value, attention_mask) # Some versions of xformers return output in fp32, cast it back to the dtype of the input hidden_states = hidden_states.to(query.dtype) else: # if self._slice_size is None or query.shape[0] // self._slice_size == 1: hidden_states = _attention(query, key, value, attention_mask=attention_mask, is_cross=False) # else: # hidden_states = self._sliced_attention( # query, key, value, hidden_states.shape[1], dim, attention_mask # ) # linear proj hidden_states = self.to_out[0](hidden_states) # dropout hidden_states = self.to_out[1](hidden_states) return hidden_states def _sliced_attention(query, key, value, sequence_length, dim, attention_mask): #query (bz*heads, t x h x w, org_dim//heads ) is_cross = False batch_size_attention = query.shape[0] # bz * heads hidden_states = torch.zeros( (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype ) slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] if ddim_inversion: per_frame_len = sequence_length//clip_length attention_store = torch.zeros((batch_size_attention, clip_length, per_frame_len, per_frame_len), device=query.device, dtype=query.dtype) for i in range(hidden_states.shape[0] // slice_size): start_idx = i * slice_size end_idx = (i + 1) * slice_size query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] if self.upcast_attention: query_slice = query_slice.float() key_slice = key_slice.float() attn_slice = torch.baddbmm( torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), query_slice, key_slice.transpose(-1, -2), beta=0, alpha=self.scale, ) if attention_mask is not None: attn_slice = attn_slice + attention_mask[start_idx:end_idx] if self.upcast_softmax: attn_slice = attn_slice.float() if i < self.heads: if not ddim_inversion: attention_probs = controller((attn_slice.unsqueeze(1)),is_cross, place_in_unet) attn_slice = attention_probs.squeeze(1) attn_slice = attn_slice.softmax(dim=-1) # cast back to the original dtype attn_slice = attn_slice.to(value.dtype) ## bz == 1, sliced head if ddim_inversion: # attn_slice (1, thw, thw) bz, thw, thw = attn_slice.shape t = clip_length hw = thw // t # 初始化 per_frame_attention # (1, t, hxw) per_frame_attention = torch.empty((bz, t, hw, hw), device=attn_slice.device) # # 循环提取每一帧的对角线注意力 for idx in range(t): start_idx_ = idx * hw end_idx_ = (idx + 1) * hw # per frame attention extraction per_frame_attention[:, idx, :, :] = attn_slice[:, start_idx_:end_idx_, start_idx_:end_idx_] # current_query_block = attn_slice[:, start_idx_:end_idx_, :] # aggregated_attention = current_query_block.view(bz, hw, t, hw).mean(dim=2) # # print('aggregated_attention',aggregated_attention.shape) # per_frame_attention[:, idx, :, :] = aggregated_attention per_frame_attention = rearrange(per_frame_attention, "b t h w -> (b t) h w") attention_store[start_idx:end_idx] = per_frame_attention attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice if ddim_inversion: # attention store (bz*heads, t , h, w) h=res, w=res _ = controller(attention_store, is_cross, place_in_unet) # reshape hidden_states hidden_states = reshape_batch_dim_to_heads(hidden_states) return hidden_states def fully_frame_forward(hidden_states, encoder_hidden_states=None, attention_mask=None, clip_length=None, inter_frame=False, **kwargs): batch_size, sequence_length, _ = hidden_states.shape # print("hidden_states.shape",hidden_states.shape) # print("sequence_length",sequence_length) encoder_hidden_states = encoder_hidden_states h = kwargs['height'] w = kwargs['width'] if self.group_norm is not None: hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = self.to_q(hidden_states) # (bf) x d(hw) x c self.q = query if self.inject_q is not None: query = self.inject_q dim = query.shape[-1] query_old = query.clone() # All frames #init query (bz*t, hxw, dim) query = rearrange(query, "(b f) d c -> b (f d) c", f=clip_length) query = reshape_heads_to_batch_dim(query) #(bz*heads, txhxw, dim//heads) if self.added_kv_proj_dim is not None: raise NotImplementedError encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = self.to_k(encoder_hidden_states) self.k = key if self.inject_k is not None: key = self.inject_k key_old = key.clone() value = self.to_v(encoder_hidden_states) if inter_frame: key = rearrange(key, "(b f) d c -> b f d c", f=clip_length)[:, [0, -1]] value = rearrange(value, "(b f) d c -> b f d c", f=clip_length)[:, [0, -1]] key = rearrange(key, "b f d c -> b (f d) c",) value = rearrange(value, "b f d c -> b (f d) c") else: # All frames key = rearrange(key, "(b f) d c -> b (f d) c", f=clip_length) value = rearrange(value, "(b f) d c -> b (f d) c", f=clip_length) key = reshape_heads_to_batch_dim(key) value = reshape_heads_to_batch_dim(value) if attention_mask is not None: if attention_mask.shape[-1] != query.shape[1]: target_length = query.shape[1] attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) #print("query.shape[0]",query.shape[0]) # 16 self._slice_size = 1 ### 8 sequence_length_full_frame = query.shape[1] # attention, what we cannot get enough of if self._use_memory_efficient_attention_xformers and query.shape[-2] > clip_length*(32 ** 2): hidden_states = _memory_efficient_attention_xformers(query, key, value, attention_mask) # Some versions of xformers return output in fp32, cast it back to the dtype of the input hidden_states = hidden_states.to(query.dtype) else: # if ddim_inversion: # #if self._slice_size is None or query.shape[0] // self._slice_size == 1: # hidden_states = _attention(query, key, value, attention_mask) # else: hidden_states = _sliced_attention(query, key, value, sequence_length_full_frame, dim, attention_mask) if [h,w] in kwargs['flatten_res']: hidden_states = rearrange(hidden_states, "b (f d) c -> (b f) d c", f=clip_length) if self.group_norm is not None: hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) if kwargs["old_qk"] == 1: query = query_old key = key_old else: query = hidden_states key = hidden_states value = hidden_states traj = kwargs["traj"] traj = rearrange(traj, '(f n) l d -> f n l d', f=clip_length, n=sequence_length) mask = rearrange(kwargs["mask"], '(f n) l -> f n l', f=clip_length, n=sequence_length) mask = torch.cat([mask[:, :, 0].unsqueeze(-1), mask[:, :, -clip_length+1:]], dim=-1) #print('traj',traj.shape) #print('mask',mask.shape) traj_key_sequence_inds = torch.cat([traj[:, :, 0, :].unsqueeze(-2), traj[:, :, -clip_length+1:, :]], dim=-2) t_inds = traj_key_sequence_inds[:, :, :, 0] x_inds = traj_key_sequence_inds[:, :, :, 1] y_inds = traj_key_sequence_inds[:, :, :, 2] query_tempo = query.unsqueeze(-2) _key = rearrange(key, '(b f) (h w) d -> b f h w d', b=int(batch_size/clip_length), f=clip_length, h=h, w=w) _value = rearrange(value, '(b f) (h w) d -> b f h w d', b=int(batch_size/clip_length), f=clip_length, h=h, w=w) key_tempo = _key[:, t_inds, x_inds, y_inds] value_tempo = _value[:, t_inds, x_inds, y_inds] key_tempo = rearrange(key_tempo, 'b f n l d -> (b f) n l d') value_tempo = rearrange(value_tempo, 'b f n l d -> (b f) n l d') mask = rearrange(torch.stack([mask, mask]), 'b f n l -> (b f) n l') mask = mask[:,None].repeat(1, self.heads, 1, 1).unsqueeze(-2) attn_bias = torch.zeros_like(mask, dtype=key_tempo.dtype) # regular zeros_like attn_bias[~mask] = -torch.inf # print('attn_bias',attn_bias.shape) # print('query_tempo',query_tempo.shape) # print('key_tempo',key_tempo.shape) # flow attention query_tempo = reshape_heads_to_batch_dim3(query_tempo) key_tempo = reshape_heads_to_batch_dim3(key_tempo) value_tempo = reshape_heads_to_batch_dim3(value_tempo) attn_matrix2 = query_tempo @ key_tempo.transpose(-2, -1) / math.sqrt(query_tempo.size(-1)) + attn_bias attn_matrix2 = F.softmax(attn_matrix2, dim=-1) out = (attn_matrix2@value_tempo).squeeze(-2) hidden_states = rearrange(out,'(b f) k (h w) d -> b (f h w) (k d)', b=int(batch_size/clip_length), f=clip_length, h=h, w=w) # linear proj hidden_states = self.to_out[0](hidden_states) # dropout hidden_states = self.to_out[1](hidden_states) # All frames hidden_states = rearrange(hidden_states, "b (f d) c -> (b f) d c", f=clip_length) return hidden_states if attention_type == 'CrossAttention': # return mod_forward return forward elif attention_type == "SparseCausalAttention": #return mod_forward return spatial_temporal_forward elif attention_type == "FullyFrameAttention": #return mod_forward return fully_frame_forward class DummyController: def __call__(self, *args): return args[0] def __init__(self): self.num_att_layers = 0 if controller is None: controller = DummyController() def register_recr(net_, count, place_in_unet): if net_[1].__class__.__name__ == 'CrossAttention' \ or net_[1].__class__.__name__ == 'FullyFrameAttention' \ or net_[1].__class__.__name__ == 'SparseCausalAttention' : net_[1].forward = attention_controlled_forward(net_[1], place_in_unet, attention_type = net_[1].__class__.__name__) return count + 1 elif hasattr(net_[1], 'children'): for net in net_[1].named_children(): if net[0] !='attn_temporal': count = register_recr(net, count, place_in_unet) return count cross_att_count = 0 sub_nets = model.unet.named_children() for net in sub_nets: if "down" in net[0]: cross_att_count += register_recr(net, 0, "down") elif "up" in net[0]: cross_att_count += register_recr(net, 0, "up") elif "mid" in net[0]: cross_att_count += register_recr(net, 0, "mid") #print(f"Number of attention layer registered {cross_att_count}") controller.num_att_layers = cross_att_count