hf-train-frontend / run_transformers_training.py
George-API's picture
Upload folder using huggingface_hub
bbeed83 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Basic Python imports
import os
import sys
import json
import argparse
import logging
from datetime import datetime
import time
import warnings
import traceback
from importlib.util import find_spec
import multiprocessing
import torch
import random
import numpy as np
from tqdm import tqdm
# Check hardware capabilities first
CUDA_AVAILABLE = "CUDA_VISIBLE_DEVICES" in os.environ or os.environ.get("NVIDIA_VISIBLE_DEVICES") != ""
NUM_GPUS = torch.cuda.device_count() if CUDA_AVAILABLE else 0
DEVICE_TYPE = "cuda" if CUDA_AVAILABLE else "cpu"
# Set the multiprocessing start method to 'spawn' for CUDA compatibility
if CUDA_AVAILABLE:
try:
multiprocessing.set_start_method('spawn', force=True)
print("Set multiprocessing start method to 'spawn' for CUDA compatibility")
except RuntimeError:
# Method already set, which is fine
print("Multiprocessing start method already set")
# Import order is important: unsloth should be imported before transformers
# Check for libraries without importing them
unsloth_available = find_spec("unsloth") is not None
if unsloth_available:
import unsloth
# Import torch first, then transformers if available
import torch
transformers_available = find_spec("transformers") is not None
if transformers_available:
import transformers
from transformers import AutoTokenizer, TrainingArguments, Trainer, TrainerCallback, set_seed
from torch.utils.data import DataLoader
peft_available = find_spec("peft") is not None
if peft_available:
import peft
# Only import HF datasets if available
datasets_available = find_spec("datasets") is not None
if datasets_available:
from datasets import load_dataset
# Set up the logger
logger = logging.getLogger(__name__)
log_handler = logging.StreamHandler()
log_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
log_handler.setFormatter(log_format)
logger.addHandler(log_handler)
logger.setLevel(logging.INFO)
# Define a clean logging function for HF Space compatibility
def log_info(message):
"""Log information in a format compatible with Hugging Face Spaces"""
# Just use the logger, but ensure consistent formatting
logger.info(message)
# Also ensure output is flushed immediately for streaming
sys.stdout.flush()
# Check for BitsAndBytes
try:
from transformers import BitsAndBytesConfig
bitsandbytes_available = True
except ImportError:
bitsandbytes_available = False
logger.warning("BitsAndBytes not available. 4-bit quantization will not be used.")
# Check for PEFT
try:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
peft_available = True
except ImportError:
peft_available = False
logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.")
def load_env_variables():
"""Load environment variables from system, .env file, or Hugging Face Space variables."""
# Check if we're running in a Hugging Face Space
if os.environ.get("SPACE_ID"):
logging.info("Running in Hugging Face Space")
# Log the presence of variables (without revealing values)
logging.info(f"HF_TOKEN available: {bool(os.environ.get('HF_TOKEN'))}")
logging.info(f"HF_USERNAME available: {bool(os.environ.get('HF_USERNAME'))}")
# If username is not set, try to extract from SPACE_ID
if not os.environ.get("HF_USERNAME") and "/" in os.environ.get("SPACE_ID", ""):
username = os.environ.get("SPACE_ID").split("/")[0]
os.environ["HF_USERNAME"] = username
logging.info(f"Set HF_USERNAME from SPACE_ID: {username}")
else:
# Try to load from .env file if not in a Space
try:
from dotenv import load_dotenv
# First check the current directory
env_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".env")
if os.path.exists(env_path):
load_dotenv(env_path)
logging.info(f"Loaded environment variables from {env_path}")
logging.info(f"HF_TOKEN loaded from .env file: {bool(os.environ.get('HF_TOKEN'))}")
logging.info(f"HF_USERNAME loaded from .env file: {bool(os.environ.get('HF_USERNAME'))}")
logging.info(f"HF_SPACE_NAME loaded from .env file: {bool(os.environ.get('HF_SPACE_NAME'))}")
else:
# Try the shared directory as fallback
shared_env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "shared", ".env")
if os.path.exists(shared_env_path):
load_dotenv(shared_env_path)
logging.info(f"Loaded environment variables from {shared_env_path}")
logging.info(f"HF_TOKEN loaded from shared .env file: {bool(os.environ.get('HF_TOKEN'))}")
logging.info(f"HF_USERNAME loaded from shared .env file: {bool(os.environ.get('HF_USERNAME'))}")
logging.info(f"HF_SPACE_NAME loaded from shared .env file: {bool(os.environ.get('HF_SPACE_NAME'))}")
else:
logging.warning(f"No .env file found in current or shared directory")
except ImportError:
logging.warning("python-dotenv not installed, not loading from .env file")
if not os.environ.get("HF_TOKEN"):
logger.warning("HF_TOKEN is not set. Pushing to Hugging Face Hub will not work.")
if not os.environ.get("HF_USERNAME"):
logger.warning("HF_USERNAME is not set. Using default username.")
if not os.environ.get("HF_SPACE_NAME"):
logger.warning("HF_SPACE_NAME is not set. Using default space name.")
# Set HF_TOKEN for huggingface_hub
if os.environ.get("HF_TOKEN"):
os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN")
def load_configs(base_path):
"""Load configuration from transformers_config.json file."""
# Using a single consolidated config file
config_file = base_path
try:
with open(config_file, "r") as f:
config = json.load(f)
logger.info(f"Loaded configuration from {config_file}")
return config
except Exception as e:
logger.error(f"Error loading {config_file}: {e}")
raise
def parse_args():
"""
Parse command line arguments for the training script.
Returns:
argparse.Namespace: The parsed command line arguments
"""
parser = argparse.ArgumentParser(description="Run training for language models")
parser.add_argument(
"--config_file",
type=str,
default=None,
help="Path to the configuration file (default: transformers_config.json in script directory)"
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Random seed for reproducibility (default: based on current time)"
)
parser.add_argument(
"--log_level",
type=str,
choices=["debug", "info", "warning", "error", "critical"],
default="info",
help="Logging level (default: info)"
)
return parser.parse_args()
def load_model_and_tokenizer(config):
"""
Load the model and tokenizer according to the configuration.
Args:
config (dict): Complete configuration dictionary
Returns:
tuple: (model, tokenizer) - The loaded model and tokenizer
"""
# Extract model configuration
model_config = get_config_value(config, "model", {})
model_name = get_config_value(model_config, "name", "unsloth/phi-4-unsloth-bnb-4bit")
use_fast_tokenizer = get_config_value(model_config, "use_fast_tokenizer", True)
trust_remote_code = get_config_value(model_config, "trust_remote_code", True)
model_revision = get_config_value(config, "model_revision", "main")
# Unsloth configuration
unsloth_config = get_config_value(config, "unsloth", {})
unsloth_enabled = get_config_value(unsloth_config, "enabled", True)
# Tokenizer configuration
tokenizer_config = get_config_value(config, "tokenizer", {})
max_seq_length = min(
get_config_value(tokenizer_config, "max_seq_length", 2048),
4096 # Maximum supported by most models
)
add_eos_token = get_config_value(tokenizer_config, "add_eos_token", True)
chat_template = get_config_value(tokenizer_config, "chat_template", None)
padding_side = get_config_value(tokenizer_config, "padding_side", "right")
# Check for flash attention
use_flash_attention = get_config_value(config, "use_flash_attention", False)
flash_attention_available = False
try:
import flash_attn
flash_attention_available = True
log_info(f"Flash Attention detected (version: {flash_attn.__version__})")
except ImportError:
if use_flash_attention:
log_info("Flash Attention requested but not available")
log_info(f"Loading model: {model_name} (revision: {model_revision})")
log_info(f"Max sequence length: {max_seq_length}")
try:
if unsloth_enabled and unsloth_available:
log_info("Using Unsloth for 4-bit quantized model and LoRA")
# Load using Unsloth
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
dtype=get_config_value(config, "torch_dtype", "bfloat16"),
revision=model_revision,
trust_remote_code=trust_remote_code,
use_flash_attention_2=use_flash_attention and flash_attention_available
)
# Configure tokenizer settings
tokenizer.padding_side = padding_side
if add_eos_token and tokenizer.eos_token is None:
log_info("Setting EOS token")
tokenizer.add_special_tokens({"eos_token": "</s>"})
# Set chat template if specified
if chat_template:
log_info(f"Setting chat template: {chat_template}")
if hasattr(tokenizer, "chat_template"):
tokenizer.chat_template = chat_template
else:
log_info("Tokenizer does not support chat templates, using default formatting")
# Apply LoRA
lora_r = get_config_value(unsloth_config, "r", 16)
lora_alpha = get_config_value(unsloth_config, "alpha", 32)
lora_dropout = get_config_value(unsloth_config, "dropout", 0)
target_modules = get_config_value(unsloth_config, "target_modules",
["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}")
model = FastLanguageModel.get_peft_model(
model,
r=lora_r,
target_modules=target_modules,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
bias="none",
use_gradient_checkpointing=get_config_value(config, "training.gradient_checkpointing", True),
random_state=0,
max_seq_length=max_seq_length,
modules_to_save=None
)
if use_flash_attention and flash_attention_available:
log_info("🚀 Using Flash Attention for faster training")
elif use_flash_attention and not flash_attention_available:
log_info("⚠️ Flash Attention requested but not available - using standard attention")
else:
# Standard HuggingFace loading
log_info("Using standard HuggingFace model loading (Unsloth not available or disabled)")
from transformers import AutoModelForCausalLM, AutoTokenizer
# Check if flash attention should be enabled in config
use_attn_implementation = None
if use_flash_attention and flash_attention_available:
use_attn_implementation = "flash_attention_2"
log_info("🚀 Using Flash Attention for faster training")
# Load tokenizer first
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
use_fast=use_fast_tokenizer,
revision=model_revision,
padding_side=padding_side
)
# Configure tokenizer settings
if add_eos_token and tokenizer.eos_token is None:
log_info("Setting EOS token")
tokenizer.add_special_tokens({"eos_token": "</s>"})
# Set chat template if specified
if chat_template:
log_info(f"Setting chat template: {chat_template}")
if hasattr(tokenizer, "chat_template"):
tokenizer.chat_template = chat_template
else:
log_info("Tokenizer does not support chat templates, using default formatting")
# Now load model with updated tokenizer
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
revision=model_revision,
torch_dtype=torch.bfloat16 if get_config_value(config, "torch_dtype", "bfloat16") == "bfloat16" else torch.float16,
device_map="auto" if CUDA_AVAILABLE else None,
attn_implementation=use_attn_implementation
)
# Apply PEFT/LoRA if enabled but using standard loading
if peft_available and get_config_value(unsloth_config, "enabled", True):
log_info("Applying standard PEFT/LoRA configuration")
from peft import LoraConfig, get_peft_model
lora_r = get_config_value(unsloth_config, "r", 16)
lora_alpha = get_config_value(unsloth_config, "alpha", 32)
lora_dropout = get_config_value(unsloth_config, "dropout", 0)
target_modules = get_config_value(unsloth_config, "target_modules",
["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}")
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
# Print model summary
log_info(f"Model loaded successfully: {model.__class__.__name__}")
if hasattr(model, "print_trainable_parameters"):
model.print_trainable_parameters()
else:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
log_info(f"Model has {total_params:,} parameters, {trainable_params:,} trainable ({trainable_params/total_params:.2%})")
return model, tokenizer
except Exception as e:
log_info(f"Error loading model: {str(e)}")
traceback.print_exc()
return None, None
def load_dataset_with_mapping(config):
"""
Load dataset from Hugging Face or local files and apply necessary transformations.
Args:
config (dict): Dataset configuration dictionary
Returns:
Dataset: The loaded and processed dataset
"""
# Extract dataset configuration
dataset_info = get_config_value(config, "dataset", {})
dataset_name = get_config_value(dataset_info, "name", None)
dataset_split = get_config_value(dataset_info, "split", "train")
# Data formatting configuration
formatting_config = get_config_value(config, "data_formatting", {})
if not dataset_name:
raise ValueError("Dataset name not specified in config")
log_info(f"Loading dataset: {dataset_name} (split: {dataset_split})")
try:
# Load dataset from Hugging Face or local path
from datasets import load_dataset
# Check if it's a local path or Hugging Face dataset
if os.path.exists(dataset_name) or os.path.exists(os.path.join(os.getcwd(), dataset_name)):
log_info(f"Loading dataset from local path: {dataset_name}")
# Local dataset - check if it's a directory or file
if os.path.isdir(dataset_name):
# Directory - look for data files
dataset = load_dataset(
"json",
data_files={"train": os.path.join(dataset_name, "*.json")},
split=dataset_split
)
else:
# Single file
dataset = load_dataset(
"json",
data_files={"train": dataset_name},
split=dataset_split
)
else:
# Hugging Face dataset
log_info(f"Loading dataset from Hugging Face: {dataset_name}")
dataset = load_dataset(dataset_name, split=dataset_split)
log_info(f"Dataset loaded with {len(dataset)} examples")
# Check if dataset contains required fields
required_fields = ["conversations"]
missing_fields = [field for field in required_fields if field not in dataset.column_names]
if missing_fields:
log_info(f"WARNING: Dataset missing required fields: {missing_fields}")
log_info("Attempting to map dataset structure to required format")
# Implement conversion logic based on dataset structure
if "messages" in dataset.column_names:
log_info("Converting 'messages' field to 'conversations' format")
dataset = dataset.map(
lambda x: {"conversations": x["messages"]},
remove_columns=["messages"]
)
elif "text" in dataset.column_names:
log_info("Converting plain text to conversations format")
dataset = dataset.map(
lambda x: {"conversations": [{"role": "user", "content": x["text"]}]},
remove_columns=["text"]
)
else:
raise ValueError(f"Cannot convert dataset format - missing required fields and no conversion path available")
# Log dataset info
log_info(f"Dataset has {len(dataset)} examples and columns: {dataset.column_names}")
# Show a few examples for verification
for i in range(min(3, len(dataset))):
example = dataset[i]
log_info(f"Example {i}:")
for key, value in example.items():
if key == "conversations":
log_info(f" conversations: {len(value)} messages")
# Show first message only to avoid cluttering logs
if value and len(value) > 0:
first_msg = value[0]
if isinstance(first_msg, dict) and "content" in first_msg:
content = first_msg["content"]
log_info(f" First message: {content[:50]}..." if len(content) > 50 else f" First message: {content}")
else:
log_info(f" {key}: {value}")
return dataset
except Exception as e:
log_info(f"Error loading dataset: {str(e)}")
traceback.print_exc()
return None
def format_phi_chat(messages, dataset_config):
"""Format messages according to phi-4's chat template and dataset config.
Only formats the conversation structure, preserves the actual content."""
formatted_chat = ""
# Get role templates from config
roles = dataset_config.get("data_formatting", {}).get("roles", {
"system": "System: {content}\n\n",
"human": "Human: {content}\n\n",
"assistant": "Assistant: {content}\n\n"
})
# Handle each message in the conversation
for message in messages:
if not isinstance(message, dict) or "content" not in message:
logger.warning(f"Skipping invalid message format: {message}")
continue
content = message.get("content", "") # Don't strip() - preserve exact content
# Skip empty content
if not content:
continue
# Only add role prefixes based on position/content
if "[RESEARCH INTRODUCTION]" in content:
# System message
template = roles.get("system", "System: {content}\n\n")
formatted_chat = template.format(content=content) + formatted_chat
else:
# Alternate between human and assistant for regular conversation turns
# In phi-4 format, human messages come first, followed by assistant responses
if len(formatted_chat.split("Human:")) == len(formatted_chat.split("Assistant:")):
# If equal numbers of Human and Assistant messages, next is Human
template = roles.get("human", "Human: {content}\n\n")
else:
# Otherwise, next is Assistant
template = roles.get("assistant", "Assistant: {content}\n\n")
formatted_chat += template.format(content=content)
return formatted_chat
class SimpleDataCollator:
def __init__(self, tokenizer, dataset_config):
self.tokenizer = tokenizer
self.max_seq_length = min(dataset_config.get("max_seq_length", 2048), tokenizer.model_max_length)
self.stats = {
"processed": 0,
"skipped": 0,
"total_tokens": 0
}
logger.info(f"Initialized SimpleDataCollator with max_seq_length={self.max_seq_length}")
def __call__(self, features):
# Initialize tensors on CPU to save GPU memory
batch = {
"input_ids": [],
"attention_mask": [],
"labels": []
}
for feature in features:
paper_id = feature.get("article_id", "unknown")
prompt_num = feature.get("prompt_number", 0)
conversations = feature.get("conversations", [])
if not conversations:
logger.warning(f"No conversations for paper_id {paper_id}, prompt {prompt_num}")
self.stats["skipped"] += 1
continue
# Get the content directly
content = conversations[0].get("content", "")
if not content:
logger.warning(f"Empty content for paper_id {paper_id}, prompt {prompt_num}")
self.stats["skipped"] += 1
continue
# Process the content string by tokenizing it
if isinstance(content, str):
# Tokenize the content string
input_ids = self.tokenizer.encode(content, add_special_tokens=True)
else:
# If somehow the content is already tokenized (not a string), use it directly
input_ids = content
# Truncate if needed
if len(input_ids) > self.max_seq_length:
input_ids = input_ids[:self.max_seq_length]
logger.warning(f"Truncated sequence for paper_id {paper_id}, prompt {prompt_num}")
# Create attention mask (1s for all tokens)
attention_mask = [1] * len(input_ids)
# Add to batch
batch["input_ids"].append(input_ids)
batch["attention_mask"].append(attention_mask)
batch["labels"].append(input_ids.copy()) # For causal LM, labels = input_ids
self.stats["processed"] += 1
self.stats["total_tokens"] += len(input_ids)
# Log statistics periodically
if self.stats["processed"] % 100 == 0:
avg_tokens = self.stats["total_tokens"] / max(1, self.stats["processed"])
logger.info(f"Data collation stats: processed={self.stats['processed']}, "
f"skipped={self.stats['skipped']}, avg_tokens={avg_tokens:.1f}")
# Convert to tensors or pad sequences (PyTorch will handle)
if batch["input_ids"]:
# Pad sequences to max length in batch using the tokenizer
batch = self.tokenizer.pad(
batch,
padding="max_length",
max_length=self.max_seq_length,
return_tensors="pt"
)
return batch
else:
# Return empty batch if no valid examples
return {k: [] for k in batch}
def log_gpu_memory_usage(step=None, frequency=50, clear_cache_threshold=0.9, label=None):
"""
Log GPU memory usage statistics with optional cache clearing
Args:
step: Current training step (if None, logs regardless of frequency)
frequency: How often to log when step is provided
clear_cache_threshold: Fraction of memory used that triggers cache clearing (0-1)
label: Optional label for the log message (e.g., "Initial", "Error", "Step")
"""
if not CUDA_AVAILABLE:
return
# Only log every 'frequency' steps if step is provided
if step is not None and frequency > 0 and step % frequency != 0:
return
# Get memory usage for each GPU
memory_info = []
for i in range(NUM_GPUS):
allocated = torch.cuda.memory_allocated(i) / (1024 ** 2) # MB
reserved = torch.cuda.memory_reserved(i) / (1024 ** 2) # MB
max_mem = torch.cuda.max_memory_allocated(i) / (1024 ** 2) # MB
# Calculate percentage of reserved memory that's allocated
usage_percent = (allocated / reserved) * 100 if reserved > 0 else 0
memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB ({usage_percent:.1f}%, max: {max_mem:.1f}MB)")
# Automatically clear cache if over threshold
if clear_cache_threshold > 0 and reserved > 0 and (allocated / reserved) > clear_cache_threshold:
log_info(f"Clearing CUDA cache for GPU {i} - high utilization ({allocated:.1f}/{reserved:.1f}MB)")
with torch.cuda.device(i):
torch.cuda.empty_cache()
prefix = f"{label} " if label else ""
log_info(f"{prefix}GPU Memory: {', '.join(memory_info)}")
class LoggingCallback(TrainerCallback):
"""
Custom callback for logging training progress and metrics.
Provides detailed information about training status, GPU memory usage, and model performance.
"""
def __init__(self, model=None, dataset=None):
# Ensure we have TrainerCallback
try:
super().__init__()
except Exception as e:
# Try to import directly if initial import failed
try:
from transformers.trainer_callback import TrainerCallback
self.__class__.__bases__ = (TrainerCallback,)
super().__init__()
log_info("Successfully imported TrainerCallback directly")
except ImportError as ie:
log_info(f"❌ Error: Could not import TrainerCallback: {str(ie)}")
log_info("Please ensure transformers is properly installed")
raise
self.training_started = time.time()
self.last_log_time = time.time()
self.last_step_time = None
self.step_durations = []
self.best_loss = float('inf')
self.model = model
self.dataset = dataset
def on_train_begin(self, args, state, control, **kwargs):
"""Called at the beginning of training"""
try:
log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===")
# Log model info if available
if self.model is not None:
total_params = sum(p.numel() for p in self.model.parameters())
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
log_info(f"Model parameters: {total_params/1e6:.2f}M total, {trainable_params/1e6:.2f}M trainable")
# Log dataset info if available
if self.dataset is not None:
log_info(f"Dataset size: {len(self.dataset)} examples")
# Log important training parameters for visibility
total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS
total_steps = int(len(self.dataset or []) / (args.per_device_train_batch_size * NUM_GPUS * args.gradient_accumulation_steps) * args.num_train_epochs)
log_info(f"Training plan: {len(self.dataset or [])} examples over {args.num_train_epochs} epochs ≈ {total_steps} steps")
log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total")
# Log initial GPU memory usage with label
log_gpu_memory_usage(label="Initial")
except Exception as e:
logger.warning(f"Error logging training begin statistics: {str(e)}")
def on_step_end(self, args, state, control, **kwargs):
"""Called at the end of each step"""
try:
if state.global_step == 1 or state.global_step % args.logging_steps == 0:
# Track step timing
current_time = time.time()
if self.last_step_time:
step_duration = current_time - self.last_step_time
self.step_durations.append(step_duration)
# Keep only last 100 steps for averaging
if len(self.step_durations) > 100:
self.step_durations.pop(0)
avg_step_time = sum(self.step_durations) / len(self.step_durations)
log_info(f"Step {state.global_step}: {step_duration:.2f}s (avg: {avg_step_time:.2f}s)")
self.last_step_time = current_time
# Log GPU memory usage with step number
log_gpu_memory_usage(state.global_step, args.logging_steps)
# Log loss
if state.log_history:
latest_logs = state.log_history[-1] if state.log_history else {}
if "loss" in latest_logs:
loss = latest_logs["loss"]
log_info(f"Step {state.global_step} loss: {loss:.4f}")
# Track best loss
if loss < self.best_loss:
self.best_loss = loss
log_info(f"New best loss: {loss:.4f}")
except Exception as e:
logger.warning(f"Error logging step end statistics: {str(e)}")
def on_train_end(self, args, state, control, **kwargs):
"""Called at the end of training"""
try:
# Calculate training duration
training_time = time.time() - self.training_started
hours, remainder = divmod(training_time, 3600)
minutes, seconds = divmod(remainder, 60)
log_info(f"=== Training completed at {time.strftime('%Y-%m-%d %H:%M:%S')} ===")
log_info(f"Training duration: {int(hours)}h {int(minutes)}m {int(seconds)}s")
log_info(f"Final step: {state.global_step}")
log_info(f"Best loss: {self.best_loss:.4f}")
# Log final GPU memory usage
log_gpu_memory_usage(label="Final")
except Exception as e:
logger.warning(f"Error logging training end statistics: {str(e)}")
# Other callback methods with proper error handling
def on_save(self, args, state, control, **kwargs):
"""Called when a checkpoint is saved"""
try:
log_info(f"Saving checkpoint at step {state.global_step}")
except Exception as e:
logger.warning(f"Error in on_save: {str(e)}")
def on_log(self, args, state, control, **kwargs):
"""Called when a log is created"""
pass
def on_evaluate(self, args, state, control, **kwargs):
"""Called when evaluation is performed"""
pass
# Only implement the methods we actually need, remove the others
def on_prediction_step(self, args, state, control, **kwargs):
"""Called when prediction is performed"""
pass
def on_save_model(self, args, state, control, **kwargs):
"""Called when model is saved"""
try:
# Log memory usage after saving
log_gpu_memory_usage(label=f"Save at step {state.global_step}")
except Exception as e:
logger.warning(f"Error in on_save_model: {str(e)}")
def on_epoch_end(self, args, state, control, **kwargs):
"""Called at the end of an epoch"""
try:
epoch = state.epoch
log_info(f"Completed epoch {epoch:.2f}")
log_gpu_memory_usage(label=f"Epoch {epoch:.2f}")
except Exception as e:
logger.warning(f"Error in on_epoch_end: {str(e)}")
def on_step_begin(self, args, state, control, **kwargs):
"""Called at the beginning of a step"""
pass
def install_flash_attention():
"""
Attempt to install Flash Attention for improved performance.
Returns True if installation was successful, False otherwise.
"""
log_info("Attempting to install Flash Attention...")
# Check for CUDA before attempting installation
if not CUDA_AVAILABLE:
log_info("❌ Cannot install Flash Attention: CUDA not available")
return False
try:
# Check CUDA version to determine correct installation command
cuda_version = torch.version.cuda
if cuda_version is None:
log_info("❌ Cannot determine CUDA version for Flash Attention installation")
return False
import subprocess
# Use --no-build-isolation for better compatibility
install_cmd = [
sys.executable,
"-m",
"pip",
"install",
"flash-attn",
"--no-build-isolation"
]
log_info(f"Running: {' '.join(install_cmd)}")
result = subprocess.run(
install_cmd,
capture_output=True,
text=True,
check=False
)
if result.returncode == 0:
log_info("✅ Flash Attention installed successfully!")
# Attempt to import to verify installation
try:
import flash_attn
log_info(f"✅ Flash Attention version {flash_attn.__version__} is now available")
return True
except ImportError:
log_info("⚠️ Flash Attention installed but import failed")
return False
else:
log_info(f"❌ Flash Attention installation failed with error: {result.stderr}")
return False
except Exception as e:
log_info(f"❌ Error installing Flash Attention: {str(e)}")
return False
def check_dependencies():
"""
Check for required and optional dependencies, ensuring proper versions and import order.
Returns True if all required dependencies are present, False otherwise.
"""
# Define required packages with versions and descriptions
required_packages = {
"unsloth": {"version": ">=2024.3", "feature": "fast 4-bit quantization and LoRA"},
"transformers": {"version": ">=4.38.0", "feature": "core model functionality"},
"peft": {"version": ">=0.9.0", "feature": "parameter-efficient fine-tuning"},
"accelerate": {"version": ">=0.27.0", "feature": "multi-GPU training"}
}
# Optional packages that enhance functionality
optional_packages = {
"flash_attn": {"feature": "faster attention computation"},
"bitsandbytes": {"feature": "quantization support"},
"optimum": {"feature": "model optimization"},
"wandb": {"feature": "experiment tracking"}
}
# Store results
missing_packages = []
package_versions = {}
order_issues = []
missing_optional = []
# Check required packages
log_info("Checking required dependencies...")
for package, info in required_packages.items():
version_req = info["version"]
feature = info["feature"]
try:
# Special handling for packages we've already checked
if package == "unsloth" and not unsloth_available:
missing_packages.append(f"{package}{version_req}")
log_info(f"❌ {package} - {feature} MISSING")
continue
elif package == "peft" and not peft_available:
missing_packages.append(f"{package}{version_req}")
log_info(f"❌ {package} - {feature} MISSING")
continue
# Try to import and get version
module = __import__(package)
version = getattr(module, "__version__", "unknown")
package_versions[package] = version
log_info(f"✅ {package} v{version} - {feature}")
except ImportError:
missing_packages.append(f"{package}{version_req}")
log_info(f"❌ {package} - {feature} MISSING")
# Check optional packages
log_info("\nChecking optional dependencies...")
for package, info in optional_packages.items():
feature = info["feature"]
try:
__import__(package)
log_info(f"✅ {package} - {feature} available")
except ImportError:
log_info(f"⚠️ {package} - {feature} not available")
missing_optional.append(package)
# Check import order for optimal performance
if "transformers" in package_versions and "unsloth" in package_versions:
try:
import sys
modules = list(sys.modules.keys())
transformers_idx = modules.index("transformers")
unsloth_idx = modules.index("unsloth")
if transformers_idx < unsloth_idx:
order_issue = "⚠️ For optimal performance, import unsloth before transformers"
order_issues.append(order_issue)
log_info(order_issue)
log_info("This might cause performance issues but won't prevent training")
else:
log_info("✅ Import order: unsloth before transformers (optimal)")
except (ValueError, IndexError) as e:
log_info(f"⚠️ Could not verify import order: {str(e)}")
# Try to install missing optional packages
if "flash_attn" in missing_optional and CUDA_AVAILABLE:
log_info("\nFlash Attention is missing but would improve performance.")
install_result = install_flash_attention()
if install_result:
missing_optional.remove("flash_attn")
# Report missing required packages
if missing_packages:
log_info("\n❌ Critical dependencies missing:")
for pkg in missing_packages:
log_info(f" - {pkg}")
log_info("Please install missing dependencies with:")
log_info(f" pip install {' '.join(missing_packages)}")
return False
log_info("\n✅ All required dependencies satisfied!")
return True
def get_config_value(config, path, default=None):
"""
Safely get a nested value from a config dictionary using a dot-separated path.
Args:
config: The configuration dictionary
path: Dot-separated path to the value (e.g., "training.optimizer.lr")
default: Default value to return if path doesn't exist
Returns:
The value at the specified path or the default value
"""
if not config:
return default
parts = path.split('.')
current = config
for part in parts:
if isinstance(current, dict) and part in current:
current = current[part]
else:
return default
return current
def update_huggingface_space():
"""Update the Hugging Face Space with the current code."""
log_info("Updating Hugging Face Space...")
update_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "update_space.py")
if not os.path.exists(update_script):
logger.warning(f"Update space script not found at {update_script}")
return False
try:
import subprocess
# Explicitly set space_name to ensure we're targeting the right Space
result = subprocess.run(
[sys.executable, update_script, "--force", "--space_name", "phi4training"],
capture_output=True, text=True, check=False
)
if result.returncode == 0:
log_info("Hugging Face Space updated successfully!")
log_info(f"Space URL: https://huggingface.co/spaces/George-API/phi4training")
return True
else:
logger.error(f"Failed to update Hugging Face Space: {result.stderr}")
return False
except Exception as e:
logger.error(f"Error updating Hugging Face Space: {str(e)}")
return False
def validate_huggingface_credentials():
"""Validate Hugging Face credentials to ensure they work correctly."""
if not os.environ.get("HF_TOKEN"):
logger.warning("HF_TOKEN not found. Skipping Hugging Face credentials validation.")
return False
try:
# Import here to avoid requiring huggingface_hub if not needed
from huggingface_hub import HfApi, login
# Try to login with the token
login(token=os.environ.get("HF_TOKEN"))
# Check if we can access the API
api = HfApi()
username = os.environ.get("HF_USERNAME", "George-API")
space_name = os.environ.get("HF_SPACE_NAME", "phi4training")
# Try to get whoami info
user_info = api.whoami()
logger.info(f"Successfully authenticated with Hugging Face as {user_info['name']}")
# Check if we're using the expected Space
expected_space_id = "George-API/phi4training"
actual_space_id = f"{username}/{space_name}"
if actual_space_id != expected_space_id:
logger.warning(f"Using Space '{actual_space_id}' instead of the expected '{expected_space_id}'")
logger.warning(f"Make sure this is intentional. To use the correct Space, update your .env file.")
else:
logger.info(f"Confirmed using Space: {expected_space_id}")
# Check if the space exists
try:
space_id = f"{username}/{space_name}"
space_info = api.space_info(repo_id=space_id)
logger.info(f"Space {space_id} is accessible at: https://huggingface.co/spaces/{space_id}")
return True
except Exception as e:
logger.warning(f"Could not access Space {username}/{space_name}: {str(e)}")
logger.warning("Space updating may not work correctly")
return False
except ImportError:
logger.warning("huggingface_hub not installed. Cannot validate Hugging Face credentials.")
return False
except Exception as e:
logger.warning(f"Error validating Hugging Face credentials: {str(e)}")
return False
def setup_environment(args):
"""
Set up the training environment including logging, seed, and configurations.
Args:
args: Command line arguments
Returns:
tuple: (transformers_config, seed) - The loaded configuration and random seed
"""
# Load environment variables first
load_env_variables()
# Set random seed for reproducibility
seed = args.seed if args.seed is not None else int(time.time()) % 10000
set_seed(seed)
log_info(f"Using random seed: {seed}")
# Load configuration
base_path = os.path.dirname(os.path.abspath(__file__))
config_file = args.config_file or os.path.join(base_path, "transformers_config.json")
if not os.path.exists(config_file):
raise FileNotFoundError(f"Config file not found: {config_file}")
log_info(f"Loading configuration from {config_file}")
transformers_config = load_configs(config_file)
# Set up hardware environment variables if CUDA is available
if CUDA_AVAILABLE:
memory_fraction = get_config_value(transformers_config, "hardware.system_settings.cuda_memory_fraction", 0.75)
if memory_fraction < 1.0:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:128,expandable_segments:True"
log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128")
# Check dependencies and install optional ones if needed
if not check_dependencies():
raise RuntimeError("Critical dependencies missing")
# Check if flash attention was successfully installed
flash_attention_available = False
try:
import flash_attn
flash_attention_available = True
log_info(f"Flash Attention will be used (version: {flash_attn.__version__})")
# Update config to use flash attention
if "use_flash_attention" not in transformers_config:
transformers_config["use_flash_attention"] = True
except ImportError:
log_info("Flash Attention not available, will use standard attention mechanism")
return transformers_config, seed
def setup_model_and_tokenizer(config):
"""
Load and configure the model and tokenizer.
Args:
config (dict): Complete configuration dictionary
Returns:
tuple: (model, tokenizer) - The loaded model and tokenizer
"""
# Extract model configuration
model_config = get_config_value(config, "model", {})
model_name = get_config_value(model_config, "name", "unsloth/phi-4-unsloth-bnb-4bit")
use_fast_tokenizer = get_config_value(model_config, "use_fast_tokenizer", True)
trust_remote_code = get_config_value(model_config, "trust_remote_code", True)
model_revision = get_config_value(config, "model_revision", "main")
# Detect if model is already pre-quantized (includes '4bit', 'bnb', or 'int4' in name)
is_prequantized = any(q in model_name.lower() for q in ['4bit', 'bnb', 'int4', 'quant'])
if is_prequantized:
log_info("⚠️ Detected pre-quantized model. No additional quantization will be applied.")
# Unsloth configuration
unsloth_config = get_config_value(config, "unsloth", {})
unsloth_enabled = get_config_value(unsloth_config, "enabled", True)
# Tokenizer configuration
tokenizer_config = get_config_value(config, "tokenizer", {})
max_seq_length = min(
get_config_value(tokenizer_config, "max_seq_length", 2048),
4096 # Maximum supported by most models
)
add_eos_token = get_config_value(tokenizer_config, "add_eos_token", True)
chat_template = get_config_value(tokenizer_config, "chat_template", None)
padding_side = get_config_value(tokenizer_config, "padding_side", "right")
# Check for flash attention
use_flash_attention = get_config_value(config, "use_flash_attention", False)
flash_attention_available = False
try:
import flash_attn
flash_attention_available = True
log_info(f"Flash Attention detected (version: {flash_attn.__version__})")
except ImportError:
if use_flash_attention:
log_info("Flash Attention requested but not available")
log_info(f"Loading model: {model_name} (revision: {model_revision})")
log_info(f"Max sequence length: {max_seq_length}")
try:
if unsloth_enabled and unsloth_available:
log_info("Using Unsloth for LoRA fine-tuning")
if is_prequantized:
log_info("Using pre-quantized model - no additional quantization will be applied")
else:
log_info("Using 4-bit quantization for efficient training")
# Load using Unsloth
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
dtype=get_config_value(config, "torch_dtype", "bfloat16"),
revision=model_revision,
trust_remote_code=trust_remote_code,
use_flash_attention_2=use_flash_attention and flash_attention_available
)
# Configure tokenizer settings
tokenizer.padding_side = padding_side
if add_eos_token and tokenizer.eos_token is None:
log_info("Setting EOS token")
tokenizer.add_special_tokens({"eos_token": "</s>"})
# Set chat template if specified
if chat_template:
log_info(f"Setting chat template: {chat_template}")
if hasattr(tokenizer, "chat_template"):
tokenizer.chat_template = chat_template
else:
log_info("Tokenizer does not support chat templates, using default formatting")
# Apply LoRA
lora_r = get_config_value(unsloth_config, "r", 16)
lora_alpha = get_config_value(unsloth_config, "alpha", 32)
lora_dropout = get_config_value(unsloth_config, "dropout", 0)
target_modules = get_config_value(unsloth_config, "target_modules",
["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}")
model = FastLanguageModel.get_peft_model(
model,
r=lora_r,
target_modules=target_modules,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
bias="none",
use_gradient_checkpointing=get_config_value(config, "training.gradient_checkpointing", True),
random_state=0,
max_seq_length=max_seq_length,
modules_to_save=None
)
if use_flash_attention and flash_attention_available:
log_info("🚀 Using Flash Attention for faster training")
elif use_flash_attention and not flash_attention_available:
log_info("⚠️ Flash Attention requested but not available - using standard attention")
else:
# Standard HuggingFace loading
log_info("Using standard HuggingFace model loading (Unsloth not available or disabled)")
from transformers import AutoModelForCausalLM, AutoTokenizer
# Check if flash attention should be enabled in config
use_attn_implementation = None
if use_flash_attention and flash_attention_available:
use_attn_implementation = "flash_attention_2"
log_info("🚀 Using Flash Attention for faster training")
# Load tokenizer first
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
use_fast=use_fast_tokenizer,
revision=model_revision,
padding_side=padding_side
)
# Configure tokenizer settings
if add_eos_token and tokenizer.eos_token is None:
log_info("Setting EOS token")
tokenizer.add_special_tokens({"eos_token": "</s>"})
# Set chat template if specified
if chat_template:
log_info(f"Setting chat template: {chat_template}")
if hasattr(tokenizer, "chat_template"):
tokenizer.chat_template = chat_template
else:
log_info("Tokenizer does not support chat templates, using default formatting")
# Only apply quantization config if model is not already pre-quantized
quantization_config = None
if not is_prequantized and CUDA_AVAILABLE:
try:
from transformers import BitsAndBytesConfig
log_info("Using 4-bit quantization (BitsAndBytes) for efficient training")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
except ImportError:
log_info("BitsAndBytes not available - quantization disabled")
# Now load model with updated tokenizer
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
revision=model_revision,
torch_dtype=torch.bfloat16 if get_config_value(config, "torch_dtype", "bfloat16") == "bfloat16" else torch.float16,
device_map="auto" if CUDA_AVAILABLE else None,
attn_implementation=use_attn_implementation,
quantization_config=quantization_config
)
# Apply PEFT/LoRA if enabled but using standard loading
if peft_available and get_config_value(unsloth_config, "enabled", True):
log_info("Applying standard PEFT/LoRA configuration")
from peft import LoraConfig, get_peft_model
lora_r = get_config_value(unsloth_config, "r", 16)
lora_alpha = get_config_value(unsloth_config, "alpha", 32)
lora_dropout = get_config_value(unsloth_config, "dropout", 0)
target_modules = get_config_value(unsloth_config, "target_modules",
["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}")
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
# Print model summary
log_info(f"Model loaded successfully: {model.__class__.__name__}")
if hasattr(model, "print_trainable_parameters"):
model.print_trainable_parameters()
else:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
log_info(f"Model has {total_params:,} parameters, {trainable_params:,} trainable ({trainable_params/total_params:.2%})")
return model, tokenizer
except Exception as e:
log_info(f"Error loading model: {str(e)}")
traceback.print_exc()
return None, None
def setup_dataset_and_collator(config, tokenizer):
"""
Load and configure the dataset and data collator.
Args:
config: Complete configuration dictionary
tokenizer: The tokenizer for the data collator
Returns:
tuple: (dataset, data_collator) - The loaded dataset and configured data collator
"""
dataset_config = get_config_value(config, "dataset", {})
log_info("Loading dataset...")
dataset = load_dataset_with_mapping(dataset_config)
# Validate dataset
if dataset is None:
raise ValueError("Dataset is None! Cannot proceed with training.")
if not hasattr(dataset, '__len__') or len(dataset) == 0:
raise ValueError("Dataset is empty! Cannot proceed with training.")
log_info(f"Dataset loaded with {len(dataset)} examples")
# Create data collator
data_collator = SimpleDataCollator(tokenizer, dataset_config)
return dataset, data_collator
def create_training_arguments(config, dataset):
"""
Create and configure training arguments for the Trainer.
Args:
config: Complete configuration dictionary
dataset: The dataset to determine total steps
Returns:
TrainingArguments: Configured training arguments
"""
# Extract configuration sections
training_config = get_config_value(config, "training", {})
hardware_config = get_config_value(config, "hardware", {})
huggingface_config = get_config_value(config, "huggingface_hub", {})
distributed_config = get_config_value(config, "distributed_training", {})
# Extract key training parameters
per_device_batch_size = get_config_value(training_config, "per_device_train_batch_size", 4)
gradient_accumulation_steps = get_config_value(training_config, "gradient_accumulation_steps", 8)
learning_rate = get_config_value(training_config, "learning_rate", 2e-5)
num_train_epochs = get_config_value(training_config, "num_train_epochs", 3)
# Extract hardware settings
dataloader_workers = get_config_value(hardware_config, "system_settings.dataloader_num_workers",
get_config_value(distributed_config, "dataloader_num_workers", 2))
pin_memory = get_config_value(hardware_config, "system_settings.dataloader_pin_memory", True)
# BF16/FP16 settings - ensure only one is enabled
use_bf16 = get_config_value(training_config, "bf16", False)
use_fp16 = get_config_value(training_config, "fp16", False) if not use_bf16 else False
# Configure distributed training
fsdp_config = get_config_value(distributed_config, "fsdp_config", {})
fsdp_enabled = get_config_value(fsdp_config, "enabled", False)
ddp_config = get_config_value(distributed_config, "ddp_config", {})
ddp_find_unused_parameters = get_config_value(ddp_config, "find_unused_parameters", False)
# Set up FSDP args if enabled
fsdp_args = None
if fsdp_enabled and NUM_GPUS > 1:
from accelerate import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig, FullStateDictConfig
)
fsdp_plugin = FullyShardedDataParallelPlugin(
sharding_strategy=get_config_value(fsdp_config, "sharding_strategy", "FULL_SHARD"),
mixed_precision_policy=get_config_value(fsdp_config, "mixed_precision", "BF16"),
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
)
fsdp_args = {
"fsdp": fsdp_plugin,
"fsdp_transformer_layer_cls_to_wrap": ["LlamaDecoderLayer", "PhiDecoderLayer"]
}
# Create and return training arguments
training_args = TrainingArguments(
output_dir=get_config_value(config, "checkpointing.output_dir", "./results"),
overwrite_output_dir=True,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=per_device_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=learning_rate,
weight_decay=get_config_value(training_config, "weight_decay", 0.01),
max_grad_norm=get_config_value(training_config, "max_grad_norm", 1.0),
warmup_ratio=get_config_value(training_config, "warmup_ratio", 0.03),
lr_scheduler_type=get_config_value(training_config, "lr_scheduler_type", "cosine"),
logging_steps=get_config_value(training_config, "logging_steps", 10),
save_strategy=get_config_value(config, "checkpointing.save_strategy", "steps"),
save_steps=get_config_value(config, "checkpointing.save_steps", 500),
save_total_limit=get_config_value(config, "checkpointing.save_total_limit", 3),
bf16=use_bf16,
fp16=use_fp16,
push_to_hub=get_config_value(huggingface_config, "push_to_hub", False),
hub_model_id=get_config_value(huggingface_config, "hub_model_id", None),
hub_strategy=get_config_value(huggingface_config, "hub_strategy", "every_save"),
hub_private_repo=get_config_value(huggingface_config, "hub_private_repo", True),
gradient_checkpointing=get_config_value(training_config, "gradient_checkpointing", True),
dataloader_pin_memory=pin_memory,
optim=get_config_value(training_config, "optim", "adamw_torch"),
ddp_find_unused_parameters=ddp_find_unused_parameters,
dataloader_drop_last=False,
dataloader_num_workers=dataloader_workers,
no_cuda=False if CUDA_AVAILABLE else True,
**({} if fsdp_args is None else fsdp_args)
)
log_info("Training arguments created successfully")
return training_args
def configure_custom_dataloader(trainer, dataset, config, training_args):
"""
Configure a custom dataloader for the trainer if needed.
Args:
trainer: The Trainer instance to configure
dataset: The dataset to use
config: Complete configuration dictionary
training_args: The training arguments
Returns:
None (modifies trainer in-place)
"""
dataset_config = get_config_value(config, "dataset", {})
# Check if we need a custom dataloader
if get_config_value(dataset_config, "data_loading.sequential_processing", True):
log_info("Using custom sequential dataloader")
# Create sequential sampler to maintain dataset order
sequential_sampler = torch.utils.data.SequentialSampler(dataset)
log_info("Sequential sampler created")
# Define custom dataloader getter
def custom_get_train_dataloader():
"""Create a custom dataloader that maintains dataset order"""
# Get configuration values
batch_size = training_args.per_device_train_batch_size
drop_last = get_config_value(dataset_config, "data_loading.drop_last", False)
num_workers = training_args.dataloader_num_workers
pin_memory = training_args.dataloader_pin_memory
prefetch_factor = get_config_value(dataset_config, "data_loading.prefetch_factor", 2)
persistent_workers = get_config_value(dataset_config, "data_loading.persistent_workers", False)
# Create DataLoader with sequential sampler
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sequential_sampler,
collate_fn=trainer.data_collator,
drop_last=drop_last,
num_workers=num_workers,
pin_memory=pin_memory,
prefetch_factor=prefetch_factor if num_workers > 0 else None,
persistent_workers=persistent_workers if num_workers > 0 else False,
)
# Override the default dataloader
trainer.get_train_dataloader = custom_get_train_dataloader
def run_training(trainer, tokenizer, training_args):
"""
Run the training process and handle model saving.
Args:
trainer: Configured Trainer instance
tokenizer: The tokenizer to save with the model
training_args: Training arguments
Returns:
int: 0 for success, 1 for failure
"""
log_info("Starting training...")
trainer.train()
log_info("Training complete! Saving final model...")
trainer.save_model()
tokenizer.save_pretrained(training_args.output_dir)
# Push to Hub if configured
if training_args.push_to_hub:
log_info(f"Pushing model to Hugging Face Hub: {training_args.hub_model_id}")
trainer.push_to_hub()
log_info("Training completed successfully!")
return 0
def main():
"""
Main entry point for the training script.
Returns:
int: 0 for success, non-zero for failure
"""
# Set up logging
logger.info("Starting training process")
try:
# Verify critical imports are available
if not transformers_available:
log_info("❌ Error: transformers library not available. Please install it with: pip install transformers")
return 1
# Check for required classes
for required_class in ["Trainer", "TrainingArguments", "TrainerCallback"]:
if not hasattr(transformers, required_class):
log_info(f"❌ Error: {required_class} not found in transformers. Please update transformers.")
return 1
# Check for potential import order issue and warn early
if "transformers" in sys.modules and "unsloth" in sys.modules:
if list(sys.modules.keys()).index("transformers") < list(sys.modules.keys()).index("unsloth"):
log_info("⚠️ Warning: transformers was imported before unsloth. This may affect performance.")
log_info(" For optimal performance in future runs, import unsloth first.")
# Parse command line arguments
args = parse_args()
# Set up environment and load configuration
transformers_config, seed = setup_environment(args)
# Load model and tokenizer
try:
model, tokenizer = setup_model_and_tokenizer(transformers_config)
except Exception as e:
logger.error(f"Error setting up model: {str(e)}")
return 1
# Load dataset and create data collator
try:
dataset, data_collator = setup_dataset_and_collator(transformers_config, tokenizer)
except Exception as e:
logger.error(f"Error setting up dataset: {str(e)}")
return 1
# Configure training arguments
try:
training_args = create_training_arguments(transformers_config, dataset)
except Exception as e:
logger.error(f"Error configuring training arguments: {str(e)}")
return 1
# Initialize trainer with callbacks
log_info("Initializing Trainer")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=data_collator,
callbacks=[LoggingCallback(model=model, dataset=dataset)],
)
# Configure custom dataloader if needed
try:
configure_custom_dataloader(trainer, dataset, transformers_config, training_args)
except Exception as e:
logger.error(f"Error configuring custom dataloader: {str(e)}")
return 1
# Run training process
try:
return run_training(trainer, tokenizer, training_args)
except Exception as e:
logger.error(f"Training failed with error: {str(e)}")
# Log GPU memory for debugging
log_gpu_memory_usage(label="Error")
# Print full stack trace
traceback.print_exc()
return 1
except Exception as e:
logger.error(f"Error in main function: {str(e)}")
traceback.print_exc()
return 1
if __name__ == "__main__":
sys.exit(main())