GPT2-PBE / app.py
tymbos's picture
Update app.py
72577d1 verified
# -*- coding: utf-8 -*-
import os
import gc
import gradio as gr
from datasets import load_dataset
from train_tokenizer import train_tokenizer
from tokenizers import Tokenizer
from langdetect import detect, DetectorFactory
from PIL import Image
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
import matplotlib.pyplot as plt
from io import BytesIO
import traceback
# Για επαναληψιμότητα στο langdetect
DetectorFactory.seed = 0
# Ρυθμίσεις
CHECKPOINT_FILE = "checkpoint.txt"
TOKENIZER_DIR = "./tokenizer_model"
TOKENIZER_FILE = os.path.join(TOKENIZER_DIR, "tokenizer.json")
MAX_SAMPLES = 5000000
DEFAULT_CHUNK_SIZE = 200000
BATCH_SIZE = 1000
NUM_WORKERS = 4
# Παγκόσμια μεταβλητή ελέγχου
STOP_COLLECTION = False
def load_checkpoint():
"""Φόρτωση δεδομένων από το checkpoint."""
if os.path.exists(CHECKPOINT_FILE):
with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f:
return f.read().splitlines()
return []
def append_to_checkpoint(texts):
"""Αποθήκευση δεδομένων με ομαδοποίηση."""
with open(CHECKPOINT_FILE, "a", encoding="utf-8") as f:
batch = "\n".join(texts) + "\n"
f.write(batch)
def create_iterator(dataset_name, configs, split):
"""Βελτιωμένο iterator με batch φόρτωση και caching."""
configs_list = [c.strip() for c in configs.split(",") if c.strip()]
for config in configs_list:
try:
dataset = load_dataset(
dataset_name,
name=config,
split=split,
streaming=True,
cache_dir="./dataset_cache"
)
while True:
batch = list(dataset.take(BATCH_SIZE))
if not batch:
break
dataset = dataset.skip(BATCH_SIZE)
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
processed_texts = list(executor.map(process_example, batch))
yield from filter(None, processed_texts)
except Exception as e:
print(f"⚠️ Σφάλμα φόρτωσης {config}: {e}")
def process_example(example):
"""Επεξεργασία ενός παραδείγματος με έλεγχο γλώσσας."""
try:
text = example.get('text', '').strip()
if text and detect(text) in ['el', 'en']:
return text
return None
except:
return None
def collect_samples(dataset_name, configs, split, chunk_size, max_samples):
"""Συλλογή δεδομένων με streaming και checkpoints."""
global STOP_COLLECTION
STOP_COLLECTION = False
total_processed = len(load_checkpoint())
progress_messages = [f"🚀 Εκκίνηση συλλογής... Πρόοδος: {total_processed}/{max_samples}"]
dataset_iterator = create_iterator(dataset_name, configs, split)
chunk = []
while not STOP_COLLECTION and total_processed < max_samples:
try:
while len(chunk) < chunk_size:
text = next(dataset_iterator)
if text:
chunk.append(text)
total_processed += 1
if total_processed >= max_samples:
break
if chunk:
append_to_checkpoint(chunk)
progress_messages.append(f"✅ Αποθηκεύτηκαν {len(chunk)} δείγματα (Σύνολο: {total_processed})")
chunk = []
gc.collect()
except StopIteration:
progress_messages.append("🏁 Ολοκληρώθηκε η επεξεργασία όλων των δεδομένων!")
break
except Exception as e:
progress_messages.append(f"⛔ Σφάλμα: {str(e)}")
break
return "\n".join(progress_messages)
def train_tokenizer_fn(dataset_name, configs, split, vocab_size, min_freq, test_text):
"""Εκπαίδευση του tokenizer και έλεγχος ποιότητας."""
messages = ["🚀 Εκκίνηση εκπαίδευσης..."]
try:
all_texts = load_checkpoint()
messages.append("📚 Φόρτωση δεδομένων από checkpoint...")
tokenizer = train_tokenizer(all_texts, vocab_size, min_freq, TOKENIZER_DIR, NUM_WORKERS)
messages.append("✅ Εκπαίδευση ολοκληρώθηκε!")
trained_tokenizer = Tokenizer.from_file(TOKENIZER_FILE)
encoded = trained_tokenizer.encode(test_text)
decoded = trained_tokenizer.decode(encoded.ids)
fig, ax = plt.subplots()
ax.hist([len(t) for t in encoded.tokens], bins=20)
ax.set_xlabel('Μήκος Token')
ax.set_ylabel('Συχνότητα')
img_buffer = BytesIO()
plt.savefig(img_buffer, format='png')
plt.close()
return ("\n".join(messages), decoded, Image.open(img_buffer))
except Exception as e:
messages.append(f"❌ Σφάλμα: {str(e)}")
return ("\n".join(messages), "", None)
def analyze_checkpoint():
"""Ανάλυση δεδομένων από το checkpoint."""
messages = ["🔍 Έναρξη ανάλυσης..."]
try:
texts = load_checkpoint()
if not texts:
return "Δεν βρέθηκαν δεδομένα για ανάλυση."
total_chars = sum(len(t) for t in texts)
avg_length = total_chars / len(texts) if texts else 0
languages = {}
for t in texts[:1000]:
if len(t) > 20:
try:
lang = detect(t)
languages[lang] = languages.get(lang, 0) + 1
except Exception as e:
print(f"⚠️ Σφάλμα ανίχνευσης γλώσσας: {e}")
report = [
f"📊 Σύνολο δειγμάτων: {len(texts)}",
f"📝 Μέσο μήκος: {avg_length:.1f} χαρακτήρες",
"🌍 Γλώσσες (δείγμα 1000):",
*[f"- {k}: {v} ({v/10:.1f}%)" for k, v in languages.items()]
]
return "\n".join(messages + report)
except Exception as e:
messages.append(f"❌ Σφάλμα: {str(e)}")
return "\n".join(messages)
def restart_collection():
"""Διαγραφή checkpoint και επανεκκίνηση."""
global STOP_COLLECTION
STOP_COLLECTION = False
if os.path.exists(CHECKPOINT_FILE):
os.remove(CHECKPOINT_FILE)
return "🔄 Το checkpoint διαγράφηκε. Έτοιμο για νέα συλλογή."
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Custom Tokenizer Trainer για GPT-2")
with gr.Row():
with gr.Column(scale=2):
dataset_name = gr.Textbox(value="wikimedia/wikipedia", label="Dataset")
configs = gr.Textbox(value="20231101.el,20231101.en", label="Configurations")
split = gr.Dropdown(["train"], value="train", label="Split")
chunk_size = gr.Slider(10000, 500000, value=200000, step=10000, label="Chunk Size")
vocab_size = gr.Slider(20000, 50000, value=30000, step=1000, label="Μέγεθος Λεξιλογίου")
min_freq = gr.Slider(1, 10, value=3, label="Ελάχιστη Συχνότητα")
test_text = gr.Textbox(value="Η Ακρόπολη είναι σύμβολο της αρχαίας Ελλάδας.", label="Test Text")
max_samples = gr.Slider(10000, 10000000, value=5000000, step=100000, label="Μέγιστα Δείγματα")
with gr.Row():
start_btn = gr.Button("Start", variant="primary")
stop_btn = gr.Button("Stop", variant="stop")
restart_btn = gr.Button("Restart")
analyze_btn = gr.Button("Analyze Data")
train_btn = gr.Button("Train Tokenizer", variant="primary")
with gr.Column(scale=3):
progress = gr.Textbox(label="Πρόοδος", lines=10, interactive=False)
gr.Markdown("### Αποτελέσματα")
decoded_text = gr.Textbox(label="Αποκωδικοποιημένο Κείμενο")
token_distribution = gr.Image(label="Κατανομή Tokens")
# Event handlers
start_btn.click(collect_samples, [dataset_name, configs, split, chunk_size, max_samples], progress)
stop_btn.click(lambda: globals().update(STOP_COLLECTION=True) or "⏹️ Διακοπή συλλογής...", None, progress, queue=False)
restart_btn.click(restart_collection, None, progress)
analyze_btn.click(analyze_checkpoint, None, progress)
train_btn.click(train_tokenizer_fn, [dataset_name, configs, split, vocab_size, min_freq, test_text],
[progress, decoded_text, token_distribution])
demo.queue().launch()