|
import os |
|
import math |
|
import copy |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch.nn import CrossEntropyLoss |
|
|
|
from PIL import Image |
|
from functools import partial |
|
from typing import List, Optional, Tuple, Union, Dict |
|
from dataclasses import dataclass |
|
|
|
import transformers |
|
from transformers.modeling_outputs import ModelOutput |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, Qwen2Config, SiglipVisionModel |
|
|
|
from .adapters import AdapterSigLIP |
|
from .mm_constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX |
|
from .processing_FlashVL import tokenizer_image_token_qwen |
|
from .configuration_FlashVLStatic import FlashVLStaticConfig |
|
|
|
@dataclass |
|
class FlashVLStaticOutputWithPast(ModelOutput): |
|
loss: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
|
|
|
|
class FlashVLStatic(PreTrainedModel): |
|
config_class = FlashVLStaticConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.llm = AutoModelForCausalLM.from_config(config.llm_config, trust_remote_code=True) |
|
self.vit = SiglipVisionModel(config.vision_config).vision_model |
|
self.adp = AdapterSigLIP(config) |
|
self.image_token_num = config.image_token_num |
|
self.image_size = config.vision_config.image_size |
|
|
|
def merge_text_image_tokens(self, inputs): |
|
input_ids, image_features, targets, attn_mask, loss_mask = inputs |
|
micro_batch_size, tokens_len = input_ids.shape |
|
device = input_ids.device |
|
|
|
img_rows, img_cols = torch.where(input_ids == IMAGE_TOKEN_INDEX) |
|
image_idxs = {i: [] for i in range(micro_batch_size)} |
|
for row, col in zip(img_rows.tolist(), img_cols.tolist()): |
|
image_idxs[row].append(col) |
|
for row in range(micro_batch_size): |
|
image_idxs[row] = sorted(image_idxs[row]) |
|
|
|
split_sizes = [] |
|
for row in range(micro_batch_size): |
|
image_num = len(image_idxs[row]) |
|
if image_num == 0: |
|
split_sizes.append(tokens_len) |
|
continue |
|
|
|
if image_idxs[row][0] != 0: |
|
split_sizes.append(image_idxs[row][0]) |
|
|
|
for idx in range(image_num - 1): |
|
split_sizes.append(self.image_token_num) |
|
if image_idxs[row][idx + 1] > image_idxs[row][idx] + self.image_token_num: |
|
split_sizes.append(image_idxs[row][idx + 1] - (image_idxs[row][idx] + self.image_token_num)) |
|
|
|
if image_idxs[row][image_num - 1] + self.image_token_num >= tokens_len: |
|
split_sizes.append(tokens_len - image_idxs[row][image_num - 1]) |
|
else: |
|
split_sizes.append(self.image_token_num) |
|
split_sizes.append(tokens_len - (image_idxs[row][image_num - 1] + self.image_token_num)) |
|
|
|
input_ids_noim = torch.where(input_ids < 0, 151643, input_ids) |
|
input_ids_noim = input_ids_noim.view(-1) |
|
input_embeds = self.llm.model.embed_tokens(input_ids_noim) |
|
input_embeds_split = torch.split(input_embeds, split_sizes, dim=0) |
|
|
|
vl_embeds_list = [] |
|
cur_language_idx = 0 |
|
cur_image_idx = 0 |
|
for row in range(micro_batch_size): |
|
image_num = len(image_idxs[row]) |
|
if image_num == 0: |
|
vl_embeds_list.append(input_embeds_split[cur_language_idx]) |
|
cur_language_idx += 1 |
|
vl_embeds_list.append(image_features[cur_image_idx][0:0]) |
|
cur_image_idx += 1 |
|
continue |
|
|
|
if image_idxs[row][0] != 0: |
|
vl_embeds_list.append(input_embeds_split[cur_language_idx]) |
|
cur_language_idx += 1 |
|
|
|
for idx in range(image_num - 1): |
|
vl_embeds_list.append(image_features[cur_image_idx]) |
|
cur_language_idx += 1 |
|
cur_image_idx += 1 |
|
|
|
if image_idxs[row][idx + 1] > image_idxs[row][idx] + self.image_token_num: |
|
vl_embeds_list.append(input_embeds_split[cur_language_idx]) |
|
cur_language_idx += 1 |
|
|
|
if image_idxs[row][image_num - 1] + self.image_token_num >= tokens_len: |
|
vl_embeds_list.append(image_features[cur_image_idx][0 : tokens_len - image_idxs[row][image_num - 1]]) |
|
cur_language_idx += 1 |
|
cur_image_idx += 1 |
|
else: |
|
vl_embeds_list.append(image_features[cur_image_idx]) |
|
cur_language_idx += 1 |
|
cur_image_idx += 1 |
|
vl_embeds_list.append(input_embeds_split[cur_language_idx]) |
|
cur_language_idx += 1 |
|
|
|
vl_embeds = torch.cat(vl_embeds_list) |
|
vl_embeds = vl_embeds.view(micro_batch_size, tokens_len, vl_embeds.shape[-1]) |
|
return (input_ids, vl_embeds, targets, attn_mask, loss_mask) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
pixel_values: torch.FloatTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
local_pos_batch: Optional[torch.LongTensor] = None, |
|
image_idx_batch: Optional[torch.Tensor] = None, |
|
loss_mask: Optional[torch.Tensor] = None, |
|
use_cache: Optional[bool] = None, |
|
): |
|
inputs = [input_ids, pixel_values, labels, attention_mask, loss_mask] |
|
|
|
if isinstance(inputs[1], list): |
|
pixel_values = [p.bfloat16() for p in inputs[1]] |
|
else: |
|
pixel_values = inputs[1].bfloat16() |
|
img_token = self.vit.forward(pixel_values) |
|
|
|
if hasattr(img_token, 'last_hidden_state'): |
|
img_token = img_token.last_hidden_state |
|
|
|
inputs = self.adp(inputs[:1]+[img_token]+inputs[2:]) |
|
|
|
inputs = self.merge_text_image_tokens(inputs) |
|
tokens, hidden_states, targets, attn_mask, loss_mask = inputs |
|
|
|
outputs = self.llm.forward( |
|
inputs_embeds = hidden_states, |
|
attention_mask = attn_mask, |
|
use_cache = use_cache) |
|
|
|
lm_logits = outputs.logits |
|
|
|
loss = None |
|
if targets is not None: |
|
labels = targets.to(lm_logits.device) |
|
shift_logits = lm_logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
loss_fct = CrossEntropyLoss(reduction='none') |
|
loss = loss_fct( |
|
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) |
|
) |
|
|
|
batch_size = labels.size(0) |
|
loss_mask = loss_mask[:, 1:].to(loss.dtype) |
|
loss = (loss.view(batch_size, -1) * loss_mask).sum() / loss_mask.sum() |
|
|
|
return FlashVLStaticOutputWithPast( |
|
loss=loss, |
|
logits=lm_logits |
|
) |
|
|
|
def get_input_embeddings(self): |
|
return self.llm.get_input_embeddings() |
|
|
|
def generate( |
|
self, |
|
input_ids=None, |
|
pixel_values=None, |
|
attention_mask=None, |
|
**kwargs |
|
): |
|
image = pixel_values |
|
img_token = self.vit.forward(image.bfloat16()) |
|
if hasattr(img_token, 'last_hidden_state'): |
|
img_token = img_token.last_hidden_state |
|
inputs = self.adp(( |
|
input_ids.to(self.device), |
|
img_token, |
|
None, None, None)) |
|
inputs = self.merge_text_image_tokens(inputs) |
|
tokens, hidden_states, targets, attn_mask, loss_mask = inputs |
|
|
|
keys_to_pop = ['loss_mask', 'labels','attention_mask'] |
|
kwargs = {k: v for k, v in kwargs.items() if k not in keys_to_pop} |
|
|
|
outputs = self.llm.generate( |
|
inputs_embeds=hidden_states.bfloat16(), |
|
max_new_tokens=2048, |
|
do_sample=False, |
|
**kwargs |
|
) |
|
|
|
return outputs |
|
|
|
def chat(self, pil_image, messages, answer_prompt=None, do_sample=True, max_new_tokens=256): |
|
data={} |
|
data['img'] = pil_image |
|
data['text_only'] = (pil_image is None) |
|
data['messages'] = messages |
|
|
|
sources = self.to_llava_format(data) |
|
sources = [sources] |
|
has_image = not sources[0]['text_only'] |
|
|
|
if has_image: |
|
img_list = sources[0]['image'] |
|
if not isinstance(img_list, list): |
|
img_list = [img_list] |
|
image = torch.stack([torch.from_numpy(self.im_trans(i)['pixel_values'][0]) for i in img_list], dim=0) |
|
|
|
sources = copy.deepcopy([e["conversations"] for e in sources]) |
|
|
|
data_dict = self.preprocess_qwen( |
|
sources, |
|
self.tokenizer, |
|
has_image=has_image, |
|
) |
|
|
|
input_ids_data = data_dict["input_ids"][0] |
|
data_dict["input_ids"] = [ input_ids_data, ] |
|
|
|
if not has_image: |
|
image = torch.zeros(1, 3, self.image_size, self.image_size) |
|
data_dict = dict(tokens=data_dict["input_ids"][0],) |
|
|
|
img_token = self.vit.forward(image.cuda().bfloat16()) |
|
|
|
if hasattr(img_token, 'last_hidden_state'): |
|
img_token = img_token.last_hidden_state |
|
|
|
inputs = self.adp(( |
|
data_dict['tokens'].unsqueeze(0).to(self.device), |
|
img_token, |
|
None, None, None)) |
|
|
|
inputs = self.merge_text_image_tokens(inputs) |
|
tokens, hidden_states, targets, attn_mask, loss_mask = inputs |
|
|
|
outputs = self.llm.generate( |
|
inputs_embeds=hidden_states.bfloat16(), |
|
return_dict_in_generate=False, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=do_sample, |
|
pad_token_id=False, |
|
) |
|
decoded = self.tokenizer.decode(outputs[0]) |
|
|
|
stop_words_ids = [self.llm.generation_config.bos_token_id, |
|
self.llm.generation_config.eos_token_id, |
|
self.tokenizer.convert_tokens_to_ids('<|im_start|>')] |
|
stop_words = [self.tokenizer.decode(w) for w in stop_words_ids] |
|
|
|
for stop_word in stop_words: |
|
decoded = decoded.replace(stop_word, "").strip() |
|
|
|
return decoded |
|
|
|
def preprocess_qwen( |
|
self, |
|
sources, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
has_image: bool = False, |
|
max_len=2048, |
|
system_message: str = "You are a helpful assistant.",) -> Dict: |
|
|
|
roles = {"human": "user", "gpt": "assistant"} |
|
tokenizer = copy.deepcopy(tokenizer) |
|
|
|
tokenizer.add_tokens(["<image>"], special_tokens=True) |
|
image_token_index = tokenizer.convert_tokens_to_ids("<image>") |
|
im_start, im_end = tokenizer.additional_special_tokens_ids[:2] |
|
|
|
unmask_tokens_idx = [198, im_start, im_end] |
|
nl_tokens = tokenizer("\n").input_ids |
|
|
|
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" |
|
tokenizer.chat_template = chat_template |
|
|
|
input_ids, targets = [], [] |
|
for i, source in enumerate(sources): |
|
if roles[source[0]["from"]] != roles["human"]: |
|
source = source[1:] |
|
input_id, target = [], [] |
|
|
|
input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}]) |
|
target += [IGNORE_INDEX] * len(input_id) |
|
i=0 |
|
for conv in source: |
|
try: |
|
role = conv["role"] |
|
content = conv["content"] |
|
except: |
|
role = conv["from"] |
|
content = conv["value"] |
|
role = roles.get(role, role) |
|
if i==len(source)-1: |
|
conv = [{"role" : role, "content" : content}] |
|
encode_id = tokenizer.apply_chat_template(conv,add_generation_prompt=True) |
|
else: |
|
conv = [{"role" : role, "content" : content}] |
|
encode_id = tokenizer.apply_chat_template(conv) |
|
i=i+1 |
|
|
|
if image_token_index in encode_id: |
|
encode_id = tokenizer_image_token_qwen(encode_id, tokenizer, image_token_index, image_token_num=self.image_token_num) |
|
input_id += encode_id |
|
if role in ["user", "system"]: |
|
target += [IGNORE_INDEX] * len(encode_id) |
|
else: |
|
target += encode_id |
|
|
|
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" |
|
for idx, encode_id in enumerate(input_id): |
|
if encode_id in unmask_tokens_idx: |
|
target[idx] = encode_id |
|
if encode_id == image_token_index: |
|
input_id[idx] = IMAGE_TOKEN_INDEX |
|
input_ids.append(input_id) |
|
targets.append(target) |
|
|
|
input_ids = torch.tensor(input_ids, dtype=torch.long) |
|
targets = torch.tensor(targets, dtype=torch.long) |
|
|
|
return dict( |
|
input_ids=input_ids, |
|
labels=targets, |
|
) |
|
|
|
def to_llava_format(self, data): |
|
img_pil = data['img'] |
|
messages = data['messages'] |
|
text_only = data['text_only'] |
|
is_video=False |
|
if 'is_video' in data: |
|
is_video=data['is_video'] |
|
|
|
messages.append({'role': 'assistant', 'content': ''}) |
|
conversations = [] |
|
for i,m in enumerate(messages): |
|
if m['role'] == 'user': |
|
value = str(m['content']).replace('<image>', '') |
|
if i == 0 and not text_only: |
|
value = '<image>\n' + value |
|
|
|
conversations.append({'from': 'human', 'value': value}) |
|
elif m['role'] == 'assistant': |
|
conversations.append({'from': 'gpt', 'value': str(m['content']).replace('<image>', '')}) |
|
else: |
|
raise ValueError(f"Wrong role in conversation. {m['role']}") |
|
|
|
return {'image': img_pil, |
|
'text_only': text_only, |
|
'is_video':is_video, |
|
'conversations': conversations} |
|
|