Spaces:
Paused
Paused
import gradio as gr | |
import base64 | |
import json | |
import os | |
import shutil | |
import uuid | |
import glob | |
from huggingface_hub import CommitScheduler, HfApi, snapshot_download | |
from pathlib import Path | |
import git | |
from datasets import Dataset, Features, Value, Sequence, Image as ImageFeature | |
import threading | |
import time | |
from utils import process_and_push_dataset | |
from datasets import load_dataset | |
api = HfApi(token=os.environ["HF_TOKEN"]) | |
VALID_DATASET = load_dataset("taesiri/IERv2-Subset", split="train") | |
VALID_DATASET_POST_IDS = ( | |
load_dataset("taesiri/IERv2-Subset", split="train", columns=["post_id"]) | |
.to_pandas()["post_id"] | |
.tolist() | |
) | |
POST_ID_TO_ID_MAP = {post_id: idx for idx, post_id in enumerate(VALID_DATASET_POST_IDS)} | |
DATASET_REPO = "taesiri/AIImageEditingResults_Intemediate" | |
FINAL_DATASET_REPO = "taesiri/AIImageEditingResults" | |
# Download existing data from hub | |
def sync_with_hub(): | |
""" | |
Synchronize local data with the hub by cloning the dataset repo | |
""" | |
print("Starting sync with hub...") | |
data_dir = Path("./data") | |
if data_dir.exists(): | |
# Backup existing data | |
backup_dir = Path("./data_backup") | |
if backup_dir.exists(): | |
shutil.rmtree(backup_dir) | |
shutil.copytree(data_dir, backup_dir) | |
# Clone/pull latest data from hub | |
repo_url = f"https://huggingface.co/datasets/{DATASET_REPO}" | |
hub_data_dir = Path("hub_data") | |
if hub_data_dir.exists(): | |
# If repo exists, do a git pull | |
print("Pulling latest changes...") | |
repo = git.Repo(hub_data_dir) | |
origin = repo.remotes.origin | |
origin.pull() | |
else: | |
# Clone the repo | |
print("Cloning repository...") | |
git.Repo.clone_from(repo_url, hub_data_dir) | |
# Merge hub data with local data | |
hub_data_source = hub_data_dir / "data" | |
if hub_data_source.exists(): | |
# Create data dir if it doesn't exist | |
data_dir.mkdir(exist_ok=True) | |
# Copy files from hub | |
for item in hub_data_source.glob("*"): | |
if item.is_dir(): | |
dest = data_dir / item.name | |
if not dest.exists(): # Only copy if doesn't exist locally | |
shutil.copytree(item, dest) | |
# Clean up cloned repo | |
if hub_data_dir.exists(): | |
shutil.rmtree(hub_data_dir) | |
print("Finished syncing with hub!") | |
scheduler = CommitScheduler( | |
repo_id=DATASET_REPO, | |
repo_type="dataset", | |
folder_path="./data", | |
path_in_repo="data", | |
every=1, | |
) | |
def load_question_data(question_id): | |
""" | |
Load a specific question's data | |
Returns a tuple of all form fields | |
""" | |
if not question_id: | |
return [None] * 11 # Reduced number of fields | |
# Extract the ID part before the colon from the dropdown selection | |
question_id = ( | |
question_id.split(":")[0].strip() if ":" in question_id else question_id | |
) | |
json_path = os.path.join("./data", question_id, "question.json") | |
if not os.path.exists(json_path): | |
print(f"Question file not found: {json_path}") | |
return [None] * 11 | |
try: | |
with open(json_path, "r", encoding="utf-8") as f: | |
data = json.loads(f.read().strip()) | |
# Load images | |
def load_image(image_path): | |
if not image_path: | |
return None | |
full_path = os.path.join( | |
"./data", question_id, os.path.basename(image_path) | |
) | |
return full_path if os.path.exists(full_path) else None | |
question_images = data.get("question_images", []) | |
rationale_images = data.get("rationale_images", []) | |
return [ | |
( | |
",".join(data["question_categories"]) | |
if isinstance(data["question_categories"], list) | |
else data["question_categories"] | |
), | |
data["question"], | |
data["final_answer"], | |
data.get("rationale_text", ""), | |
load_image(question_images[0] if question_images else None), | |
load_image(question_images[1] if len(question_images) > 1 else None), | |
load_image(question_images[2] if len(question_images) > 2 else None), | |
load_image(question_images[3] if len(question_images) > 3 else None), | |
load_image(rationale_images[0] if rationale_images else None), | |
load_image(rationale_images[1] if len(rationale_images) > 1 else None), | |
question_id, | |
] | |
except Exception as e: | |
print(f"Error loading question {question_id}: {str(e)}") | |
return [None] * 11 | |
def load_post_image(post_id): | |
if not post_id: | |
return [None] * 21 # source image + 10 pairs of (image, text) | |
idx = POST_ID_TO_ID_MAP[post_id] | |
source_image = VALID_DATASET[idx]["image"] | |
# Load existing responses if any | |
post_folder = os.path.join("./data", str(post_id)) | |
metadata_path = os.path.join(post_folder, "metadata.json") | |
if os.path.exists(metadata_path): | |
with open(metadata_path, "r") as f: | |
metadata = json.load(f) | |
# Initialize response data | |
responses = [(None, "")] * 10 | |
# Fill in existing responses | |
for response in metadata["responses"]: | |
idx = response["response_id"] | |
if idx < 10: # Ensure we don't exceed our UI limit | |
image_path = os.path.join(post_folder, response["image_path"]) | |
responses[idx] = (image_path, response["answer_text"]) | |
# Flatten responses for output | |
flat_responses = [item for pair in responses for item in pair] | |
return [source_image] + flat_responses | |
# If no existing responses, return source image and empty responses | |
return [source_image] + [None] * 20 | |
def generate_json_files(source_image, responses, post_id): | |
""" | |
Save the source image and multiple responses to the data directory | |
Args: | |
source_image: Path to the source image | |
responses: List of (image, answer) tuples | |
post_id: The post ID from the dataset | |
""" | |
# Create parent data folder if it doesn't exist | |
parent_data_folder = "./data" | |
os.makedirs(parent_data_folder, exist_ok=True) | |
# Create/clear post_id folder | |
post_folder = os.path.join(parent_data_folder, str(post_id)) | |
if os.path.exists(post_folder): | |
shutil.rmtree(post_folder) | |
os.makedirs(post_folder) | |
# Save source image | |
source_image_path = os.path.join(post_folder, "source_image.png") | |
if isinstance(source_image, str): | |
shutil.copy2(source_image, source_image_path) | |
else: | |
gr.processing_utils.save_image(source_image, source_image_path) | |
# Create responses data | |
responses_data = [] | |
for idx, (response_image, answer_text) in enumerate(responses): | |
if response_image and answer_text: # Only process if both image and text exist | |
response_folder = os.path.join(post_folder, f"response_{idx}") | |
os.makedirs(response_folder) | |
# Save response image | |
response_image_path = os.path.join(response_folder, "response_image.png") | |
if isinstance(response_image, str): | |
shutil.copy2(response_image, response_image_path) | |
else: | |
gr.processing_utils.save_image(response_image, response_image_path) | |
# Add to responses data | |
responses_data.append( | |
{ | |
"response_id": idx, | |
"answer_text": answer_text, | |
"image_path": f"response_{idx}/response_image.png", | |
} | |
) | |
# Create metadata JSON | |
metadata = { | |
"post_id": post_id, | |
"source_image": "source_image.png", | |
"responses": responses_data, | |
} | |
# Save metadata | |
with open(os.path.join(post_folder, "metadata.json"), "w", encoding="utf-8") as f: | |
json.dump(metadata, f, ensure_ascii=False, indent=2) | |
return post_folder | |
# Build the Gradio app | |
with gr.Blocks() as demo: | |
gr.Markdown("# Image Response Collector") | |
# Source image selection at the top | |
with gr.Column(): | |
post_id_dropdown = gr.Dropdown( | |
label="Select Post ID to Load Image", | |
choices=VALID_DATASET_POST_IDS, | |
type="value", | |
allow_custom_value=False, | |
) | |
source_image = gr.Image(label="Source Image", type="filepath") | |
# Responses in tabs | |
with gr.Tabs() as response_tabs: | |
responses = [] | |
for i in range(10): | |
with gr.Tab(f"Response {i+1}"): | |
img = gr.Image(label=f"Response Image {i+1}", type="filepath") | |
txt = gr.Textbox(label=f"Model Name {i+1}", lines=2) | |
responses.append((img, txt)) | |
with gr.Row(): | |
submit_btn = gr.Button("Submit All Responses") | |
clear_btn = gr.Button("Clear Form") | |
def submit_responses(source_img, post_id, *response_data): | |
if not source_img: | |
gr.Warning("Please select a source image first!") | |
return | |
if not post_id: | |
gr.Warning("Please select a post ID first!") | |
return | |
# Convert flat response_data into pairs of (image, text) | |
response_pairs = list(zip(response_data[::2], response_data[1::2])) | |
# Filter out empty responses | |
valid_responses = [ | |
(img, txt) for img, txt in response_pairs if img is not None and txt | |
] | |
if not valid_responses: | |
gr.Warning("Please provide at least one response (image + text)!") | |
return | |
generate_json_files(source_img, valid_responses, post_id) | |
gr.Info("Responses saved successfully! 🎉") | |
def clear_form(): | |
outputs = [None] * (1 + 20) # 1 source image + 10 pairs of (image, text) | |
return outputs | |
# Connect components | |
post_id_dropdown.change( | |
fn=load_post_image, | |
inputs=[post_id_dropdown], | |
outputs=[source_image] + [comp for pair in responses for comp in pair], | |
) | |
submit_inputs = [source_image, post_id_dropdown] + [ | |
comp for pair in responses for comp in pair | |
] | |
submit_btn.click(fn=submit_responses, inputs=submit_inputs) | |
clear_outputs = [source_image] + [comp for pair in responses for comp in pair] | |
clear_btn.click(fn=clear_form, outputs=clear_outputs) | |
def process_thread(): | |
while True: | |
try: | |
pass | |
# process_and_push_dataset( | |
# "./data", | |
# FINAL_DATASET_REPO, | |
# token=os.environ["HF_TOKEN"], | |
# private=True, | |
# ) | |
except Exception as e: | |
print(f"Error in process thread: {e}") | |
time.sleep(120) # Sleep for 2 minutes | |
if __name__ == "__main__": | |
print("Initializing app...") | |
sync_with_hub() # Sync before launching the app | |
print("Starting Gradio interface...") | |
# Start the processing thread when the app starts | |
processing_thread = threading.Thread(target=process_thread, daemon=True) | |
processing_thread.start() | |
demo.launch() | |