Spaces:
Building
Building
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import torch.nn.functional as F | |
from einops import rearrange | |
def mask_generation( | |
crossmap_2d_list, selfmap_2d_list=None, | |
target_token=None, mask_scope=None, | |
mask_target_h=64, mask_target_w=64, | |
mask_mode=["binary"], | |
): | |
if len(selfmap_2d_list) > 0: | |
target_hw_selfmap = mask_target_h * mask_target_w | |
selfmap_2ds = [] | |
for i in range(len(selfmap_2d_list)): | |
selfmap_ = selfmap_2d_list[i] | |
selfmap_ = F.interpolate(selfmap_, size=(target_hw_selfmap, target_hw_selfmap), mode='bilinear') | |
selfmap_2ds.append(selfmap_ ) | |
selfmap_2ds = torch.cat(selfmap_2ds, dim=1) | |
if "selfmap_min_max_per_channel" in mask_mode: | |
selfmap_1ds = rearrange(selfmap_2ds, "b c h w -> b c (h w)") | |
channel_max_self = torch.max(selfmap_1ds, dim=-1, keepdim=True)[0].unsqueeze(-1) | |
channel_min_self = torch.min(selfmap_1ds, dim=-1, keepdim=True)[0].unsqueeze(-1) | |
selfmap_2ds = (selfmap_2ds - channel_min_self) / (channel_max_self - channel_min_self + 1e-6) | |
elif "selfmap_max_norm" in mask_mode: | |
selfmap_1ds = rearrange(selfmap_2ds, "b c h w -> b c (h w)") | |
b = selfmap_1ds.size(0) | |
batch_max = torch.max(selfmap_1ds.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(-1).unsqueeze(-1) | |
selfmap_2ds = selfmap_2ds / (batch_max + 1e-10) | |
selfmap_2d = selfmap_2ds.mean(dim=1, keepdim=True) | |
else: | |
selfmap_2d = None | |
crossmap_2ds = [] | |
for i in range(len(crossmap_2d_list)): | |
crossmap = crossmap_2d_list[i] | |
crossmap = crossmap.mean(dim=1) # average on head dim | |
crossmap = crossmap * target_token.unsqueeze(-1).unsqueeze(-1) # target token valid | |
crossmap = crossmap.sum(dim=1, keepdim=True) | |
crossmap = F.interpolate(crossmap, size=(mask_target_h, mask_target_w), mode='bilinear') | |
crossmap_2ds.append(crossmap) | |
crossmap_2ds = torch.cat(crossmap_2ds, dim=1) | |
crossmap_1ds = rearrange(crossmap_2ds, "b c h w -> b c (h w)") | |
if "max_norm" in mask_mode: | |
crossmap_1d_avg = torch.mean(crossmap_1ds, dim=1, keepdim=True) # [b, 1, (h w)] | |
if selfmap_2d is not None: | |
crossmap_1d_avg = torch.matmul(selfmap_2d, crossmap_1d_avg.unsqueeze(-1)).squeeze(-1) | |
b, c, n = crossmap_1ds.shape | |
batch_max = torch.max(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1) | |
crossmap_1d_avg = crossmap_1d_avg / (batch_max + 1e-6) | |
elif "min_max_norm" in mask_mode: | |
crossmap_1d_avg = torch.mean(crossmap_1ds, dim=1, keepdim=True) # [b, 1, (h w)] | |
if selfmap_2d is not None: | |
crossmap_1d_avg = torch.matmul(selfmap_2d, crossmap_1d_avg.unsqueeze(-1)).squeeze(-1) | |
b, c, n = crossmap_1ds.shape | |
batch_max = torch.max(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1) # NOTE unsqueeze | |
batch_min = torch.min(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1) # NOTE unsqueeze | |
crossmap_1d_avg = (crossmap_1d_avg - batch_min) / (batch_max - batch_min + 1e-6) | |
elif "min_max_per_channel" in mask_mode: | |
channel_max = torch.max(crossmap_1ds, dim=-1, keepdim=True)[0] | |
channel_min = torch.min(crossmap_1ds, dim=-1, keepdim=True)[0] | |
crossmap_1ds = (crossmap_1ds - channel_min) / (channel_max - channel_min + 1e-6) | |
crossmap_1d_avg = torch.mean(crossmap_1ds, dim=1, keepdim=True) # [b, 1, (h w)] | |
if selfmap_2d is not None: | |
crossmap_1d_avg = torch.matmul(selfmap_2d, crossmap_1d_avg.unsqueeze(-1)).squeeze(-1) | |
# renormalize to 0-1 | |
b, c, n = crossmap_1d_avg.shape | |
batch_max = torch.max(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1) | |
batch_min = torch.min(crossmap_1d_avg.view(b, -1), dim=-1, keepdim=True)[0].unsqueeze(1) | |
crossmap_1d_avg = (crossmap_1d_avg - batch_min) / (batch_max - batch_min + 1e-6) | |
else: | |
crossmap_1d_avg = torch.mean(crossmap_1ds, dim=1, keepdim=True) # [b, 1, (h w)] | |
if "threshold" in mask_mode: | |
threshold = 1 - mask_scope | |
crossmap_1d_avg[crossmap_1d_avg < threshold] = 0.0 | |
if "binary" in mask_mode: | |
crossmap_1d_avg[crossmap_1d_avg > threshold] = 1.0 | |
else: | |
# topk | |
topk_num = int(crossmap_1d_avg.size(-1) * mask_scope) | |
sort_score, sort_order = crossmap_1d_avg.sort(descending=True, dim=-1) | |
sort_topk = sort_order[:, :, :topk_num] | |
sort_topk_remain = sort_order[:, :, topk_num:] | |
crossmap_1d_avg = crossmap_1d_avg.scatter(2, sort_topk_remain, 0.) | |
if "binary" in mask_mode: | |
crossmap_1d_avg = crossmap_1d_avg.scatter(2, sort_topk, 1.0) | |
crossmap_2d_avg = rearrange(crossmap_1d_avg, "b c (h w) -> b c h w", h=mask_target_h, w=mask_target_w) | |
crossmap_2d_avg = crossmap_2d_avg | |
output = crossmap_2d_avg.unsqueeze(1) # torch.Size([4, 1, 60, 64, 64]), The second dimension is the dimension of the number of reference images. | |
if output.size(2) == 1: # The dimension of the layer. | |
output = output.squeeze(2) # If there is only a single dimension, then all layers will share the same mask. | |
return output |