""" VIT的transformer结构没有因果掩码,因为任意一个位置都能访问其它位置,它们之间没有因果关系,或者说关系很弱 文本生成仍然考虑因果掩码。 """ import torch.nn.functional as F from VIT import model as VIT from Text_Encoder import text_encoder as transformer import torch.nn as nn import torch from Text_Encoder import MLP class Prompt_block(nn.Module): def __init__(self,config): super(Prompt_block,self).__init__() self.prompt_embedding=nn.Embedding(config.prompt_num,config.hidden_size,dtype=config.dtype,device=config.device) def forward(self,text_embeddings): b,_,_=text_embeddings.size() n,dim=self.prompt_embedding.weight.size() """ new_embeddings=[] for batch,index_ in enumerate(index): text_embedding=text_embeddings[0] text_embedding=torch.cat((text_embedding[:index_,:],self.prompt_embedding.weight,text_embedding[index_:,:]),0) new_embeddings.append(text_embedding) stacked_embedding= torch.stack(new_embeddings, dim=0) return stacked_embedding """ text_embeddings=torch.cat((text_embeddings[:,0:1,:],self.prompt_embedding.weight.expand(b,n,dim),text_embeddings[:,1:,:]),1) return text_embeddings class CLIP(nn.Module): def __init__(self,config): super().__init__() self.visual=VIT self.device=config.device self.dtype=config.dtype self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,dtype=config.dtype,device=config.device) self.max_position_embeddings=config.max_position_embeddings self.prompt_num=config.prompt_num self.transformer=transformer #增加一个prompt block self.prompt_block=Prompt_block(config) self.positional_embedding=nn.Parameter(torch.empty(config.max_position_embeddings,config.hidden_size,device=config.device)) self.ln_final=nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps,dtype=config.dtype,device=config.device) self.text_projection=nn.Parameter(torch.empty(config.hidden_size,config.hidden_size,device=config.device)) self.logit_scale=nn.Parameter(torch.empty([],dtype=config.dtype,device=config.device)*config.logit_scale_init,requires_grad=False) def encode_image(self,img,use_emotion=True): cls_embedding=self.visual(img,use_emotion) #cls_embedding:[batch_size,1,512],image_embedding:[batch_size,7,512] return cls_embedding def encode_text(self,text,use_emotion=True): #预留20token的位置 b,n=text.size() index=text.argmax(dim=-1) text_embedding=self.token_embedding(text) #text_embedding=self.prompt_block(index,text_embedding) if n==self.max_position_embeddings-self.prompt_num: text_embedding=self.prompt_block(text_embedding) index=index+torch.tensor(20,device=index.device,dtype=index.dtype) position_embedding=self.positional_embedding[None,:text_embedding.shape[1],:].to(self.dtype) text_embedding=position_embedding+text_embedding text_embedding=self.transformer(text_embedding,use_emotion=use_emotion) text_embedding=self.ln_final(text_embedding) #传入的标记有 #print(index[0],index_new[0],text_embedding.shape) text_embedding=text_embedding[torch.arange(text.shape[0]),index] text_embedding=text_embedding@self.text_projection.to(self.dtype) return text_embedding def forward(self,image,text,use_emotion=True): image_features=self.encode_image(image,use_emotion) text_features=self.encode_text(text,use_emotion) # normalized features image_features=image_features/image_features.norm(dim=-1,keepdim=True) text_features=text_features/text_features.norm(dim=-1,keepdim=True) # cosine similarity as logits logit_scale=self.logit_scale.exp() logits_per_image=logit_scale*image_features@text_features.t() logits_per_text=logits_per_image.t() # shape = [global_batch_size, global_batch_size] return logits_per_image,logits_per_text class Config: def __init__(self): self.vocab_size=49408 self.image_dim=768 self.num_patches=49 self.patch_size=32 self.hidden_size=512 self.prompt_num=20 self.max_position_embeddings=77 self.num_hidden_layers=12 self.num_attention_heads=8 self.head_size=64 self.layer_norm_eps=1e-5 self.activation_function="Quickgelu" self.dtype=torch.float16 self.device=torch.device("cuda:0") self.logit_scale_init=4.6052 self.num_virtual_tokens=20 self.token_dim=self.hidden_size self.encoder_hidden_size=self.hidden_size config=Config() model=CLIP(config) #加载预训练权重 model.load_state_dict(torch.load(r'./EmotionCLIP-V2.pth',weights_only=True,map_location='cpu'),strict=True) """ for name, param in model.named_parameters(): if 'prefix' not in name and 'prompt' not in name and 'ln' not in name: # 如果参数名中不包含'prefix' print(name,"'s requires_grad turn off.") param.requires_grad = False # 冻结该参数 else: print(name,"'s requires_grad turn on.") param.requires_grad = True # 允许该参数进行训练 """ #编译模型 #model=torch.compile(model) import pickle from PIL import Image import numpy as np import clip with open('./preprocess.pkl','rb') as f: preprocess = pickle.load(f) with open('./tokenize.pkl','rb') as f: tokenizer=pickle.load(f) device=config.device image = preprocess(Image.open("Dog sad.jpg")).unsqueeze(0).to(device) # 情感识别 labels=[ 'amusement', 'anger', 'awe', 'contentment', 'disgust', 'excitement', 'fear', 'sadness', 'neutral' ] text_list=[ f"This picture conveys a sense of {label}" for label in labels] tokens= tokenizer(text_list, context_length=57).to(device) with torch.no_grad(): logits_per_image, logits_per_text = model(image.to(config.dtype), tokens) probs = logits_per_image.softmax(dim=-1).cpu().numpy() # 获取预测标签 predicted_index = np.argmax(probs, axis=1) predicted_label=labels[predicted_index[0]] print("情感识别:", probs) print("预测的情感标签:", predicted_label) # 泛化性能 labels=[ 'spider', 'dog', 'cat', 'fish' ] text_list=[ f"This is a {label}" for label in labels] tokens= tokenizer(text_list,context_length=57).to(device) with torch.no_grad(): logits_per_image, logits_per_text = model(image.to(config.dtype), tokens, use_emotion=False) probs = logits_per_image.softmax(dim=-1).cpu().numpy() # 获取预测标签 predicted_index = np.argmax(probs, axis=1) predicted_label=labels[predicted_index[0]] print("泛化识别:", probs) print("预测的泛化标签:", predicted_label)