|
|
|
import os |
|
import gc |
|
import gradio as gr |
|
from datasets import load_dataset |
|
from train_tokenizer import train_tokenizer |
|
from tokenizers import Tokenizer |
|
from langdetect import detect, DetectorFactory |
|
from PIL import Image |
|
from datetime import datetime |
|
from concurrent.futures import ThreadPoolExecutor |
|
import matplotlib.pyplot as plt |
|
from io import BytesIO |
|
import traceback |
|
|
|
|
|
DetectorFactory.seed = 0 |
|
|
|
|
|
CHECKPOINT_FILE = "checkpoint.txt" |
|
TOKENIZER_DIR = "./tokenizer_model" |
|
TOKENIZER_FILE = os.path.join(TOKENIZER_DIR, "tokenizer.json") |
|
MAX_SAMPLES = 5000000 |
|
DEFAULT_CHUNK_SIZE = 200000 |
|
BATCH_SIZE = 1000 |
|
NUM_WORKERS = 4 |
|
|
|
|
|
STOP_COLLECTION = False |
|
|
|
def load_checkpoint(): |
|
"""Φόρτωση δεδομένων από το checkpoint.""" |
|
if os.path.exists(CHECKPOINT_FILE): |
|
with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f: |
|
return f.read().splitlines() |
|
return [] |
|
|
|
def append_to_checkpoint(texts): |
|
"""Αποθήκευση δεδομένων με ομαδοποίηση.""" |
|
with open(CHECKPOINT_FILE, "a", encoding="utf-8") as f: |
|
batch = "\n".join(texts) + "\n" |
|
f.write(batch) |
|
|
|
def create_iterator(dataset_name, configs, split): |
|
"""Βελτιωμένο iterator με batch φόρτωση και caching.""" |
|
configs_list = [c.strip() for c in configs.split(",") if c.strip()] |
|
for config in configs_list: |
|
try: |
|
dataset = load_dataset( |
|
dataset_name, |
|
name=config, |
|
split=split, |
|
streaming=True, |
|
cache_dir="./dataset_cache" |
|
) |
|
while True: |
|
batch = list(dataset.take(BATCH_SIZE)) |
|
if not batch: |
|
break |
|
dataset = dataset.skip(BATCH_SIZE) |
|
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor: |
|
processed_texts = list(executor.map(process_example, batch)) |
|
yield from filter(None, processed_texts) |
|
except Exception as e: |
|
print(f"⚠️ Σφάλμα φόρτωσης {config}: {e}") |
|
|
|
def process_example(example): |
|
"""Επεξεργασία ενός παραδείγματος με έλεγχο γλώσσας.""" |
|
try: |
|
text = example.get('text', '').strip() |
|
if text and detect(text) in ['el', 'en']: |
|
return text |
|
return None |
|
except: |
|
return None |
|
|
|
def collect_samples(dataset_name, configs, split, chunk_size, max_samples): |
|
"""Συλλογή δεδομένων με streaming και checkpoints.""" |
|
global STOP_COLLECTION |
|
STOP_COLLECTION = False |
|
total_processed = len(load_checkpoint()) |
|
progress_messages = [f"🚀 Εκκίνηση συλλογής... Πρόοδος: {total_processed}/{max_samples}"] |
|
dataset_iterator = create_iterator(dataset_name, configs, split) |
|
chunk = [] |
|
while not STOP_COLLECTION and total_processed < max_samples: |
|
try: |
|
while len(chunk) < chunk_size: |
|
text = next(dataset_iterator) |
|
if text: |
|
chunk.append(text) |
|
total_processed += 1 |
|
if total_processed >= max_samples: |
|
break |
|
if chunk: |
|
append_to_checkpoint(chunk) |
|
progress_messages.append(f"✅ Αποθηκεύτηκαν {len(chunk)} δείγματα (Σύνολο: {total_processed})") |
|
chunk = [] |
|
gc.collect() |
|
except StopIteration: |
|
progress_messages.append("🏁 Ολοκληρώθηκε η επεξεργασία όλων των δεδομένων!") |
|
break |
|
except Exception as e: |
|
progress_messages.append(f"⛔ Σφάλμα: {str(e)}") |
|
break |
|
return "\n".join(progress_messages) |
|
|
|
def train_tokenizer_fn(dataset_name, configs, split, vocab_size, min_freq, test_text): |
|
"""Εκπαίδευση του tokenizer και έλεγχος ποιότητας.""" |
|
messages = ["🚀 Εκκίνηση εκπαίδευσης..."] |
|
try: |
|
all_texts = load_checkpoint() |
|
messages.append("📚 Φόρτωση δεδομένων από checkpoint...") |
|
tokenizer = train_tokenizer(all_texts, vocab_size, min_freq, TOKENIZER_DIR, NUM_WORKERS) |
|
messages.append("✅ Εκπαίδευση ολοκληρώθηκε!") |
|
trained_tokenizer = Tokenizer.from_file(TOKENIZER_FILE) |
|
encoded = trained_tokenizer.encode(test_text) |
|
decoded = trained_tokenizer.decode(encoded.ids) |
|
fig, ax = plt.subplots() |
|
ax.hist([len(t) for t in encoded.tokens], bins=20) |
|
ax.set_xlabel('Μήκος Token') |
|
ax.set_ylabel('Συχνότητα') |
|
img_buffer = BytesIO() |
|
plt.savefig(img_buffer, format='png') |
|
plt.close() |
|
return ("\n".join(messages), decoded, Image.open(img_buffer)) |
|
except Exception as e: |
|
messages.append(f"❌ Σφάλμα: {str(e)}") |
|
return ("\n".join(messages), "", None) |
|
|
|
def analyze_checkpoint(): |
|
"""Ανάλυση δεδομένων από το checkpoint.""" |
|
messages = ["🔍 Έναρξη ανάλυσης..."] |
|
try: |
|
texts = load_checkpoint() |
|
if not texts: |
|
return "Δεν βρέθηκαν δεδομένα για ανάλυση." |
|
total_chars = sum(len(t) for t in texts) |
|
avg_length = total_chars / len(texts) if texts else 0 |
|
languages = {} |
|
for t in texts[:1000]: |
|
if len(t) > 20: |
|
try: |
|
lang = detect(t) |
|
languages[lang] = languages.get(lang, 0) + 1 |
|
except Exception as e: |
|
print(f"⚠️ Σφάλμα ανίχνευσης γλώσσας: {e}") |
|
report = [ |
|
f"📊 Σύνολο δειγμάτων: {len(texts)}", |
|
f"📝 Μέσο μήκος: {avg_length:.1f} χαρακτήρες", |
|
"🌍 Γλώσσες (δείγμα 1000):", |
|
*[f"- {k}: {v} ({v/10:.1f}%)" for k, v in languages.items()] |
|
] |
|
return "\n".join(messages + report) |
|
except Exception as e: |
|
messages.append(f"❌ Σφάλμα: {str(e)}") |
|
return "\n".join(messages) |
|
|
|
def restart_collection(): |
|
"""Διαγραφή checkpoint και επανεκκίνηση.""" |
|
global STOP_COLLECTION |
|
STOP_COLLECTION = False |
|
if os.path.exists(CHECKPOINT_FILE): |
|
os.remove(CHECKPOINT_FILE) |
|
return "🔄 Το checkpoint διαγράφηκε. Έτοιμο για νέα συλλογή." |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Custom Tokenizer Trainer για GPT-2") |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
dataset_name = gr.Textbox(value="wikimedia/wikipedia", label="Dataset") |
|
configs = gr.Textbox(value="20231101.el,20231101.en", label="Configurations") |
|
split = gr.Dropdown(["train"], value="train", label="Split") |
|
chunk_size = gr.Slider(10000, 500000, value=200000, step=10000, label="Chunk Size") |
|
vocab_size = gr.Slider(20000, 50000, value=30000, step=1000, label="Μέγεθος Λεξιλογίου") |
|
min_freq = gr.Slider(1, 10, value=3, label="Ελάχιστη Συχνότητα") |
|
test_text = gr.Textbox(value="Η Ακρόπολη είναι σύμβολο της αρχαίας Ελλάδας.", label="Test Text") |
|
max_samples = gr.Slider(10000, 10000000, value=5000000, step=100000, label="Μέγιστα Δείγματα") |
|
with gr.Row(): |
|
start_btn = gr.Button("Start", variant="primary") |
|
stop_btn = gr.Button("Stop", variant="stop") |
|
restart_btn = gr.Button("Restart") |
|
analyze_btn = gr.Button("Analyze Data") |
|
train_btn = gr.Button("Train Tokenizer", variant="primary") |
|
with gr.Column(scale=3): |
|
progress = gr.Textbox(label="Πρόοδος", lines=10, interactive=False) |
|
gr.Markdown("### Αποτελέσματα") |
|
decoded_text = gr.Textbox(label="Αποκωδικοποιημένο Κείμενο") |
|
token_distribution = gr.Image(label="Κατανομή Tokens") |
|
|
|
|
|
start_btn.click(collect_samples, [dataset_name, configs, split, chunk_size, max_samples], progress) |
|
stop_btn.click(lambda: globals().update(STOP_COLLECTION=True) or "⏹️ Διακοπή συλλογής...", None, progress, queue=False) |
|
restart_btn.click(restart_collection, None, progress) |
|
analyze_btn.click(analyze_checkpoint, None, progress) |
|
train_btn.click(train_tokenizer_fn, [dataset_name, configs, split, vocab_size, min_freq, test_text], |
|
[progress, decoded_text, token_distribution]) |
|
|
|
demo.queue().launch() |