File size: 7,023 Bytes
b11ecdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
""" 
VIT的transformer结构没有因果掩码,因为任意一个位置都能访问其它位置,它们之间没有因果关系,或者说关系很弱

文本生成仍然考虑因果掩码。
"""
import torch.nn.functional as F
from VIT import model as VIT 
from Text_Encoder import text_encoder as transformer
import torch.nn as nn
import torch
from Text_Encoder import MLP

class Prompt_block(nn.Module):
    def __init__(self,config):
        super(Prompt_block,self).__init__()
        self.prompt_embedding=nn.Embedding(config.prompt_num,config.hidden_size,dtype=config.dtype,device=config.device)
    def forward(self,text_embeddings):
        b,_,_=text_embeddings.size()
        n,dim=self.prompt_embedding.weight.size()
        """
        new_embeddings=[]
        for batch,index_ in enumerate(index):
            text_embedding=text_embeddings[0]
            text_embedding=torch.cat((text_embedding[:index_,:],self.prompt_embedding.weight,text_embedding[index_:,:]),0)
            new_embeddings.append(text_embedding)
        stacked_embedding= torch.stack(new_embeddings, dim=0) 
        return stacked_embedding
        """
        text_embeddings=torch.cat((text_embeddings[:,0:1,:],self.prompt_embedding.weight.expand(b,n,dim),text_embeddings[:,1:,:]),1)
        return text_embeddings
        
        
        
        

class CLIP(nn.Module):
    def __init__(self,config):  
        super().__init__()
        self.visual=VIT
        self.device=config.device
        self.dtype=config.dtype
        self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,dtype=config.dtype,device=config.device)
        self.max_position_embeddings=config.max_position_embeddings
        self.prompt_num=config.prompt_num
        self.transformer=transformer
        #增加一个prompt block
        self.prompt_block=Prompt_block(config)
        self.positional_embedding=nn.Parameter(torch.empty(config.max_position_embeddings,config.hidden_size,device=config.device))
        self.ln_final=nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps,dtype=config.dtype,device=config.device)
        self.text_projection=nn.Parameter(torch.empty(config.hidden_size,config.hidden_size,device=config.device))
        self.logit_scale=nn.Parameter(torch.empty([],dtype=config.dtype,device=config.device)*config.logit_scale_init,requires_grad=False)
    def encode_image(self,img,use_emotion=True):
        cls_embedding=self.visual(img,use_emotion)
        #cls_embedding:[batch_size,1,512],image_embedding:[batch_size,7,512]
        return cls_embedding
    def encode_text(self,text,use_emotion=True):
        #预留20token的位置
        b,n=text.size()
        index=text.argmax(dim=-1)
        text_embedding=self.token_embedding(text)
        #text_embedding=self.prompt_block(index,text_embedding)
        if n==self.max_position_embeddings-self.prompt_num:
            text_embedding=self.prompt_block(text_embedding)
            index=index+torch.tensor(20,device=index.device,dtype=index.dtype)
        position_embedding=self.positional_embedding[None,:text_embedding.shape[1],:].to(self.dtype)
        text_embedding=position_embedding+text_embedding
        text_embedding=self.transformer(text_embedding,use_emotion=use_emotion)
        text_embedding=self.ln_final(text_embedding)
        #传入的标记有
        #print(index[0],index_new[0],text_embedding.shape)
        text_embedding=text_embedding[torch.arange(text.shape[0]),index]
        [email protected]_projection.to(self.dtype)

        return text_embedding

    def forward(self,image,text,use_emotion=True):
        image_features=self.encode_image(image,use_emotion)
        text_features=self.encode_text(text,use_emotion)
        # normalized features
        image_features=image_features/image_features.norm(dim=-1,keepdim=True)
        text_features=text_features/text_features.norm(dim=-1,keepdim=True)
        # cosine similarity as logits
        logit_scale=self.logit_scale.exp()
        logits_per_image=logit_scale*image_features@text_features.t()
        logits_per_text=logits_per_image.t()
        # shape = [global_batch_size, global_batch_size]
        return logits_per_image,logits_per_text

class Config:
    def __init__(self):
        self.vocab_size=49408
        self.image_dim=768
        self.num_patches=49
        self.patch_size=32
        self.hidden_size=512
        self.prompt_num=20
        self.max_position_embeddings=77
        self.num_hidden_layers=12
        self.num_attention_heads=8
        self.head_size=64
        self.layer_norm_eps=1e-5
        self.activation_function="Quickgelu"
        self.dtype=torch.float16
        self.device=torch.device("cuda:0")
        self.logit_scale_init=4.6052
        self.num_virtual_tokens=20
        self.token_dim=self.hidden_size
        self.encoder_hidden_size=self.hidden_size
        
config=Config()
model=CLIP(config)
#加载预训练权重
model.load_state_dict(torch.load(r'./EmotionCLIP-V2.pth',weights_only=True,map_location='cpu'),strict=True)
"""
for name, param in model.named_parameters():
    if 'prefix' not in name and 'prompt' not  in name and 'ln' not in name:  # 如果参数名中不包含'prefix'
        print(name,"'s requires_grad turn off.")
        param.requires_grad = False  # 冻结该参数
    else:
        print(name,"'s requires_grad turn on.")
        param.requires_grad = True  # 允许该参数进行训练
"""

#编译模型
#model=torch.compile(model)
import pickle
from PIL import Image
import numpy as np
import clip
with open('./preprocess.pkl','rb') as f:
    preprocess = pickle.load(f)
with open('./tokenize.pkl','rb') as f:
    tokenizer=pickle.load(f)
device=config.device
image = preprocess(Image.open("Dog sad.jpg")).unsqueeze(0).to(device)
# 情感识别
labels=[
 'amusement',
 'anger',
 'awe',
 'contentment',
 'disgust',
 'excitement',
 'fear',
 'sadness',
 'neutral'          
    ]
text_list=[ f"This picture conveys a sense of {label}" for label in labels]
tokens= tokenizer(text_list,
                 context_length=57).to(device)

with torch.no_grad():
    logits_per_image, logits_per_text = model(image.to(config.dtype), tokens)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

# 获取预测标签
predicted_index = np.argmax(probs, axis=1)
predicted_label=labels[predicted_index[0]]

print("情感识别:", probs)
print("预测的情感标签:", predicted_label)

# 泛化性能
labels=[
    'spider',
    'dog',
    'cat',
    'fish'
]
text_list=[ f"This is a {label}" for label in labels]
tokens= tokenizer(text_list,context_length=57).to(device)

with torch.no_grad():
    logits_per_image, logits_per_text = model(image.to(config.dtype), tokens, use_emotion=False)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

# 获取预测标签
predicted_index = np.argmax(probs, axis=1)
predicted_label=labels[predicted_index[0]]

print("泛化识别:", probs)
print("预测的泛化标签:", predicted_label)