File size: 7,023 Bytes
b11ecdd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
"""
VIT的transformer结构没有因果掩码,因为任意一个位置都能访问其它位置,它们之间没有因果关系,或者说关系很弱
文本生成仍然考虑因果掩码。
"""
import torch.nn.functional as F
from VIT import model as VIT
from Text_Encoder import text_encoder as transformer
import torch.nn as nn
import torch
from Text_Encoder import MLP
class Prompt_block(nn.Module):
def __init__(self,config):
super(Prompt_block,self).__init__()
self.prompt_embedding=nn.Embedding(config.prompt_num,config.hidden_size,dtype=config.dtype,device=config.device)
def forward(self,text_embeddings):
b,_,_=text_embeddings.size()
n,dim=self.prompt_embedding.weight.size()
"""
new_embeddings=[]
for batch,index_ in enumerate(index):
text_embedding=text_embeddings[0]
text_embedding=torch.cat((text_embedding[:index_,:],self.prompt_embedding.weight,text_embedding[index_:,:]),0)
new_embeddings.append(text_embedding)
stacked_embedding= torch.stack(new_embeddings, dim=0)
return stacked_embedding
"""
text_embeddings=torch.cat((text_embeddings[:,0:1,:],self.prompt_embedding.weight.expand(b,n,dim),text_embeddings[:,1:,:]),1)
return text_embeddings
class CLIP(nn.Module):
def __init__(self,config):
super().__init__()
self.visual=VIT
self.device=config.device
self.dtype=config.dtype
self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,dtype=config.dtype,device=config.device)
self.max_position_embeddings=config.max_position_embeddings
self.prompt_num=config.prompt_num
self.transformer=transformer
#增加一个prompt block
self.prompt_block=Prompt_block(config)
self.positional_embedding=nn.Parameter(torch.empty(config.max_position_embeddings,config.hidden_size,device=config.device))
self.ln_final=nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps,dtype=config.dtype,device=config.device)
self.text_projection=nn.Parameter(torch.empty(config.hidden_size,config.hidden_size,device=config.device))
self.logit_scale=nn.Parameter(torch.empty([],dtype=config.dtype,device=config.device)*config.logit_scale_init,requires_grad=False)
def encode_image(self,img,use_emotion=True):
cls_embedding=self.visual(img,use_emotion)
#cls_embedding:[batch_size,1,512],image_embedding:[batch_size,7,512]
return cls_embedding
def encode_text(self,text,use_emotion=True):
#预留20token的位置
b,n=text.size()
index=text.argmax(dim=-1)
text_embedding=self.token_embedding(text)
#text_embedding=self.prompt_block(index,text_embedding)
if n==self.max_position_embeddings-self.prompt_num:
text_embedding=self.prompt_block(text_embedding)
index=index+torch.tensor(20,device=index.device,dtype=index.dtype)
position_embedding=self.positional_embedding[None,:text_embedding.shape[1],:].to(self.dtype)
text_embedding=position_embedding+text_embedding
text_embedding=self.transformer(text_embedding,use_emotion=use_emotion)
text_embedding=self.ln_final(text_embedding)
#传入的标记有
#print(index[0],index_new[0],text_embedding.shape)
text_embedding=text_embedding[torch.arange(text.shape[0]),index]
[email protected]_projection.to(self.dtype)
return text_embedding
def forward(self,image,text,use_emotion=True):
image_features=self.encode_image(image,use_emotion)
text_features=self.encode_text(text,use_emotion)
# normalized features
image_features=image_features/image_features.norm(dim=-1,keepdim=True)
text_features=text_features/text_features.norm(dim=-1,keepdim=True)
# cosine similarity as logits
logit_scale=self.logit_scale.exp()
logits_per_image=logit_scale*image_features@text_features.t()
logits_per_text=logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_image,logits_per_text
class Config:
def __init__(self):
self.vocab_size=49408
self.image_dim=768
self.num_patches=49
self.patch_size=32
self.hidden_size=512
self.prompt_num=20
self.max_position_embeddings=77
self.num_hidden_layers=12
self.num_attention_heads=8
self.head_size=64
self.layer_norm_eps=1e-5
self.activation_function="Quickgelu"
self.dtype=torch.float16
self.device=torch.device("cuda:0")
self.logit_scale_init=4.6052
self.num_virtual_tokens=20
self.token_dim=self.hidden_size
self.encoder_hidden_size=self.hidden_size
config=Config()
model=CLIP(config)
#加载预训练权重
model.load_state_dict(torch.load(r'./EmotionCLIP-V2.pth',weights_only=True,map_location='cpu'),strict=True)
"""
for name, param in model.named_parameters():
if 'prefix' not in name and 'prompt' not in name and 'ln' not in name: # 如果参数名中不包含'prefix'
print(name,"'s requires_grad turn off.")
param.requires_grad = False # 冻结该参数
else:
print(name,"'s requires_grad turn on.")
param.requires_grad = True # 允许该参数进行训练
"""
#编译模型
#model=torch.compile(model)
import pickle
from PIL import Image
import numpy as np
import clip
with open('./preprocess.pkl','rb') as f:
preprocess = pickle.load(f)
with open('./tokenize.pkl','rb') as f:
tokenizer=pickle.load(f)
device=config.device
image = preprocess(Image.open("Dog sad.jpg")).unsqueeze(0).to(device)
# 情感识别
labels=[
'amusement',
'anger',
'awe',
'contentment',
'disgust',
'excitement',
'fear',
'sadness',
'neutral'
]
text_list=[ f"This picture conveys a sense of {label}" for label in labels]
tokens= tokenizer(text_list,
context_length=57).to(device)
with torch.no_grad():
logits_per_image, logits_per_text = model(image.to(config.dtype), tokens)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
# 获取预测标签
predicted_index = np.argmax(probs, axis=1)
predicted_label=labels[predicted_index[0]]
print("情感识别:", probs)
print("预测的情感标签:", predicted_label)
# 泛化性能
labels=[
'spider',
'dog',
'cat',
'fish'
]
text_list=[ f"This is a {label}" for label in labels]
tokens= tokenizer(text_list,context_length=57).to(device)
with torch.no_grad():
logits_per_image, logits_per_text = model(image.to(config.dtype), tokens, use_emotion=False)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
# 获取预测标签
predicted_index = np.argmax(probs, axis=1)
predicted_label=labels[predicted_index[0]]
print("泛化识别:", probs)
print("预测的泛化标签:", predicted_label)
|