Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
# Basic Python imports | |
import os | |
import sys | |
import json | |
import argparse | |
import logging | |
from datetime import datetime | |
import time | |
import warnings | |
import traceback | |
from importlib.util import find_spec | |
import multiprocessing | |
import torch | |
import random | |
import numpy as np | |
from tqdm import tqdm | |
# Check hardware capabilities first | |
CUDA_AVAILABLE = "CUDA_VISIBLE_DEVICES" in os.environ or os.environ.get("NVIDIA_VISIBLE_DEVICES") != "" | |
NUM_GPUS = torch.cuda.device_count() if CUDA_AVAILABLE else 0 | |
DEVICE_TYPE = "cuda" if CUDA_AVAILABLE else "cpu" | |
# Set the multiprocessing start method to 'spawn' for CUDA compatibility | |
if CUDA_AVAILABLE: | |
try: | |
multiprocessing.set_start_method('spawn', force=True) | |
print("Set multiprocessing start method to 'spawn' for CUDA compatibility") | |
except RuntimeError: | |
# Method already set, which is fine | |
print("Multiprocessing start method already set") | |
# Import order is important: unsloth should be imported before transformers | |
# Check for libraries without importing them | |
unsloth_available = find_spec("unsloth") is not None | |
if unsloth_available: | |
import unsloth | |
# Import torch first, then transformers if available | |
import torch | |
transformers_available = find_spec("transformers") is not None | |
if transformers_available: | |
import transformers | |
from transformers import AutoTokenizer, TrainingArguments, Trainer, TrainerCallback, set_seed | |
from torch.utils.data import DataLoader | |
peft_available = find_spec("peft") is not None | |
if peft_available: | |
import peft | |
# Only import HF datasets if available | |
datasets_available = find_spec("datasets") is not None | |
if datasets_available: | |
from datasets import load_dataset | |
# Set up the logger | |
logger = logging.getLogger(__name__) | |
log_handler = logging.StreamHandler() | |
log_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | |
log_handler.setFormatter(log_format) | |
logger.addHandler(log_handler) | |
logger.setLevel(logging.INFO) | |
# Define a clean logging function for HF Space compatibility | |
def log_info(message): | |
"""Log information in a format compatible with Hugging Face Spaces""" | |
# Just use the logger, but ensure consistent formatting | |
logger.info(message) | |
# Also ensure output is flushed immediately for streaming | |
sys.stdout.flush() | |
# Check for BitsAndBytes | |
try: | |
from transformers import BitsAndBytesConfig | |
bitsandbytes_available = True | |
except ImportError: | |
bitsandbytes_available = False | |
logger.warning("BitsAndBytes not available. 4-bit quantization will not be used.") | |
# Check for PEFT | |
try: | |
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
peft_available = True | |
except ImportError: | |
peft_available = False | |
logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.") | |
def load_env_variables(): | |
"""Load environment variables from system, .env file, or Hugging Face Space variables.""" | |
# Check if we're running in a Hugging Face Space | |
if os.environ.get("SPACE_ID"): | |
logging.info("Running in Hugging Face Space") | |
# Log the presence of variables (without revealing values) | |
logging.info(f"HF_TOKEN available: {bool(os.environ.get('HF_TOKEN'))}") | |
logging.info(f"HF_USERNAME available: {bool(os.environ.get('HF_USERNAME'))}") | |
# If username is not set, try to extract from SPACE_ID | |
if not os.environ.get("HF_USERNAME") and "/" in os.environ.get("SPACE_ID", ""): | |
username = os.environ.get("SPACE_ID").split("/")[0] | |
os.environ["HF_USERNAME"] = username | |
logging.info(f"Set HF_USERNAME from SPACE_ID: {username}") | |
else: | |
# Try to load from .env file if not in a Space | |
try: | |
from dotenv import load_dotenv | |
# First check the current directory | |
env_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".env") | |
if os.path.exists(env_path): | |
load_dotenv(env_path) | |
logging.info(f"Loaded environment variables from {env_path}") | |
logging.info(f"HF_TOKEN loaded from .env file: {bool(os.environ.get('HF_TOKEN'))}") | |
logging.info(f"HF_USERNAME loaded from .env file: {bool(os.environ.get('HF_USERNAME'))}") | |
logging.info(f"HF_SPACE_NAME loaded from .env file: {bool(os.environ.get('HF_SPACE_NAME'))}") | |
else: | |
# Try the shared directory as fallback | |
shared_env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "shared", ".env") | |
if os.path.exists(shared_env_path): | |
load_dotenv(shared_env_path) | |
logging.info(f"Loaded environment variables from {shared_env_path}") | |
logging.info(f"HF_TOKEN loaded from shared .env file: {bool(os.environ.get('HF_TOKEN'))}") | |
logging.info(f"HF_USERNAME loaded from shared .env file: {bool(os.environ.get('HF_USERNAME'))}") | |
logging.info(f"HF_SPACE_NAME loaded from shared .env file: {bool(os.environ.get('HF_SPACE_NAME'))}") | |
else: | |
logging.warning(f"No .env file found in current or shared directory") | |
except ImportError: | |
logging.warning("python-dotenv not installed, not loading from .env file") | |
if not os.environ.get("HF_TOKEN"): | |
logger.warning("HF_TOKEN is not set. Pushing to Hugging Face Hub will not work.") | |
if not os.environ.get("HF_USERNAME"): | |
logger.warning("HF_USERNAME is not set. Using default username.") | |
if not os.environ.get("HF_SPACE_NAME"): | |
logger.warning("HF_SPACE_NAME is not set. Using default space name.") | |
# Set HF_TOKEN for huggingface_hub | |
if os.environ.get("HF_TOKEN"): | |
os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN") | |
def load_configs(base_path): | |
"""Load configuration from transformers_config.json file.""" | |
# Using a single consolidated config file | |
config_file = base_path | |
try: | |
with open(config_file, "r") as f: | |
config = json.load(f) | |
logger.info(f"Loaded configuration from {config_file}") | |
return config | |
except Exception as e: | |
logger.error(f"Error loading {config_file}: {e}") | |
raise | |
def parse_args(): | |
""" | |
Parse command line arguments for the training script. | |
Returns: | |
argparse.Namespace: The parsed command line arguments | |
""" | |
parser = argparse.ArgumentParser(description="Run training for language models") | |
parser.add_argument( | |
"--config_file", | |
type=str, | |
default=None, | |
help="Path to the configuration file (default: transformers_config.json in script directory)" | |
) | |
parser.add_argument( | |
"--seed", | |
type=int, | |
default=None, | |
help="Random seed for reproducibility (default: based on current time)" | |
) | |
parser.add_argument( | |
"--log_level", | |
type=str, | |
choices=["debug", "info", "warning", "error", "critical"], | |
default="info", | |
help="Logging level (default: info)" | |
) | |
return parser.parse_args() | |
def load_model_and_tokenizer(config): | |
""" | |
Load the model and tokenizer according to the configuration. | |
Args: | |
config (dict): Complete configuration dictionary | |
Returns: | |
tuple: (model, tokenizer) - The loaded model and tokenizer | |
""" | |
# Extract model configuration | |
model_config = get_config_value(config, "model", {}) | |
model_name = get_config_value(model_config, "name", "unsloth/phi-4-unsloth-bnb-4bit") | |
use_fast_tokenizer = get_config_value(model_config, "use_fast_tokenizer", True) | |
trust_remote_code = get_config_value(model_config, "trust_remote_code", True) | |
model_revision = get_config_value(config, "model_revision", "main") | |
# Unsloth configuration | |
unsloth_config = get_config_value(config, "unsloth", {}) | |
unsloth_enabled = get_config_value(unsloth_config, "enabled", True) | |
# Tokenizer configuration | |
tokenizer_config = get_config_value(config, "tokenizer", {}) | |
max_seq_length = min( | |
get_config_value(tokenizer_config, "max_seq_length", 2048), | |
4096 # Maximum supported by most models | |
) | |
add_eos_token = get_config_value(tokenizer_config, "add_eos_token", True) | |
chat_template = get_config_value(tokenizer_config, "chat_template", None) | |
padding_side = get_config_value(tokenizer_config, "padding_side", "right") | |
# Check for flash attention | |
use_flash_attention = get_config_value(config, "use_flash_attention", False) | |
flash_attention_available = False | |
try: | |
import flash_attn | |
flash_attention_available = True | |
log_info(f"Flash Attention detected (version: {flash_attn.__version__})") | |
except ImportError: | |
if use_flash_attention: | |
log_info("Flash Attention requested but not available") | |
log_info(f"Loading model: {model_name} (revision: {model_revision})") | |
log_info(f"Max sequence length: {max_seq_length}") | |
try: | |
if unsloth_enabled and unsloth_available: | |
log_info("Using Unsloth for 4-bit quantized model and LoRA") | |
# Load using Unsloth | |
from unsloth import FastLanguageModel | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name=model_name, | |
max_seq_length=max_seq_length, | |
dtype=get_config_value(config, "torch_dtype", "bfloat16"), | |
revision=model_revision, | |
trust_remote_code=trust_remote_code, | |
use_flash_attention_2=use_flash_attention and flash_attention_available | |
) | |
# Configure tokenizer settings | |
tokenizer.padding_side = padding_side | |
if add_eos_token and tokenizer.eos_token is None: | |
log_info("Setting EOS token") | |
tokenizer.add_special_tokens({"eos_token": "</s>"}) | |
# Set chat template if specified | |
if chat_template: | |
log_info(f"Setting chat template: {chat_template}") | |
if hasattr(tokenizer, "chat_template"): | |
tokenizer.chat_template = chat_template | |
else: | |
log_info("Tokenizer does not support chat templates, using default formatting") | |
# Apply LoRA | |
lora_r = get_config_value(unsloth_config, "r", 16) | |
lora_alpha = get_config_value(unsloth_config, "alpha", 32) | |
lora_dropout = get_config_value(unsloth_config, "dropout", 0) | |
target_modules = get_config_value(unsloth_config, "target_modules", | |
["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]) | |
log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}") | |
model = FastLanguageModel.get_peft_model( | |
model, | |
r=lora_r, | |
target_modules=target_modules, | |
lora_alpha=lora_alpha, | |
lora_dropout=lora_dropout, | |
bias="none", | |
use_gradient_checkpointing=get_config_value(config, "training.gradient_checkpointing", True), | |
random_state=0, | |
max_seq_length=max_seq_length, | |
modules_to_save=None | |
) | |
if use_flash_attention and flash_attention_available: | |
log_info("🚀 Using Flash Attention for faster training") | |
elif use_flash_attention and not flash_attention_available: | |
log_info("⚠️ Flash Attention requested but not available - using standard attention") | |
else: | |
# Standard HuggingFace loading | |
log_info("Using standard HuggingFace model loading (Unsloth not available or disabled)") | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Check if flash attention should be enabled in config | |
use_attn_implementation = None | |
if use_flash_attention and flash_attention_available: | |
use_attn_implementation = "flash_attention_2" | |
log_info("🚀 Using Flash Attention for faster training") | |
# Load tokenizer first | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
trust_remote_code=trust_remote_code, | |
use_fast=use_fast_tokenizer, | |
revision=model_revision, | |
padding_side=padding_side | |
) | |
# Configure tokenizer settings | |
if add_eos_token and tokenizer.eos_token is None: | |
log_info("Setting EOS token") | |
tokenizer.add_special_tokens({"eos_token": "</s>"}) | |
# Set chat template if specified | |
if chat_template: | |
log_info(f"Setting chat template: {chat_template}") | |
if hasattr(tokenizer, "chat_template"): | |
tokenizer.chat_template = chat_template | |
else: | |
log_info("Tokenizer does not support chat templates, using default formatting") | |
# Now load model with updated tokenizer | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=trust_remote_code, | |
revision=model_revision, | |
torch_dtype=torch.bfloat16 if get_config_value(config, "torch_dtype", "bfloat16") == "bfloat16" else torch.float16, | |
device_map="auto" if CUDA_AVAILABLE else None, | |
attn_implementation=use_attn_implementation | |
) | |
# Apply PEFT/LoRA if enabled but using standard loading | |
if peft_available and get_config_value(unsloth_config, "enabled", True): | |
log_info("Applying standard PEFT/LoRA configuration") | |
from peft import LoraConfig, get_peft_model | |
lora_r = get_config_value(unsloth_config, "r", 16) | |
lora_alpha = get_config_value(unsloth_config, "alpha", 32) | |
lora_dropout = get_config_value(unsloth_config, "dropout", 0) | |
target_modules = get_config_value(unsloth_config, "target_modules", | |
["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]) | |
log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}") | |
lora_config = LoraConfig( | |
r=lora_r, | |
lora_alpha=lora_alpha, | |
target_modules=target_modules, | |
lora_dropout=lora_dropout, | |
bias="none", | |
task_type="CAUSAL_LM" | |
) | |
model = get_peft_model(model, lora_config) | |
# Print model summary | |
log_info(f"Model loaded successfully: {model.__class__.__name__}") | |
if hasattr(model, "print_trainable_parameters"): | |
model.print_trainable_parameters() | |
else: | |
total_params = sum(p.numel() for p in model.parameters()) | |
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
log_info(f"Model has {total_params:,} parameters, {trainable_params:,} trainable ({trainable_params/total_params:.2%})") | |
return model, tokenizer | |
except Exception as e: | |
log_info(f"Error loading model: {str(e)}") | |
traceback.print_exc() | |
return None, None | |
def load_dataset_with_mapping(config): | |
""" | |
Load dataset from Hugging Face or local files and apply necessary transformations. | |
Args: | |
config (dict): Dataset configuration dictionary | |
Returns: | |
Dataset: The loaded and processed dataset | |
""" | |
# Extract dataset configuration | |
dataset_info = get_config_value(config, "dataset", {}) | |
dataset_name = get_config_value(dataset_info, "name", None) | |
dataset_split = get_config_value(dataset_info, "split", "train") | |
# Data formatting configuration | |
formatting_config = get_config_value(config, "data_formatting", {}) | |
if not dataset_name: | |
raise ValueError("Dataset name not specified in config") | |
log_info(f"Loading dataset: {dataset_name} (split: {dataset_split})") | |
try: | |
# Load dataset from Hugging Face or local path | |
from datasets import load_dataset | |
# Check if it's a local path or Hugging Face dataset | |
if os.path.exists(dataset_name) or os.path.exists(os.path.join(os.getcwd(), dataset_name)): | |
log_info(f"Loading dataset from local path: {dataset_name}") | |
# Local dataset - check if it's a directory or file | |
if os.path.isdir(dataset_name): | |
# Directory - look for data files | |
dataset = load_dataset( | |
"json", | |
data_files={"train": os.path.join(dataset_name, "*.json")}, | |
split=dataset_split | |
) | |
else: | |
# Single file | |
dataset = load_dataset( | |
"json", | |
data_files={"train": dataset_name}, | |
split=dataset_split | |
) | |
else: | |
# Hugging Face dataset | |
log_info(f"Loading dataset from Hugging Face: {dataset_name}") | |
dataset = load_dataset(dataset_name, split=dataset_split) | |
log_info(f"Dataset loaded with {len(dataset)} examples") | |
# Check if dataset contains required fields | |
required_fields = ["conversations"] | |
missing_fields = [field for field in required_fields if field not in dataset.column_names] | |
if missing_fields: | |
log_info(f"WARNING: Dataset missing required fields: {missing_fields}") | |
log_info("Attempting to map dataset structure to required format") | |
# Implement conversion logic based on dataset structure | |
if "messages" in dataset.column_names: | |
log_info("Converting 'messages' field to 'conversations' format") | |
dataset = dataset.map( | |
lambda x: {"conversations": x["messages"]}, | |
remove_columns=["messages"] | |
) | |
elif "text" in dataset.column_names: | |
log_info("Converting plain text to conversations format") | |
dataset = dataset.map( | |
lambda x: {"conversations": [{"role": "user", "content": x["text"]}]}, | |
remove_columns=["text"] | |
) | |
else: | |
raise ValueError(f"Cannot convert dataset format - missing required fields and no conversion path available") | |
# Log dataset info | |
log_info(f"Dataset has {len(dataset)} examples and columns: {dataset.column_names}") | |
# Show a few examples for verification | |
for i in range(min(3, len(dataset))): | |
example = dataset[i] | |
log_info(f"Example {i}:") | |
for key, value in example.items(): | |
if key == "conversations": | |
log_info(f" conversations: {len(value)} messages") | |
# Show first message only to avoid cluttering logs | |
if value and len(value) > 0: | |
first_msg = value[0] | |
if isinstance(first_msg, dict) and "content" in first_msg: | |
content = first_msg["content"] | |
log_info(f" First message: {content[:50]}..." if len(content) > 50 else f" First message: {content}") | |
else: | |
log_info(f" {key}: {value}") | |
return dataset | |
except Exception as e: | |
log_info(f"Error loading dataset: {str(e)}") | |
traceback.print_exc() | |
return None | |
def format_phi_chat(messages, dataset_config): | |
"""Format messages according to phi-4's chat template and dataset config. | |
Only formats the conversation structure, preserves the actual content.""" | |
formatted_chat = "" | |
# Get role templates from config | |
roles = dataset_config.get("data_formatting", {}).get("roles", { | |
"system": "System: {content}\n\n", | |
"human": "Human: {content}\n\n", | |
"assistant": "Assistant: {content}\n\n" | |
}) | |
# Handle each message in the conversation | |
for message in messages: | |
if not isinstance(message, dict) or "content" not in message: | |
logger.warning(f"Skipping invalid message format: {message}") | |
continue | |
content = message.get("content", "") # Don't strip() - preserve exact content | |
# Skip empty content | |
if not content: | |
continue | |
# Only add role prefixes based on position/content | |
if "[RESEARCH INTRODUCTION]" in content: | |
# System message | |
template = roles.get("system", "System: {content}\n\n") | |
formatted_chat = template.format(content=content) + formatted_chat | |
else: | |
# Alternate between human and assistant for regular conversation turns | |
# In phi-4 format, human messages come first, followed by assistant responses | |
if len(formatted_chat.split("Human:")) == len(formatted_chat.split("Assistant:")): | |
# If equal numbers of Human and Assistant messages, next is Human | |
template = roles.get("human", "Human: {content}\n\n") | |
else: | |
# Otherwise, next is Assistant | |
template = roles.get("assistant", "Assistant: {content}\n\n") | |
formatted_chat += template.format(content=content) | |
return formatted_chat | |
class SimpleDataCollator: | |
def __init__(self, tokenizer, dataset_config): | |
self.tokenizer = tokenizer | |
self.max_seq_length = min(dataset_config.get("max_seq_length", 2048), tokenizer.model_max_length) | |
self.stats = { | |
"processed": 0, | |
"skipped": 0, | |
"total_tokens": 0 | |
} | |
logger.info(f"Initialized SimpleDataCollator with max_seq_length={self.max_seq_length}") | |
def __call__(self, features): | |
# Initialize tensors on CPU to save GPU memory | |
batch = { | |
"input_ids": [], | |
"attention_mask": [], | |
"labels": [] | |
} | |
for feature in features: | |
paper_id = feature.get("article_id", "unknown") | |
prompt_num = feature.get("prompt_number", 0) | |
conversations = feature.get("conversations", []) | |
if not conversations: | |
logger.warning(f"No conversations for paper_id {paper_id}, prompt {prompt_num}") | |
self.stats["skipped"] += 1 | |
continue | |
# Get the content directly | |
content = conversations[0].get("content", "") | |
if not content: | |
logger.warning(f"Empty content for paper_id {paper_id}, prompt {prompt_num}") | |
self.stats["skipped"] += 1 | |
continue | |
# Process the content string by tokenizing it | |
if isinstance(content, str): | |
# Tokenize the content string | |
input_ids = self.tokenizer.encode(content, add_special_tokens=True) | |
else: | |
# If somehow the content is already tokenized (not a string), use it directly | |
input_ids = content | |
# Truncate if needed | |
if len(input_ids) > self.max_seq_length: | |
input_ids = input_ids[:self.max_seq_length] | |
logger.warning(f"Truncated sequence for paper_id {paper_id}, prompt {prompt_num}") | |
# Create attention mask (1s for all tokens) | |
attention_mask = [1] * len(input_ids) | |
# Add to batch | |
batch["input_ids"].append(input_ids) | |
batch["attention_mask"].append(attention_mask) | |
batch["labels"].append(input_ids.copy()) # For causal LM, labels = input_ids | |
self.stats["processed"] += 1 | |
self.stats["total_tokens"] += len(input_ids) | |
# Log statistics periodically | |
if self.stats["processed"] % 100 == 0: | |
avg_tokens = self.stats["total_tokens"] / max(1, self.stats["processed"]) | |
logger.info(f"Data collation stats: processed={self.stats['processed']}, " | |
f"skipped={self.stats['skipped']}, avg_tokens={avg_tokens:.1f}") | |
# Convert to tensors or pad sequences (PyTorch will handle) | |
if batch["input_ids"]: | |
# Pad sequences to max length in batch using the tokenizer | |
batch = self.tokenizer.pad( | |
batch, | |
padding="max_length", | |
max_length=self.max_seq_length, | |
return_tensors="pt" | |
) | |
return batch | |
else: | |
# Return empty batch if no valid examples | |
return {k: [] for k in batch} | |
def log_gpu_memory_usage(step=None, frequency=50, clear_cache_threshold=0.9, label=None): | |
""" | |
Log GPU memory usage statistics with optional cache clearing | |
Args: | |
step: Current training step (if None, logs regardless of frequency) | |
frequency: How often to log when step is provided | |
clear_cache_threshold: Fraction of memory used that triggers cache clearing (0-1) | |
label: Optional label for the log message (e.g., "Initial", "Error", "Step") | |
""" | |
if not CUDA_AVAILABLE: | |
return | |
# Only log every 'frequency' steps if step is provided | |
if step is not None and frequency > 0 and step % frequency != 0: | |
return | |
# Get memory usage for each GPU | |
memory_info = [] | |
for i in range(NUM_GPUS): | |
allocated = torch.cuda.memory_allocated(i) / (1024 ** 2) # MB | |
reserved = torch.cuda.memory_reserved(i) / (1024 ** 2) # MB | |
max_mem = torch.cuda.max_memory_allocated(i) / (1024 ** 2) # MB | |
# Calculate percentage of reserved memory that's allocated | |
usage_percent = (allocated / reserved) * 100 if reserved > 0 else 0 | |
memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB ({usage_percent:.1f}%, max: {max_mem:.1f}MB)") | |
# Automatically clear cache if over threshold | |
if clear_cache_threshold > 0 and reserved > 0 and (allocated / reserved) > clear_cache_threshold: | |
log_info(f"Clearing CUDA cache for GPU {i} - high utilization ({allocated:.1f}/{reserved:.1f}MB)") | |
with torch.cuda.device(i): | |
torch.cuda.empty_cache() | |
prefix = f"{label} " if label else "" | |
log_info(f"{prefix}GPU Memory: {', '.join(memory_info)}") | |
class LoggingCallback(TrainerCallback): | |
""" | |
Custom callback for logging training progress and metrics. | |
Provides detailed information about training status, GPU memory usage, and model performance. | |
""" | |
def __init__(self, model=None, dataset=None): | |
# Ensure we have TrainerCallback | |
try: | |
super().__init__() | |
except Exception as e: | |
# Try to import directly if initial import failed | |
try: | |
from transformers.trainer_callback import TrainerCallback | |
self.__class__.__bases__ = (TrainerCallback,) | |
super().__init__() | |
log_info("Successfully imported TrainerCallback directly") | |
except ImportError as ie: | |
log_info(f"❌ Error: Could not import TrainerCallback: {str(ie)}") | |
log_info("Please ensure transformers is properly installed") | |
raise | |
self.training_started = time.time() | |
self.last_log_time = time.time() | |
self.last_step_time = None | |
self.step_durations = [] | |
self.best_loss = float('inf') | |
self.model = model | |
self.dataset = dataset | |
def on_train_begin(self, args, state, control, **kwargs): | |
"""Called at the beginning of training""" | |
try: | |
log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===") | |
# Log model info if available | |
if self.model is not None: | |
total_params = sum(p.numel() for p in self.model.parameters()) | |
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) | |
log_info(f"Model parameters: {total_params/1e6:.2f}M total, {trainable_params/1e6:.2f}M trainable") | |
# Log dataset info if available | |
if self.dataset is not None: | |
log_info(f"Dataset size: {len(self.dataset)} examples") | |
# Log important training parameters for visibility | |
total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS | |
total_steps = int(len(self.dataset or []) / (args.per_device_train_batch_size * NUM_GPUS * args.gradient_accumulation_steps) * args.num_train_epochs) | |
log_info(f"Training plan: {len(self.dataset or [])} examples over {args.num_train_epochs} epochs ≈ {total_steps} steps") | |
log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total") | |
# Log initial GPU memory usage with label | |
log_gpu_memory_usage(label="Initial") | |
except Exception as e: | |
logger.warning(f"Error logging training begin statistics: {str(e)}") | |
def on_step_end(self, args, state, control, **kwargs): | |
"""Called at the end of each step""" | |
try: | |
if state.global_step == 1 or state.global_step % args.logging_steps == 0: | |
# Track step timing | |
current_time = time.time() | |
if self.last_step_time: | |
step_duration = current_time - self.last_step_time | |
self.step_durations.append(step_duration) | |
# Keep only last 100 steps for averaging | |
if len(self.step_durations) > 100: | |
self.step_durations.pop(0) | |
avg_step_time = sum(self.step_durations) / len(self.step_durations) | |
log_info(f"Step {state.global_step}: {step_duration:.2f}s (avg: {avg_step_time:.2f}s)") | |
self.last_step_time = current_time | |
# Log GPU memory usage with step number | |
log_gpu_memory_usage(state.global_step, args.logging_steps) | |
# Log loss | |
if state.log_history: | |
latest_logs = state.log_history[-1] if state.log_history else {} | |
if "loss" in latest_logs: | |
loss = latest_logs["loss"] | |
log_info(f"Step {state.global_step} loss: {loss:.4f}") | |
# Track best loss | |
if loss < self.best_loss: | |
self.best_loss = loss | |
log_info(f"New best loss: {loss:.4f}") | |
except Exception as e: | |
logger.warning(f"Error logging step end statistics: {str(e)}") | |
def on_train_end(self, args, state, control, **kwargs): | |
"""Called at the end of training""" | |
try: | |
# Calculate training duration | |
training_time = time.time() - self.training_started | |
hours, remainder = divmod(training_time, 3600) | |
minutes, seconds = divmod(remainder, 60) | |
log_info(f"=== Training completed at {time.strftime('%Y-%m-%d %H:%M:%S')} ===") | |
log_info(f"Training duration: {int(hours)}h {int(minutes)}m {int(seconds)}s") | |
log_info(f"Final step: {state.global_step}") | |
log_info(f"Best loss: {self.best_loss:.4f}") | |
# Log final GPU memory usage | |
log_gpu_memory_usage(label="Final") | |
except Exception as e: | |
logger.warning(f"Error logging training end statistics: {str(e)}") | |
# Other callback methods with proper error handling | |
def on_save(self, args, state, control, **kwargs): | |
"""Called when a checkpoint is saved""" | |
try: | |
log_info(f"Saving checkpoint at step {state.global_step}") | |
except Exception as e: | |
logger.warning(f"Error in on_save: {str(e)}") | |
def on_log(self, args, state, control, **kwargs): | |
"""Called when a log is created""" | |
pass | |
def on_evaluate(self, args, state, control, **kwargs): | |
"""Called when evaluation is performed""" | |
pass | |
# Only implement the methods we actually need, remove the others | |
def on_prediction_step(self, args, state, control, **kwargs): | |
"""Called when prediction is performed""" | |
pass | |
def on_save_model(self, args, state, control, **kwargs): | |
"""Called when model is saved""" | |
try: | |
# Log memory usage after saving | |
log_gpu_memory_usage(label=f"Save at step {state.global_step}") | |
except Exception as e: | |
logger.warning(f"Error in on_save_model: {str(e)}") | |
def on_epoch_end(self, args, state, control, **kwargs): | |
"""Called at the end of an epoch""" | |
try: | |
epoch = state.epoch | |
log_info(f"Completed epoch {epoch:.2f}") | |
log_gpu_memory_usage(label=f"Epoch {epoch:.2f}") | |
except Exception as e: | |
logger.warning(f"Error in on_epoch_end: {str(e)}") | |
def on_step_begin(self, args, state, control, **kwargs): | |
"""Called at the beginning of a step""" | |
pass | |
def install_flash_attention(): | |
""" | |
Attempt to install Flash Attention for improved performance. | |
Returns True if installation was successful, False otherwise. | |
""" | |
log_info("Attempting to install Flash Attention...") | |
# Check for CUDA before attempting installation | |
if not CUDA_AVAILABLE: | |
log_info("❌ Cannot install Flash Attention: CUDA not available") | |
return False | |
try: | |
# Check CUDA version to determine correct installation command | |
cuda_version = torch.version.cuda | |
if cuda_version is None: | |
log_info("❌ Cannot determine CUDA version for Flash Attention installation") | |
return False | |
import subprocess | |
# Use --no-build-isolation for better compatibility | |
install_cmd = [ | |
sys.executable, | |
"-m", | |
"pip", | |
"install", | |
"flash-attn", | |
"--no-build-isolation" | |
] | |
log_info(f"Running: {' '.join(install_cmd)}") | |
result = subprocess.run( | |
install_cmd, | |
capture_output=True, | |
text=True, | |
check=False | |
) | |
if result.returncode == 0: | |
log_info("✅ Flash Attention installed successfully!") | |
# Attempt to import to verify installation | |
try: | |
import flash_attn | |
log_info(f"✅ Flash Attention version {flash_attn.__version__} is now available") | |
return True | |
except ImportError: | |
log_info("⚠️ Flash Attention installed but import failed") | |
return False | |
else: | |
log_info(f"❌ Flash Attention installation failed with error: {result.stderr}") | |
return False | |
except Exception as e: | |
log_info(f"❌ Error installing Flash Attention: {str(e)}") | |
return False | |
def check_dependencies(): | |
""" | |
Check for required and optional dependencies, ensuring proper versions and import order. | |
Returns True if all required dependencies are present, False otherwise. | |
""" | |
# Define required packages with versions and descriptions | |
required_packages = { | |
"unsloth": {"version": ">=2024.3", "feature": "fast 4-bit quantization and LoRA"}, | |
"transformers": {"version": ">=4.38.0", "feature": "core model functionality"}, | |
"peft": {"version": ">=0.9.0", "feature": "parameter-efficient fine-tuning"}, | |
"accelerate": {"version": ">=0.27.0", "feature": "multi-GPU training"} | |
} | |
# Optional packages that enhance functionality | |
optional_packages = { | |
"flash_attn": {"feature": "faster attention computation"}, | |
"bitsandbytes": {"feature": "quantization support"}, | |
"optimum": {"feature": "model optimization"}, | |
"wandb": {"feature": "experiment tracking"} | |
} | |
# Store results | |
missing_packages = [] | |
package_versions = {} | |
order_issues = [] | |
missing_optional = [] | |
# Check required packages | |
log_info("Checking required dependencies...") | |
for package, info in required_packages.items(): | |
version_req = info["version"] | |
feature = info["feature"] | |
try: | |
# Special handling for packages we've already checked | |
if package == "unsloth" and not unsloth_available: | |
missing_packages.append(f"{package}{version_req}") | |
log_info(f"❌ {package} - {feature} MISSING") | |
continue | |
elif package == "peft" and not peft_available: | |
missing_packages.append(f"{package}{version_req}") | |
log_info(f"❌ {package} - {feature} MISSING") | |
continue | |
# Try to import and get version | |
module = __import__(package) | |
version = getattr(module, "__version__", "unknown") | |
package_versions[package] = version | |
log_info(f"✅ {package} v{version} - {feature}") | |
except ImportError: | |
missing_packages.append(f"{package}{version_req}") | |
log_info(f"❌ {package} - {feature} MISSING") | |
# Check optional packages | |
log_info("\nChecking optional dependencies...") | |
for package, info in optional_packages.items(): | |
feature = info["feature"] | |
try: | |
__import__(package) | |
log_info(f"✅ {package} - {feature} available") | |
except ImportError: | |
log_info(f"⚠️ {package} - {feature} not available") | |
missing_optional.append(package) | |
# Check import order for optimal performance | |
if "transformers" in package_versions and "unsloth" in package_versions: | |
try: | |
import sys | |
modules = list(sys.modules.keys()) | |
transformers_idx = modules.index("transformers") | |
unsloth_idx = modules.index("unsloth") | |
if transformers_idx < unsloth_idx: | |
order_issue = "⚠️ For optimal performance, import unsloth before transformers" | |
order_issues.append(order_issue) | |
log_info(order_issue) | |
log_info("This might cause performance issues but won't prevent training") | |
else: | |
log_info("✅ Import order: unsloth before transformers (optimal)") | |
except (ValueError, IndexError) as e: | |
log_info(f"⚠️ Could not verify import order: {str(e)}") | |
# Try to install missing optional packages | |
if "flash_attn" in missing_optional and CUDA_AVAILABLE: | |
log_info("\nFlash Attention is missing but would improve performance.") | |
install_result = install_flash_attention() | |
if install_result: | |
missing_optional.remove("flash_attn") | |
# Report missing required packages | |
if missing_packages: | |
log_info("\n❌ Critical dependencies missing:") | |
for pkg in missing_packages: | |
log_info(f" - {pkg}") | |
log_info("Please install missing dependencies with:") | |
log_info(f" pip install {' '.join(missing_packages)}") | |
return False | |
log_info("\n✅ All required dependencies satisfied!") | |
return True | |
def get_config_value(config, path, default=None): | |
""" | |
Safely get a nested value from a config dictionary using a dot-separated path. | |
Args: | |
config: The configuration dictionary | |
path: Dot-separated path to the value (e.g., "training.optimizer.lr") | |
default: Default value to return if path doesn't exist | |
Returns: | |
The value at the specified path or the default value | |
""" | |
if not config: | |
return default | |
parts = path.split('.') | |
current = config | |
for part in parts: | |
if isinstance(current, dict) and part in current: | |
current = current[part] | |
else: | |
return default | |
return current | |
def update_huggingface_space(): | |
"""Update the Hugging Face Space with the current code.""" | |
log_info("Updating Hugging Face Space...") | |
update_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "update_space.py") | |
if not os.path.exists(update_script): | |
logger.warning(f"Update space script not found at {update_script}") | |
return False | |
try: | |
import subprocess | |
# Explicitly set space_name to ensure we're targeting the right Space | |
result = subprocess.run( | |
[sys.executable, update_script, "--force", "--space_name", "phi4training"], | |
capture_output=True, text=True, check=False | |
) | |
if result.returncode == 0: | |
log_info("Hugging Face Space updated successfully!") | |
log_info(f"Space URL: https://huggingface.co/spaces/George-API/phi4training") | |
return True | |
else: | |
logger.error(f"Failed to update Hugging Face Space: {result.stderr}") | |
return False | |
except Exception as e: | |
logger.error(f"Error updating Hugging Face Space: {str(e)}") | |
return False | |
def validate_huggingface_credentials(): | |
"""Validate Hugging Face credentials to ensure they work correctly.""" | |
if not os.environ.get("HF_TOKEN"): | |
logger.warning("HF_TOKEN not found. Skipping Hugging Face credentials validation.") | |
return False | |
try: | |
# Import here to avoid requiring huggingface_hub if not needed | |
from huggingface_hub import HfApi, login | |
# Try to login with the token | |
login(token=os.environ.get("HF_TOKEN")) | |
# Check if we can access the API | |
api = HfApi() | |
username = os.environ.get("HF_USERNAME", "George-API") | |
space_name = os.environ.get("HF_SPACE_NAME", "phi4training") | |
# Try to get whoami info | |
user_info = api.whoami() | |
logger.info(f"Successfully authenticated with Hugging Face as {user_info['name']}") | |
# Check if we're using the expected Space | |
expected_space_id = "George-API/phi4training" | |
actual_space_id = f"{username}/{space_name}" | |
if actual_space_id != expected_space_id: | |
logger.warning(f"Using Space '{actual_space_id}' instead of the expected '{expected_space_id}'") | |
logger.warning(f"Make sure this is intentional. To use the correct Space, update your .env file.") | |
else: | |
logger.info(f"Confirmed using Space: {expected_space_id}") | |
# Check if the space exists | |
try: | |
space_id = f"{username}/{space_name}" | |
space_info = api.space_info(repo_id=space_id) | |
logger.info(f"Space {space_id} is accessible at: https://huggingface.co/spaces/{space_id}") | |
return True | |
except Exception as e: | |
logger.warning(f"Could not access Space {username}/{space_name}: {str(e)}") | |
logger.warning("Space updating may not work correctly") | |
return False | |
except ImportError: | |
logger.warning("huggingface_hub not installed. Cannot validate Hugging Face credentials.") | |
return False | |
except Exception as e: | |
logger.warning(f"Error validating Hugging Face credentials: {str(e)}") | |
return False | |
def setup_environment(args): | |
""" | |
Set up the training environment including logging, seed, and configurations. | |
Args: | |
args: Command line arguments | |
Returns: | |
tuple: (transformers_config, seed) - The loaded configuration and random seed | |
""" | |
# Load environment variables first | |
load_env_variables() | |
# Set random seed for reproducibility | |
seed = args.seed if args.seed is not None else int(time.time()) % 10000 | |
set_seed(seed) | |
log_info(f"Using random seed: {seed}") | |
# Load configuration | |
base_path = os.path.dirname(os.path.abspath(__file__)) | |
config_file = args.config_file or os.path.join(base_path, "transformers_config.json") | |
if not os.path.exists(config_file): | |
raise FileNotFoundError(f"Config file not found: {config_file}") | |
log_info(f"Loading configuration from {config_file}") | |
transformers_config = load_configs(config_file) | |
# Set up hardware environment variables if CUDA is available | |
if CUDA_AVAILABLE: | |
memory_fraction = get_config_value(transformers_config, "hardware.system_settings.cuda_memory_fraction", 0.75) | |
if memory_fraction < 1.0: | |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:128,expandable_segments:True" | |
log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128") | |
# Check dependencies and install optional ones if needed | |
if not check_dependencies(): | |
raise RuntimeError("Critical dependencies missing") | |
# Check if flash attention was successfully installed | |
flash_attention_available = False | |
try: | |
import flash_attn | |
flash_attention_available = True | |
log_info(f"Flash Attention will be used (version: {flash_attn.__version__})") | |
# Update config to use flash attention | |
if "use_flash_attention" not in transformers_config: | |
transformers_config["use_flash_attention"] = True | |
except ImportError: | |
log_info("Flash Attention not available, will use standard attention mechanism") | |
return transformers_config, seed | |
def setup_model_and_tokenizer(config): | |
""" | |
Load and configure the model and tokenizer. | |
Args: | |
config (dict): Complete configuration dictionary | |
Returns: | |
tuple: (model, tokenizer) - The loaded model and tokenizer | |
""" | |
# Extract model configuration | |
model_config = get_config_value(config, "model", {}) | |
model_name = get_config_value(model_config, "name", "unsloth/phi-4-unsloth-bnb-4bit") | |
use_fast_tokenizer = get_config_value(model_config, "use_fast_tokenizer", True) | |
trust_remote_code = get_config_value(model_config, "trust_remote_code", True) | |
model_revision = get_config_value(config, "model_revision", "main") | |
# Detect if model is already pre-quantized (includes '4bit', 'bnb', or 'int4' in name) | |
is_prequantized = any(q in model_name.lower() for q in ['4bit', 'bnb', 'int4', 'quant']) | |
if is_prequantized: | |
log_info("⚠️ Detected pre-quantized model. No additional quantization will be applied.") | |
# Unsloth configuration | |
unsloth_config = get_config_value(config, "unsloth", {}) | |
unsloth_enabled = get_config_value(unsloth_config, "enabled", True) | |
# Tokenizer configuration | |
tokenizer_config = get_config_value(config, "tokenizer", {}) | |
max_seq_length = min( | |
get_config_value(tokenizer_config, "max_seq_length", 2048), | |
4096 # Maximum supported by most models | |
) | |
add_eos_token = get_config_value(tokenizer_config, "add_eos_token", True) | |
chat_template = get_config_value(tokenizer_config, "chat_template", None) | |
padding_side = get_config_value(tokenizer_config, "padding_side", "right") | |
# Check for flash attention | |
use_flash_attention = get_config_value(config, "use_flash_attention", False) | |
flash_attention_available = False | |
try: | |
import flash_attn | |
flash_attention_available = True | |
log_info(f"Flash Attention detected (version: {flash_attn.__version__})") | |
except ImportError: | |
if use_flash_attention: | |
log_info("Flash Attention requested but not available") | |
log_info(f"Loading model: {model_name} (revision: {model_revision})") | |
log_info(f"Max sequence length: {max_seq_length}") | |
try: | |
if unsloth_enabled and unsloth_available: | |
log_info("Using Unsloth for LoRA fine-tuning") | |
if is_prequantized: | |
log_info("Using pre-quantized model - no additional quantization will be applied") | |
else: | |
log_info("Using 4-bit quantization for efficient training") | |
# Load using Unsloth | |
from unsloth import FastLanguageModel | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name=model_name, | |
max_seq_length=max_seq_length, | |
dtype=get_config_value(config, "torch_dtype", "bfloat16"), | |
revision=model_revision, | |
trust_remote_code=trust_remote_code, | |
use_flash_attention_2=use_flash_attention and flash_attention_available | |
) | |
# Configure tokenizer settings | |
tokenizer.padding_side = padding_side | |
if add_eos_token and tokenizer.eos_token is None: | |
log_info("Setting EOS token") | |
tokenizer.add_special_tokens({"eos_token": "</s>"}) | |
# Set chat template if specified | |
if chat_template: | |
log_info(f"Setting chat template: {chat_template}") | |
if hasattr(tokenizer, "chat_template"): | |
tokenizer.chat_template = chat_template | |
else: | |
log_info("Tokenizer does not support chat templates, using default formatting") | |
# Apply LoRA | |
lora_r = get_config_value(unsloth_config, "r", 16) | |
lora_alpha = get_config_value(unsloth_config, "alpha", 32) | |
lora_dropout = get_config_value(unsloth_config, "dropout", 0) | |
target_modules = get_config_value(unsloth_config, "target_modules", | |
["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]) | |
log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}") | |
model = FastLanguageModel.get_peft_model( | |
model, | |
r=lora_r, | |
target_modules=target_modules, | |
lora_alpha=lora_alpha, | |
lora_dropout=lora_dropout, | |
bias="none", | |
use_gradient_checkpointing=get_config_value(config, "training.gradient_checkpointing", True), | |
random_state=0, | |
max_seq_length=max_seq_length, | |
modules_to_save=None | |
) | |
if use_flash_attention and flash_attention_available: | |
log_info("🚀 Using Flash Attention for faster training") | |
elif use_flash_attention and not flash_attention_available: | |
log_info("⚠️ Flash Attention requested but not available - using standard attention") | |
else: | |
# Standard HuggingFace loading | |
log_info("Using standard HuggingFace model loading (Unsloth not available or disabled)") | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Check if flash attention should be enabled in config | |
use_attn_implementation = None | |
if use_flash_attention and flash_attention_available: | |
use_attn_implementation = "flash_attention_2" | |
log_info("🚀 Using Flash Attention for faster training") | |
# Load tokenizer first | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
trust_remote_code=trust_remote_code, | |
use_fast=use_fast_tokenizer, | |
revision=model_revision, | |
padding_side=padding_side | |
) | |
# Configure tokenizer settings | |
if add_eos_token and tokenizer.eos_token is None: | |
log_info("Setting EOS token") | |
tokenizer.add_special_tokens({"eos_token": "</s>"}) | |
# Set chat template if specified | |
if chat_template: | |
log_info(f"Setting chat template: {chat_template}") | |
if hasattr(tokenizer, "chat_template"): | |
tokenizer.chat_template = chat_template | |
else: | |
log_info("Tokenizer does not support chat templates, using default formatting") | |
# Only apply quantization config if model is not already pre-quantized | |
quantization_config = None | |
if not is_prequantized and CUDA_AVAILABLE: | |
try: | |
from transformers import BitsAndBytesConfig | |
log_info("Using 4-bit quantization (BitsAndBytes) for efficient training") | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True | |
) | |
except ImportError: | |
log_info("BitsAndBytes not available - quantization disabled") | |
# Now load model with updated tokenizer | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=trust_remote_code, | |
revision=model_revision, | |
torch_dtype=torch.bfloat16 if get_config_value(config, "torch_dtype", "bfloat16") == "bfloat16" else torch.float16, | |
device_map="auto" if CUDA_AVAILABLE else None, | |
attn_implementation=use_attn_implementation, | |
quantization_config=quantization_config | |
) | |
# Apply PEFT/LoRA if enabled but using standard loading | |
if peft_available and get_config_value(unsloth_config, "enabled", True): | |
log_info("Applying standard PEFT/LoRA configuration") | |
from peft import LoraConfig, get_peft_model | |
lora_r = get_config_value(unsloth_config, "r", 16) | |
lora_alpha = get_config_value(unsloth_config, "alpha", 32) | |
lora_dropout = get_config_value(unsloth_config, "dropout", 0) | |
target_modules = get_config_value(unsloth_config, "target_modules", | |
["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]) | |
log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}") | |
lora_config = LoraConfig( | |
r=lora_r, | |
lora_alpha=lora_alpha, | |
target_modules=target_modules, | |
lora_dropout=lora_dropout, | |
bias="none", | |
task_type="CAUSAL_LM" | |
) | |
model = get_peft_model(model, lora_config) | |
# Print model summary | |
log_info(f"Model loaded successfully: {model.__class__.__name__}") | |
if hasattr(model, "print_trainable_parameters"): | |
model.print_trainable_parameters() | |
else: | |
total_params = sum(p.numel() for p in model.parameters()) | |
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
log_info(f"Model has {total_params:,} parameters, {trainable_params:,} trainable ({trainable_params/total_params:.2%})") | |
return model, tokenizer | |
except Exception as e: | |
log_info(f"Error loading model: {str(e)}") | |
traceback.print_exc() | |
return None, None | |
def setup_dataset_and_collator(config, tokenizer): | |
""" | |
Load and configure the dataset and data collator. | |
Args: | |
config: Complete configuration dictionary | |
tokenizer: The tokenizer for the data collator | |
Returns: | |
tuple: (dataset, data_collator) - The loaded dataset and configured data collator | |
""" | |
dataset_config = get_config_value(config, "dataset", {}) | |
log_info("Loading dataset...") | |
dataset = load_dataset_with_mapping(dataset_config) | |
# Validate dataset | |
if dataset is None: | |
raise ValueError("Dataset is None! Cannot proceed with training.") | |
if not hasattr(dataset, '__len__') or len(dataset) == 0: | |
raise ValueError("Dataset is empty! Cannot proceed with training.") | |
log_info(f"Dataset loaded with {len(dataset)} examples") | |
# Create data collator | |
data_collator = SimpleDataCollator(tokenizer, dataset_config) | |
return dataset, data_collator | |
def create_training_arguments(config, dataset): | |
""" | |
Create and configure training arguments for the Trainer. | |
Args: | |
config: Complete configuration dictionary | |
dataset: The dataset to determine total steps | |
Returns: | |
TrainingArguments: Configured training arguments | |
""" | |
# Extract configuration sections | |
training_config = get_config_value(config, "training", {}) | |
hardware_config = get_config_value(config, "hardware", {}) | |
huggingface_config = get_config_value(config, "huggingface_hub", {}) | |
distributed_config = get_config_value(config, "distributed_training", {}) | |
# Extract key training parameters | |
per_device_batch_size = get_config_value(training_config, "per_device_train_batch_size", 4) | |
gradient_accumulation_steps = get_config_value(training_config, "gradient_accumulation_steps", 8) | |
learning_rate = get_config_value(training_config, "learning_rate", 2e-5) | |
num_train_epochs = get_config_value(training_config, "num_train_epochs", 3) | |
# Extract hardware settings | |
dataloader_workers = get_config_value(hardware_config, "system_settings.dataloader_num_workers", | |
get_config_value(distributed_config, "dataloader_num_workers", 2)) | |
pin_memory = get_config_value(hardware_config, "system_settings.dataloader_pin_memory", True) | |
# BF16/FP16 settings - ensure only one is enabled | |
use_bf16 = get_config_value(training_config, "bf16", False) | |
use_fp16 = get_config_value(training_config, "fp16", False) if not use_bf16 else False | |
# Configure distributed training | |
fsdp_config = get_config_value(distributed_config, "fsdp_config", {}) | |
fsdp_enabled = get_config_value(fsdp_config, "enabled", False) | |
ddp_config = get_config_value(distributed_config, "ddp_config", {}) | |
ddp_find_unused_parameters = get_config_value(ddp_config, "find_unused_parameters", False) | |
# Set up FSDP args if enabled | |
fsdp_args = None | |
if fsdp_enabled and NUM_GPUS > 1: | |
from accelerate import FullyShardedDataParallelPlugin | |
from torch.distributed.fsdp.fully_sharded_data_parallel import ( | |
FullOptimStateDictConfig, FullStateDictConfig | |
) | |
fsdp_plugin = FullyShardedDataParallelPlugin( | |
sharding_strategy=get_config_value(fsdp_config, "sharding_strategy", "FULL_SHARD"), | |
mixed_precision_policy=get_config_value(fsdp_config, "mixed_precision", "BF16"), | |
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), | |
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), | |
) | |
fsdp_args = { | |
"fsdp": fsdp_plugin, | |
"fsdp_transformer_layer_cls_to_wrap": ["LlamaDecoderLayer", "PhiDecoderLayer"] | |
} | |
# Create and return training arguments | |
training_args = TrainingArguments( | |
output_dir=get_config_value(config, "checkpointing.output_dir", "./results"), | |
overwrite_output_dir=True, | |
num_train_epochs=num_train_epochs, | |
per_device_train_batch_size=per_device_batch_size, | |
gradient_accumulation_steps=gradient_accumulation_steps, | |
learning_rate=learning_rate, | |
weight_decay=get_config_value(training_config, "weight_decay", 0.01), | |
max_grad_norm=get_config_value(training_config, "max_grad_norm", 1.0), | |
warmup_ratio=get_config_value(training_config, "warmup_ratio", 0.03), | |
lr_scheduler_type=get_config_value(training_config, "lr_scheduler_type", "cosine"), | |
logging_steps=get_config_value(training_config, "logging_steps", 10), | |
save_strategy=get_config_value(config, "checkpointing.save_strategy", "steps"), | |
save_steps=get_config_value(config, "checkpointing.save_steps", 500), | |
save_total_limit=get_config_value(config, "checkpointing.save_total_limit", 3), | |
bf16=use_bf16, | |
fp16=use_fp16, | |
push_to_hub=get_config_value(huggingface_config, "push_to_hub", False), | |
hub_model_id=get_config_value(huggingface_config, "hub_model_id", None), | |
hub_strategy=get_config_value(huggingface_config, "hub_strategy", "every_save"), | |
hub_private_repo=get_config_value(huggingface_config, "hub_private_repo", True), | |
gradient_checkpointing=get_config_value(training_config, "gradient_checkpointing", True), | |
dataloader_pin_memory=pin_memory, | |
optim=get_config_value(training_config, "optim", "adamw_torch"), | |
ddp_find_unused_parameters=ddp_find_unused_parameters, | |
dataloader_drop_last=False, | |
dataloader_num_workers=dataloader_workers, | |
no_cuda=False if CUDA_AVAILABLE else True, | |
**({} if fsdp_args is None else fsdp_args) | |
) | |
log_info("Training arguments created successfully") | |
return training_args | |
def configure_custom_dataloader(trainer, dataset, config, training_args): | |
""" | |
Configure a custom dataloader for the trainer if needed. | |
Args: | |
trainer: The Trainer instance to configure | |
dataset: The dataset to use | |
config: Complete configuration dictionary | |
training_args: The training arguments | |
Returns: | |
None (modifies trainer in-place) | |
""" | |
dataset_config = get_config_value(config, "dataset", {}) | |
# Check if we need a custom dataloader | |
if get_config_value(dataset_config, "data_loading.sequential_processing", True): | |
log_info("Using custom sequential dataloader") | |
# Create sequential sampler to maintain dataset order | |
sequential_sampler = torch.utils.data.SequentialSampler(dataset) | |
log_info("Sequential sampler created") | |
# Define custom dataloader getter | |
def custom_get_train_dataloader(): | |
"""Create a custom dataloader that maintains dataset order""" | |
# Get configuration values | |
batch_size = training_args.per_device_train_batch_size | |
drop_last = get_config_value(dataset_config, "data_loading.drop_last", False) | |
num_workers = training_args.dataloader_num_workers | |
pin_memory = training_args.dataloader_pin_memory | |
prefetch_factor = get_config_value(dataset_config, "data_loading.prefetch_factor", 2) | |
persistent_workers = get_config_value(dataset_config, "data_loading.persistent_workers", False) | |
# Create DataLoader with sequential sampler | |
return DataLoader( | |
dataset, | |
batch_size=batch_size, | |
sampler=sequential_sampler, | |
collate_fn=trainer.data_collator, | |
drop_last=drop_last, | |
num_workers=num_workers, | |
pin_memory=pin_memory, | |
prefetch_factor=prefetch_factor if num_workers > 0 else None, | |
persistent_workers=persistent_workers if num_workers > 0 else False, | |
) | |
# Override the default dataloader | |
trainer.get_train_dataloader = custom_get_train_dataloader | |
def run_training(trainer, tokenizer, training_args): | |
""" | |
Run the training process and handle model saving. | |
Args: | |
trainer: Configured Trainer instance | |
tokenizer: The tokenizer to save with the model | |
training_args: Training arguments | |
Returns: | |
int: 0 for success, 1 for failure | |
""" | |
log_info("Starting training...") | |
trainer.train() | |
log_info("Training complete! Saving final model...") | |
trainer.save_model() | |
tokenizer.save_pretrained(training_args.output_dir) | |
# Push to Hub if configured | |
if training_args.push_to_hub: | |
log_info(f"Pushing model to Hugging Face Hub: {training_args.hub_model_id}") | |
trainer.push_to_hub() | |
log_info("Training completed successfully!") | |
return 0 | |
def main(): | |
""" | |
Main entry point for the training script. | |
Returns: | |
int: 0 for success, non-zero for failure | |
""" | |
# Set up logging | |
logger.info("Starting training process") | |
try: | |
# Verify critical imports are available | |
if not transformers_available: | |
log_info("❌ Error: transformers library not available. Please install it with: pip install transformers") | |
return 1 | |
# Check for required classes | |
for required_class in ["Trainer", "TrainingArguments", "TrainerCallback"]: | |
if not hasattr(transformers, required_class): | |
log_info(f"❌ Error: {required_class} not found in transformers. Please update transformers.") | |
return 1 | |
# Check for potential import order issue and warn early | |
if "transformers" in sys.modules and "unsloth" in sys.modules: | |
if list(sys.modules.keys()).index("transformers") < list(sys.modules.keys()).index("unsloth"): | |
log_info("⚠️ Warning: transformers was imported before unsloth. This may affect performance.") | |
log_info(" For optimal performance in future runs, import unsloth first.") | |
# Parse command line arguments | |
args = parse_args() | |
# Set up environment and load configuration | |
transformers_config, seed = setup_environment(args) | |
# Load model and tokenizer | |
try: | |
model, tokenizer = setup_model_and_tokenizer(transformers_config) | |
except Exception as e: | |
logger.error(f"Error setting up model: {str(e)}") | |
return 1 | |
# Load dataset and create data collator | |
try: | |
dataset, data_collator = setup_dataset_and_collator(transformers_config, tokenizer) | |
except Exception as e: | |
logger.error(f"Error setting up dataset: {str(e)}") | |
return 1 | |
# Configure training arguments | |
try: | |
training_args = create_training_arguments(transformers_config, dataset) | |
except Exception as e: | |
logger.error(f"Error configuring training arguments: {str(e)}") | |
return 1 | |
# Initialize trainer with callbacks | |
log_info("Initializing Trainer") | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=dataset, | |
data_collator=data_collator, | |
callbacks=[LoggingCallback(model=model, dataset=dataset)], | |
) | |
# Configure custom dataloader if needed | |
try: | |
configure_custom_dataloader(trainer, dataset, transformers_config, training_args) | |
except Exception as e: | |
logger.error(f"Error configuring custom dataloader: {str(e)}") | |
return 1 | |
# Run training process | |
try: | |
return run_training(trainer, tokenizer, training_args) | |
except Exception as e: | |
logger.error(f"Training failed with error: {str(e)}") | |
# Log GPU memory for debugging | |
log_gpu_memory_usage(label="Error") | |
# Print full stack trace | |
traceback.print_exc() | |
return 1 | |
except Exception as e: | |
logger.error(f"Error in main function: {str(e)}") | |
traceback.print_exc() | |
return 1 | |
if __name__ == "__main__": | |
sys.exit(main()) | |