RealCustom / inference /mask_generation.py
CoreloneH's picture
Add application file
7cc4b41
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from einops import rearrange
def mask_generation(
crossmap_2d_list, selfmap_2d_list=None,
target_token=None, mask_scope=None,
mask_target_h=64, mask_target_w=64,
mask_mode=["binary"],
):
if len(selfmap_2d_list) > 0:
target_hw_selfmap = mask_target_h * mask_target_w
selfmap_2ds = []
for i in range(len(selfmap_2d_list)):
selfmap_ = selfmap_2d_list[i]
selfmap_ = F.interpolate(selfmap_, size=(target_hw_selfmap, target_hw_selfmap), mode='bilinear')
selfmap_2ds.append(selfmap_ )
selfmap_2ds = torch.cat(selfmap_2ds, dim=1)
if "selfmap_min_max_per_channel" in mask_mode:
selfmap_1ds = rearrange(selfmap_2ds, "b c h w -> b c (h w)")
channel_max_self = torch.max(selfmap_1ds, dim=-1, keepdim=True)[0].unsqueeze(-1)
channel_min_self = torch.min(selfmap_1ds, dim=-1, keepdim=True)[0].unsqueeze(-1)
selfmap_2ds = (selfmap_2ds - channel_min_self) / (channel_max_self - channel_min_self + 1e-6)
elif "selfmap_max_norm" in mask_mode:
selfmap_1ds = rearrange(selfmap_2ds, "b c h w -> b c (h w)")
b = selfmap_1ds.size(0)
batch_max = torch.max(selfmap_1ds.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(-1).unsqueeze(-1)
selfmap_2ds = selfmap_2ds / (batch_max + 1e-10)
selfmap_2d = selfmap_2ds.mean(dim=1, keepdim=True)
else:
selfmap_2d = None
crossmap_2ds = []
for i in range(len(crossmap_2d_list)):
crossmap = crossmap_2d_list[i]
crossmap = crossmap.mean(dim=1) # average on head dim
crossmap = crossmap * target_token.unsqueeze(-1).unsqueeze(-1) # target token valid
crossmap = crossmap.sum(dim=1, keepdim=True)
crossmap = F.interpolate(crossmap, size=(mask_target_h, mask_target_w), mode='bilinear')
crossmap_2ds.append(crossmap)
crossmap_2ds = torch.cat(crossmap_2ds, dim=1)
crossmap_1ds = rearrange(crossmap_2ds, "b c h w -> b c (h w)")
if "max_norm" in mask_mode:
crossmap_1d_avg = torch.mean(crossmap_1ds, dim=1, keepdim=True) # [b, 1, (h w)]
if selfmap_2d is not None:
crossmap_1d_avg = torch.matmul(selfmap_2d, crossmap_1d_avg.unsqueeze(-1)).squeeze(-1)
b, c, n = crossmap_1ds.shape
batch_max = torch.max(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1)
crossmap_1d_avg = crossmap_1d_avg / (batch_max + 1e-6)
elif "min_max_norm" in mask_mode:
crossmap_1d_avg = torch.mean(crossmap_1ds, dim=1, keepdim=True) # [b, 1, (h w)]
if selfmap_2d is not None:
crossmap_1d_avg = torch.matmul(selfmap_2d, crossmap_1d_avg.unsqueeze(-1)).squeeze(-1)
b, c, n = crossmap_1ds.shape
batch_max = torch.max(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1) # NOTE unsqueeze
batch_min = torch.min(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1) # NOTE unsqueeze
crossmap_1d_avg = (crossmap_1d_avg - batch_min) / (batch_max - batch_min + 1e-6)
elif "min_max_per_channel" in mask_mode:
channel_max = torch.max(crossmap_1ds, dim=-1, keepdim=True)[0]
channel_min = torch.min(crossmap_1ds, dim=-1, keepdim=True)[0]
crossmap_1ds = (crossmap_1ds - channel_min) / (channel_max - channel_min + 1e-6)
crossmap_1d_avg = torch.mean(crossmap_1ds, dim=1, keepdim=True) # [b, 1, (h w)]
if selfmap_2d is not None:
crossmap_1d_avg = torch.matmul(selfmap_2d, crossmap_1d_avg.unsqueeze(-1)).squeeze(-1)
# renormalize to 0-1
b, c, n = crossmap_1d_avg.shape
batch_max = torch.max(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1)
batch_min = torch.min(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1)
crossmap_1d_avg = (crossmap_1d_avg - batch_min) / (batch_max - batch_min + 1e-6)
else:
crossmap_1d_avg = torch.mean(crossmap_1ds, dim=1, keepdim=True) # [b, 1, (h w)]
if "threshold" in mask_mode:
threshold = 1 - mask_scope
crossmap_1d_avg[crossmap_1d_avg < threshold] = 0.0
if "binary" in mask_mode:
crossmap_1d_avg[crossmap_1d_avg > threshold] = 1.0
else:
# topk
topk_num = int(crossmap_1d_avg.size(-1) * mask_scope)
sort_score, sort_order = crossmap_1d_avg.sort(descending=True, dim=-1)
sort_topk = sort_order[:, :, :topk_num]
sort_topk_remain = sort_order[:, :, topk_num:]
crossmap_1d_avg = crossmap_1d_avg.scatter(2, sort_topk_remain, 0.)
if "binary" in mask_mode:
crossmap_1d_avg = crossmap_1d_avg.scatter(2, sort_topk, 1.0)
crossmap_2d_avg = rearrange(crossmap_1d_avg, "b c (h w) -> b c h w", h=mask_target_h, w=mask_target_w)
crossmap_2d_avg = crossmap_2d_avg
output = crossmap_2d_avg.unsqueeze(1) # torch.Size([4, 1, 60, 64, 64]), The second dimension is the dimension of the number of reference images.
if output.size(2) == 1: # The dimension of the layer.
output = output.squeeze(2) # If there is only a single dimension, then all layers will share the same mask.
return output