|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import os |
|
import json |
|
from tqdm import tqdm |
|
import ipdb |
|
import random |
|
from torch.nn.utils.rnn import pad_sequence |
|
from dataclasses import dataclass, field |
|
from typing import Callable, Dict, Sequence |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import transformers |
|
from torch.utils.data import Dataset |
|
from tqdm import tqdm |
|
|
|
class SupervisedDataset(Dataset): |
|
"""Dataset for supervised fine-tuning.""" |
|
|
|
def __init__(self, data_path: str, image_root_path: str): |
|
super(SupervisedDataset, self).__init__() |
|
|
|
with open(data_path, 'r') as f: |
|
json_data = json.load(f) |
|
|
|
|
|
|
|
self.image_path_list, self.caption_list = [], [] |
|
for item in json_data: |
|
one_image_name, one_caption = item["image_name"], item["conversation"] |
|
|
|
if not one_image_name.endswith('.jpg'): |
|
one_image_name += '.jpg' |
|
one_image_path = image_root_path + '/{}'.format(one_image_name) |
|
self.image_path_list.append(one_image_path) |
|
self.caption_list.append(one_caption) |
|
print(f'[!] collect {len(self.image_path_list)} samples for training') |
|
|
|
def __len__(self): |
|
return len(self.image_path_list) |
|
|
|
|
|
def __getitem__(self, i): |
|
return dict(image_paths=self.image_path_list[i], output_texts=self.caption_list[i]) |
|
|
|
def collate(self, instances): |
|
image_paths, output_texts = tuple([instance[key] for instance in instances] for key in ("image_paths", "output_texts")) |
|
return dict( |
|
image_paths=image_paths, |
|
output_texts=output_texts |
|
) |
|
|