DekGenerate / app.py
Nattapong Tapachoom
Add data quality management features and update requirements
e7a189a
raw
history blame
36.2 kB
import gradio as gr
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import json
import io
import csv
from typing import List, Dict
import threading
import time
import queue
from concurrent.futures import ThreadPoolExecutor, as_completed
import asyncio
from data_quality import DataQualityManager, export_to_huggingface_format
# Predefined task templates with Thai language support
TASK_TEMPLATES = {
"text_generation": {
"name": "การสร้างข้อความ (Text Generation)",
"template": "เขียนเรื่องราวสร้างสรรค์เกี่ยวกับ {topic}",
"description": "สร้างข้อความสร้างสรรค์ภาษาไทยจากหัวข้อที่กำหนด"
},
"question_answering": {
"name": "คำถาม-คำตอบ (Question Answering)",
"template": "คำถาม: {question}\nคำตอบ:",
"description": "สร้างคู่คำถาม-คำตอบภาษาไทย"
},
"summarization": {
"name": "การสรุปข้อความ (Text Summarization)",
"template": "สรุปข้อความต่อไปนี้: {text}",
"description": "สร้างตัวอย่างการสรุปข้อความภาษาไทย"
},
"translation": {
"name": "การแปลภาษา (Translation)",
"template": "แปลจาก {source_lang} เป็น {target_lang}: {text}",
"description": "สร้างคู่ข้อมูลสำหรับการแปลภาษา"
},
"classification": {
"name": "การจำแนกข้อความ (Text Classification)",
"template": "จำแนกอารมณ์ของข้อความนี้: {text}\nอารมณ์:",
"description": "สร้างตัวอย่างการจำแนกอารมณ์หรือหมวดหมู่ของข้อความ"
},
"conversation": {
"name": "บทสนทนา (Conversation)",
"template": "มนุษย์: {input}\nผู้ช่วย:",
"description": "สร้างข้อมูลบทสนทนาภาษาไทย"
},
"instruction_following": {
"name": "การทำตามคำสั่ง (Instruction Following)",
"template": "คำสั่ง: {instruction}\nการตอบสนอง:",
"description": "สร้างคู่คำสั่ง-การตอบสนองภาษาไทย"
},
"thai_poetry": {
"name": "กวีนิพนธ์ไทย (Thai Poetry)",
"template": "แต่งกวีนิพนธ์เกี่ยวกับ {topic} ในรูปแบบ {style}",
"description": "สร้างกวีนิพนธ์ไทยในรูปแบบต่างๆ"
},
"thai_news": {
"name": "ข่าวภาษาไทย (Thai News)",
"template": "เขียนข่าวภาษาไทยเกี่ยวกับ {topic} ในหัวข้อ {category}",
"description": "สร้างข้อความข่าวภาษาไทยในหมวดหมู่ต่างๆ"
}
}
# Thai language models from Hugging Face
THAI_MODELS = {
"typhoon-7b": {
"name": "🌪️ Typhoon-7B (SCB10X)",
"model_id": "scb10x/typhoon-7b",
"description": "โมเดลภาษาไทยขนาด 7B พารามิเตอร์ ประสิทธิภาพสูง"
},
"openthaigpt": {
"name": "🇹🇭 OpenThaiGPT 1.5-7B",
"model_id": "openthaigpt/openthaigpt1.5-7b-instruct",
"description": "โมเดลภาษาไทยรองรับคำสั่งและบทสนทนาหลายรอบ"
},
"wangchanlion": {
"name": "🦁 Gemma2-9B WangchanLION",
"model_id": "aisingapore/Gemma2-9b-WangchanLIONv2-instruct",
"description": "โมเดลขนาด 9B รองรับไทย-อังกฤษ พัฒนาโดย AI Singapore"
},
"sambalingo": {
"name": "🌍 SambaLingo-Thai-Base",
"model_id": "sambanovasystems/SambaLingo-Thai-Base",
"description": "โมเดลภาษาไทยพื้นฐาน รองรับทั้งไทยและอังกฤษ"
},
"other": {
"name": "🔧 โมเดลอื่นๆ (Custom)",
"model_id": "custom",
"description": "ระบุชื่อโมเดลที่ต้องการใช้งานเอง"
}
}
def load_file_data(file_path: str) -> List[Dict]:
"""Load data from uploaded file"""
try:
if file_path.endswith('.csv'):
df = pd.read_csv(file_path)
return df.to_dict('records')
elif file_path.endswith('.json'):
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
elif file_path.endswith('.txt'):
with open(file_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
return [{'text': line.strip()} for line in lines if line.strip()]
else:
raise ValueError("Unsupported file format. Use CSV, JSON, or TXT files.")
except Exception as e:
raise Exception(f"Error reading file: {str(e)}")
def generate_from_template(template: str, data_row: Dict) -> str:
"""Generate prompt from template and data"""
try:
return template.format(**data_row)
except KeyError as e:
return f"Template error: Missing field {e}"
def load_model(model_name):
"""Load a Hugging Face model for text generation"""
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
return generator, None
except Exception as e:
return None, str(e)
def generate_dataset(model_name, prompt_template, num_samples, max_length, temperature, top_p):
"""Generate dataset using Hugging Face model"""
try:
generator, error = load_model(model_name)
if error:
return None, f"Error loading model: {error}"
dataset = []
for i in range(num_samples):
# Generate text
generated = generator(
prompt_template,
max_length=max_length,
temperature=temperature,
top_p=top_p,
num_return_sequences=1,
do_sample=True
)
generated_text = generated[0]['generated_text']
dataset.append({
'id': i + 1,
'prompt': prompt_template,
'generated_text': generated_text,
'full_text': generated_text
})
# Convert to DataFrame for display
df = pd.DataFrame(dataset)
# Create downloadable files
csv_data = df.to_csv(index=False)
json_data = json.dumps(dataset, indent=2, ensure_ascii=False)
return df, csv_data, json_data, None
except Exception as e:
return None, None, None, f"Error generating dataset: {str(e)}"
def generate_dataset_from_task(model_name, task_type, custom_template, file_data, num_samples, max_length, temperature, top_p):
"""Generate dataset using task templates or file input"""
try:
generator, error = load_model(model_name)
if error:
return None, f"Error loading model: {error}"
dataset = []
# Determine the template to use
if custom_template and custom_template.strip():
template = custom_template
elif task_type in TASK_TEMPLATES:
template = TASK_TEMPLATES[task_type]["template"]
else:
template = "Generate text: {input}"
# Generate samples
for i in range(num_samples):
if file_data and len(file_data) > 0:
# Use file data cyclically
data_row = file_data[i % len(file_data)]
prompt = generate_from_template(template, data_row)
else:
# Use template with placeholder values
prompt = template.replace("{topic}", "artificial intelligence") \
.replace("{question}", "What is machine learning?") \
.replace("{text}", "Sample text for processing") \
.replace("{input}", f"Sample input {i+1}") \
.replace("{instruction}", f"Complete this task {i+1}")
# Generate text
generated = generator(
prompt,
max_length=max_length,
temperature=temperature,
top_p=top_p,
num_return_sequences=1,
do_sample=True,
pad_token_id=generator.tokenizer.eos_token_id
)
generated_text = generated[0]['generated_text']
dataset.append({
'id': i + 1,
'task_type': task_type,
'prompt': prompt,
'generated_text': generated_text,
'original_data': data_row if file_data else None
})
# Convert to DataFrame for display
df = pd.DataFrame(dataset)
# Create downloadable files
csv_data = df.to_csv(index=False)
json_data = json.dumps(dataset, indent=2, ensure_ascii=False)
return df, csv_data, json_data, None
except Exception as e:
return None, None, None, f"Error generating dataset: {str(e)}"
# Multi-model generation status tracking
class ModelStatus:
def __init__(self):
self.models = {}
self.record_status = {} # record_id: {"status": "pending/processing/completed", "model": "model_name"}
self.completed_records = []
self.lock = threading.Lock()
def set_record_processing(self, record_id: int, model_name: str):
with self.lock:
self.record_status[record_id] = {"status": "processing", "model": model_name}
def set_record_completed(self, record_id: int, result: dict):
with self.lock:
self.record_status[record_id]["status"] = "completed"
self.completed_records.append(result)
def get_next_available_record(self, total_records: int, model_name: str) -> int:
with self.lock:
for i in range(total_records):
if i not in self.record_status or self.record_status[i]["status"] == "pending":
self.record_status[i] = {"status": "pending", "model": model_name}
return i
return -1 # No available records
def get_progress(self, total_records: int) -> dict:
with self.lock:
completed = len([r for r in self.record_status.values() if r["status"] == "completed"])
processing = len([r for r in self.record_status.values() if r["status"] == "processing"])
return {
"completed": completed,
"processing": processing,
"total": total_records,
"percentage": (completed / total_records * 100) if total_records > 0 else 0
}
def load_model_with_cache(model_name: str, cache: dict):
"""Load model with caching to avoid reloading"""
if model_name in cache:
return cache[model_name], None
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
cache[model_name] = generator
return generator, None
except Exception as e:
return None, str(e)
def generate_single_record(generator, prompt: str, record_id: int, model_name: str,
max_length: int, temperature: float, top_p: float,
task_type: str, original_data: dict, status_tracker: ModelStatus):
"""Generate a single record with the given model"""
try:
# Mark record as processing
status_tracker.set_record_processing(record_id, model_name)
# Generate text
generated = generator(
prompt,
max_length=max_length,
temperature=temperature,
top_p=top_p,
num_return_sequences=1,
do_sample=True,
pad_token_id=generator.tokenizer.eos_token_id if hasattr(generator.tokenizer, 'eos_token_id') else generator.tokenizer.pad_token_id
)
generated_text = generated[0]['generated_text']
result = {
'id': record_id + 1,
'model_used': model_name,
'task_type': task_type,
'prompt': prompt,
'generated_text': generated_text,
'original_data': original_data,
'generation_time': time.time()
}
# Mark record as completed
status_tracker.set_record_completed(record_id, result)
return result
except Exception as e:
# If generation fails, mark as pending again for other models to try
with status_tracker.lock:
if record_id in status_tracker.record_status:
status_tracker.record_status[record_id]["status"] = "pending"
return None
def model_worker(model_name: str, model_cache: dict, prompts: List[str],
task_type: str, original_data_list: List[dict],
max_length: int, temperature: float, top_p: float,
status_tracker: ModelStatus, progress_callback=None):
"""Worker function for each model to process available records"""
# Load model
generator, error = load_model_with_cache(model_name, model_cache)
if error:
return f"Error loading {model_name}: {error}"
total_records = len(prompts)
processed_count = 0
while True:
# Get next available record
record_id = status_tracker.get_next_available_record(total_records, model_name)
if record_id == -1: # No more records available
break
# Generate record
prompt = prompts[record_id]
original_data = original_data_list[record_id] if original_data_list else None
result = generate_single_record(
generator, prompt, record_id, model_name,
max_length, temperature, top_p, task_type,
original_data, status_tracker
)
if result:
processed_count += 1
# Update progress
if progress_callback:
progress = status_tracker.get_progress(total_records)
progress_callback(progress, model_name, processed_count)
return f"{model_name}: Processed {processed_count} records"
def generate_dataset_multi_model(selected_models: List[str], task_type: str, custom_template: str,
file_data: List[dict], num_samples: int, max_length: int,
temperature: float, top_p: float, progress_callback=None):
"""Generate dataset using multiple models collaboratively"""
try:
# Prepare prompts
prompts = []
original_data_list = []
# Determine template
if custom_template and custom_template.strip():
template = custom_template
elif task_type in TASK_TEMPLATES:
template = TASK_TEMPLATES[task_type]["template"]
else:
template = "Generate text: {input}"
# Generate prompts for all records
for i in range(num_samples):
if file_data and len(file_data) > 0:
data_row = file_data[i % len(file_data)]
prompt = generate_from_template(template, data_row)
original_data_list.append(data_row)
else:
# Use template with placeholder values
prompt = template.replace("{topic}", f"หัวข้อที่ {i+1}") \
.replace("{question}", f"คำถามที่ {i+1} เกี่ยวกับการเรียนรู้ของเครื่อง") \
.replace("{text}", f"ข้อความตัวอย่างที่ {i+1} สำหรับการประมวลผล") \
.replace("{input}", f"ข้อมูลนำเข้าที่ {i+1}") \
.replace("{instruction}", f"คำสั่งที่ {i+1}: ให้ทำงานนี้") \
.replace("{category}", "เทคโนโลยี") \
.replace("{style}", "โคลงสี่สุภาพ")
original_data_list.append(None)
prompts.append(prompt)
# Initialize status tracker
status_tracker = ModelStatus()
model_cache = {}
# Start worker threads for each model
with ThreadPoolExecutor(max_workers=len(selected_models)) as executor:
futures = []
for model_name in selected_models:
future = executor.submit(
model_worker, model_name, model_cache, prompts,
task_type, original_data_list, max_length,
temperature, top_p, status_tracker, progress_callback
)
futures.append((future, model_name))
# Wait for all workers to complete
for future, model_name in futures:
try:
result = future.result(timeout=300) # 5 minute timeout per model
print(f"Model {model_name} completed: {result}")
except Exception as e:
print(f"Model {model_name} failed: {str(e)}")
# Collect results
dataset = sorted(status_tracker.completed_records, key=lambda x: x['id'])
if not dataset:
return None, None, None, "ไม่สามารถสร้างข้อมูลได้"
# Convert to DataFrame
df = pd.DataFrame(dataset)
# Create downloadable files
csv_data = df.to_csv(index=False)
json_data = json.dumps(dataset, indent=2, ensure_ascii=False)
return df, csv_data, json_data, None
except Exception as e:
return None, None, None, f"Error in multi-model generation: {str(e)}"
def create_interface():
with gr.Blocks(title="🇹🇭 Thai Dataset Generator with Hugging Face", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🤗 เครื่องมือสร้างชุดข้อมูลภาษาไทยคุณภาพสูง")
gr.Markdown("สร้างชุดข้อมูลภาษาไทยคุณภาพสูง สะอาด และเป็นสากลด้วยโมเดลหลายตัว")
with gr.Row():
with gr.Column():
# Multi-model selection
gr.Markdown("### 🤖 เลือกโมเดลภาษาไทย (หลายตัว)")
model_checkboxes = gr.CheckboxGroup(
choices=[
("🌪️ Typhoon-7B (SCB10X)", "scb10x/typhoon-7b"),
("🇹🇭 OpenThaiGPT 1.5-7B", "openthaigpt/openthaigpt1.5-7b-instruct"),
("🦁 Gemma2-9B WangchanLION", "aisingapore/Gemma2-9b-WangchanLIONv2-instruct"),
("🌍 SambaLingo-Thai-Base", "sambanovasystems/SambaLingo-Thai-Base")
],
value=["scb10x/typhoon-7b"],
label="เลือกโมเดลที่ต้องการใช้งาน (สามารถเลือกหลายตัว)"
)
gr.Markdown("### 📊 โหมดการทำงาน")
work_mode = gr.Radio(
choices=[
("🔄 แบ่งงานกัน (Multi-Model Collaboration)", "collaborative"),
("📝 ใช้โมเดลเดียว (Single Model)", "single")
],
value="collaborative",
label="เลือกโหมดการทำงาน"
)
# Task selection with Thai tasks
gr.Markdown("### 📝 เลือกประเภทงาน")
task_dropdown = gr.Dropdown(
choices=[(v["name"], k) for k, v in TASK_TEMPLATES.items()],
value="text_generation",
label="ประเภทงานที่ต้องการ"
)
task_description = gr.Textbox(
label="คำอธิบายงาน",
value=TASK_TEMPLATES["text_generation"]["description"],
interactive=False
)
# File upload section
gr.Markdown("### 📁 อัปโหลดข้อมูลต้นฉบับ (ไม่บังคับ)")
gr.Markdown("อัปโหลดไฟล์ CSV, JSON หรือ TXT ที่มีข้อมูลต้นฉบับภาษาไทย")
file_upload = gr.File(
label="อัปโหลดไฟล์ข้อมูล",
file_types=[".csv", ".json", ".txt"]
)
file_preview = gr.Dataframe(
label="ตัวอย่างข้อมูลจากไฟล์",
visible=False,
max_rows=5
)
# Template customization
gr.Markdown("### 🎯 ปรับแต่งเทมเพลต")
gr.Markdown("ใช้ {ชื่อฟิลด์} สำหรับตัวแปรในเทมเพลต")
template_display = gr.Textbox(
label="เทมเพลตปัจจุบัน",
value=TASK_TEMPLATES["text_generation"]["template"],
interactive=False
)
custom_template = gr.Textbox(
label="เทมเพลตกำหนดเอง (ไม่บังคับ)",
lines=3,
placeholder="สร้างเทมเพลตของคุณเองที่นี่..."
)
# Generation parameters
gr.Markdown("### ⚙️ ตั้งค่าการสร้างข้อมูล")
with gr.Row():
num_samples = gr.Slider(
minimum=1,
maximum=100,
value=10,
step=1,
label="จำนวนข้อมูลที่ต้องการ"
)
max_length = gr.Slider(
minimum=10,
maximum=1000,
value=200,
step=10,
label="ความยาวสูงสุด (โทเคน)"
)
with gr.Row():
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.8,
step=0.1,
label="ความคิดสร้างสรรค์ (Temperature)"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.1,
label="ความหลากหลาย (Top-p)"
)
# Data Quality Settings
gr.Markdown("### 🧼 การจัดการคุณภาพข้อมูล")
enable_cleaning = gr.Checkbox(
label="เปิดใช้การทำความสะอาดข้อมูล",
value=True
)
remove_duplicates = gr.Checkbox(
label="ลบข้อมูลซ้ำซ้อน",
value=True
)
min_quality_score = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.5,
step=0.1,
label="คะแนนคุณภาพขั้นต่ำ (0-1)"
)
# Export Settings
gr.Markdown("### 📦 การส่งออกข้อมูล")
create_splits = gr.Checkbox(
label="แบ่งข้อมูล Train/Validation/Test",
value=True
)
export_format = gr.Radio(
choices=[
("📊 CSV + JSON (พื้นฐาน)", "standard"),
("🤗 Hugging Face Dataset (มาตรฐานสากล)", "huggingface"),
("📋 JSONL (สำหรับ Fine-tuning)", "jsonl")
],
value="huggingface",
label="รูปแบบการส่งออก"
)
generate_btn = gr.Button("🚀 สร้างชุดข้อมูลแบบทีมเวิร์ก", variant="primary", size="lg")
with gr.Column():
with gr.Tabs():
with gr.TabItem("📊 ตัวอย่างข้อมูล"):
dataset_preview = gr.Dataframe(
headers=["id", "task_type", "input", "output", "quality_score"],
interactive=False
)
with gr.TabItem("📈 รายงานคุณภาพ"):
quality_report = gr.JSON(
label="รายงานคุณภาพข้อมูล",
visible=True
)
quality_summary = gr.Markdown(
value="สร้างข้อมูลเสร็จแล้วจึงจะแสดงรายงานคุณภาพ"
)
with gr.TabItem("💾 ดาวน์โหลด"):
gr.Markdown("### 💾 ดาวน์โหลดชุดข้อมูลคุณภาพสูง")
download_info = gr.Markdown("สร้างข้อมูลเสร็จแล้วจึงจะสามารถดาวน์โหลดได้")
with gr.Row():
csv_btn = gr.Button("📄 ดาวน์โหลด CSV", variant="secondary")
json_btn = gr.Button("📋 ดาวน์โหลด JSON", variant="secondary")
hf_btn = gr.Button("🤗 ดาวน์โหลด HF Dataset", variant="secondary")
card_btn = gr.Button("📖 ดาวน์โหลด Dataset Card", variant="secondary")
csv_download = gr.File(
label="ไฟล์ CSV",
visible=False
)
json_download = gr.File(
label="ไฟล์ JSON",
visible=False
)
dataset_card_download = gr.File(
label="Dataset Card (README.md)",
visible=False
)
hf_dataset_download = gr.File(
label="Hugging Face Dataset",
visible=False
)
with gr.TabItem("📖 คู่มือการใช้งาน"):
gr.Markdown("""
## 📖 คู่มือการใช้งาน
### 🤖 เลือกโมเดล
1. **Typhoon-7B**: เหมาะสำหรับงานทั่วไป ประสิทธิภาพสูง
2. **OpenThaiGPT**: เหมาะสำหรับบทสนทนาและการทำตามคำสั่ง
3. **WangchanLION**: รองรับทั้งไทย-อังกฤษ
4. **SambaLingo**: โมเดลพื้นฐานที่เสถียร
### 📝 ประเภทงาน
- **การสร้างข้อความ**: สร้างเรื่องราว บทความ
- **คำถาม-คำตอบ**: สร้างคู่ Q&A
- **การสรุป**: สรุปข้อความยาว
- **บทสนทนา**: สร้างข้อมูลแชทบอท
- **กวีนิพนธ์**: สร้างบทกวีไทย
### 📁 การใช้ไฟล์ข้อมูล
- **CSV**: ต้องมีคอลัมน์ที่ตรงกับตัวแปรในเทมเพลต
- **JSON**: อาร์เรย์ของออบเจ็กต์
- **TXT**: แต่ละบรรทัดเป็นข้อมูลหนึ่งชิ้น
### ⚙️ พารามิเตอร์
- **Temperature**: 0.1-2.0 (ต่ำ=เสถียร, สูง=สร้างสรรค์)
- **Top-p**: 0.1-1.0 (ความหลากหลายของคำตอบ)
""")
status_message = gr.Textbox(
label="สถานะ",
visible=False,
lines=3
)
# Store data states
csv_data_state = gr.State()
json_data_state = gr.State()
file_data_state = gr.State([])
dataset_card_state = gr.State()
quality_report_state = gr.State()
def update_model_info(model_key):
if model_key in THAI_MODELS:
model_info = THAI_MODELS[model_key]
if model_key == "other":
return model_info["description"], ""
return model_info["description"], model_info["model_id"]
return "", ""
def update_task_info(task_type):
if task_type in TASK_TEMPLATES:
return (
TASK_TEMPLATES[task_type]["description"],
TASK_TEMPLATES[task_type]["template"]
)
return "", ""
def process_file(file):
if file is None:
return gr.update(visible=False), []
try:
data = load_file_data(file.name)
df = pd.DataFrame(data[:5])
return gr.update(visible=True, value=df), data
except Exception as e:
return gr.update(visible=False), []
def on_generate(model_name, task_type, custom_template, file_data, num_samples, max_length, temperature, top_p):
if not model_name.strip():
return (
gr.update(visible=False),
gr.update(visible=True, value="❌ กรุณาระบุชื่อโมเดล"),
None, None
)
df, csv_data, json_data, error = generate_dataset_from_task(
model_name, task_type, custom_template, file_data,
num_samples, max_length, temperature, top_p
)
if error:
return (
gr.update(visible=False),
gr.update(visible=True, value=f"❌ เกิดข้อผิดพลาด: {error}"),
csv_data, json_data
)
else:
success_msg = f"✅ สร้างข้อมูลสำเร็จ! ได้ {len(df)} รายการ"
return (
gr.update(visible=True, value=df),
gr.update(visible=True, value=success_msg),
csv_data, json_data
)
def download_csv(csv_data):
if csv_data:
return gr.update(visible=True, value=io.StringIO(csv_data))
return gr.update(visible=False)
def download_json(json_data):
if json_data:
return gr.update(visible=True, value=io.StringIO(json_data))
return gr.update(visible=False)
# Event connections
model_checkboxes.change(
fn=lambda *model_keys: ", ".join([THAI_MODELS[k]["name"] for k in model_keys]),
inputs=[model_checkboxes],
outputs=[status_message]
)
model_checkboxes.change(
fn=lambda *model_keys: [THAI_MODELS[k]["description"] for k in model_keys],
inputs=[model_checkboxes],
outputs=[model_description]
)
task_dropdown.change(
fn=update_task_info,
inputs=[task_dropdown],
outputs=[task_description, template_display]
)
file_upload.change(
fn=process_file,
inputs=[file_upload],
outputs=[file_preview, file_data_state]
)
generate_btn.click(
fn=on_generate_multi,
inputs=[model_checkboxes, work_mode, task_dropdown, custom_template, file_data_state,
num_samples, max_length, temperature, top_p],
outputs=[dataset_preview, status_message, csv_data_state, json_data_state]
)
csv_btn.click(
fn=download_csv,
inputs=[csv_data_state],
outputs=[csv_download]
)
json_btn.click(
fn=download_json,
inputs=[json_data_state],
outputs=[json_download]
)
return demo
demo = create_interface()
demo.launch()