def format_phi_chat(messages, dataset_config): """Format messages according to phi-4's chat template and dataset config.""" formatted_chat = "" # Get role templates from config roles = dataset_config.get("data_formatting", {}).get("roles", { "system": "System: {content}\n\n", "human": "Human: {content}\n\n", "user": "Human: {content}\n\n", "assistant": "Assistant: {content}\n\n" }) # Handle research introduction metadata first metadata = next((msg for msg in messages if isinstance(msg, dict) and "[RESEARCH INTRODUCTION]" in msg.get("content", "")), None) if metadata: system_template = roles.get("system", "System: {content}\n\n") formatted_chat = system_template.format(content=metadata['content']) messages = [msg for msg in messages if msg != metadata] # Process remaining messages for message in messages: if not isinstance(message, dict) or "content" not in message: logger.warning(f"Skipping invalid message format: {message}") continue role = message.get("role", "").lower() content = message.get("content", "") # Format based on role if role == "human" or role == "user": template = roles.get("user", roles.get("human", "Human: {content}\n\n")) formatted_chat += template.format(content=content) elif role == "assistant" or role == "bot": template = roles.get("assistant", "Assistant: {content}\n\n") formatted_chat += template.format(content=content) elif role == "system": # For system messages, prepend them template = roles.get("system", "System: {content}\n\n") formatted_chat = template.format(content=content) + formatted_chat else: # Default to system for unknown roles logger.warning(f"Unknown role '{role}' - treating as system message") template = roles.get("system", "System: {content}\n\n") formatted_chat += template.format(content=content) return formatted_chat.strip() class SimpleDataCollator: def __init__(self, tokenizer, dataset_config): self.tokenizer = tokenizer self.dataset_config = dataset_config self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0} self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048) logger.info(f"SimpleDataCollator initialized - using pre-audited dataset with max_seq_length={self.max_seq_length}") logger.info("Using exact dataset structure without reformatting") # Check if we're on GPU self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"SimpleDataCollator using device: {self.device}") def __call__(self, features): """Process examples preserving exact JSONL structure""" batch = {"input_ids": [], "attention_mask": [], "labels": []} for example in features: try: # Get ID paper_id = example.get("id", "") # Get conversations - these should already contain role and content conversations = example.get("conversations", []) if not conversations: self.stats["skipped"] += 1 continue # Directly use the conversations array as input to the model's chat template # This preserves the exact structure with roles and content as they are try: # Let tokenizer handle the content with the model's chat template inputs = self.tokenizer.apply_chat_template( conversations, return_tensors=None, add_generation_prompt=False ) except Exception as chat_error: # Fallback if apply_chat_template fails logger.warning(f"Chat template application failed for example {paper_id}: {str(chat_error)[:100]}") # Create a basic representation of the conversation conversation_text = "" for msg in conversations: if isinstance(msg, dict) and 'content' in msg: conversation_text += msg.get('content', '') + "\n\n" # Basic tokenization inputs = self.tokenizer( conversation_text, add_special_tokens=True, return_tensors=None ) # Apply length cap if needed (shouldn't be necessary for pre-audited data) if self.max_seq_length > 0 and len(inputs) > self.max_seq_length: logger.warning(f"Example {paper_id} exceeds max_seq_length ({len(inputs)} > {self.max_seq_length})") inputs = inputs[:self.max_seq_length] # Create attention mask (1 for all tokens) attention_mask = [1] * len(inputs) if len(inputs) > 0: # For causal language modeling, labels are the same as inputs labels = inputs.copy() batch["input_ids"].append(inputs) batch["attention_mask"].append(attention_mask) batch["labels"].append(labels) self.stats["processed"] += 1 self.stats["total_tokens"] += len(inputs) # Debug logging for first few examples log_samples = self.dataset_config.get("validation", {}).get("log_samples", 3) if self.stats["processed"] <= log_samples: logger.info(f"Example {self.stats['processed']}:") logger.info(f"Paper ID: {paper_id}") logger.info(f"Token count: {len(inputs)}") logger.info(f"Conversation entries: {len(conversations)}") else: self.stats["skipped"] += 1 except Exception as e: logger.warning(f"Error processing example: {str(e)[:100]}...") logger.warning(f"Problematic example ID: {example.get('id', 'unknown')}") self.stats["skipped"] += 1 continue if not batch["input_ids"]: logger.warning("Empty batch, returning dummy tensors") return { "input_ids": torch.zeros((1, 1), dtype=torch.long), "attention_mask": torch.zeros((1, 1), dtype=torch.long), "labels": torch.zeros((1, 1), dtype=torch.long) } # Pad the batch max_length = max(len(ids) for ids in batch["input_ids"]) for i in range(len(batch["input_ids"])): padding_length = max_length - len(batch["input_ids"][i]) if padding_length > 0: batch["input_ids"][i].extend([self.pad_token_id] * padding_length) batch["attention_mask"][i].extend([0] * padding_length) batch["labels"][i].extend([-100] * padding_length) # Convert to tensors batch = {k: torch.tensor(v, dtype=torch.long) for k, v in batch.items()} # Log stats periodically log_interval = self.dataset_config.get("validation", {}).get("log_interval", 100) if self.stats["processed"] % log_interval == 0 and self.stats["processed"] > 0: logger.info(f"Data collator stats: processed={self.stats['processed']}, " f"skipped={self.stats['skipped']}, " f"avg_tokens={self.stats['total_tokens']/self.stats['processed']:.1f}") return batch class LoggingCallback(TrainerCallback): def __init__(self): self.last_log_time = time.time() self.last_memory_log_time = time.time() def on_step_end(self, args, state, control, **kwargs): # Log every 50 steps or every 5 minutes, whichever comes first current_time = time.time() # Log loss every 50 steps or 5 minutes if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300): if state.log_history: loss = state.log_history[-1].get('loss', 'N/A') # Use simple formatting for better HF Space log compatibility log_info(f"Step {state.global_step}: Loss {loss}") else: log_info(f"Step {state.global_step}: No loss data available") self.last_log_time = current_time # Log memory usage every 15 minutes if current_time - self.last_memory_log_time > 900: # 15 minutes if torch.cuda.is_available(): memory_info = [] for i in range(torch.cuda.device_count()): allocated = torch.cuda.memory_allocated(i) / 1024**2 reserved = torch.cuda.memory_reserved(i) / 1024**2 memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB") # Log in compact format for better visibility log_info(f"Memory usage - {', '.join(memory_info)}") self.last_memory_log_time = current_time def on_train_begin(self, args, state, control, **kwargs): log_info("=== Training is starting ===") # Log important training parameters for visibility log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {max(1, torch.cuda.device_count())} GPUs") log_info(f"Learning rate: {args.learning_rate}") log_info(f"Epochs: {args.num_train_epochs}") # Log memory information in compact format if torch.cuda.is_available(): memory_info = [] for i in range(torch.cuda.device_count()): allocated = torch.cuda.memory_allocated(i) / 1024**2 max_mem = torch.cuda.max_memory_allocated(i) / 1024**2 memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)") log_info(f"Initial memory usage - {', '.join(memory_info)}") def on_train_end(self, args, state, control, **kwargs): log_info("=== Training completed ===") if torch.cuda.is_available(): memory_info = [] for i in range(torch.cuda.device_count()): allocated = torch.cuda.memory_allocated(i) / 1024**2 max_mem = torch.cuda.max_memory_allocated(i) / 1024**2 memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)") log_info(f"Final memory usage - {', '.join(memory_info)}") log_info(f"Total steps: {state.global_step}") log_info(f"Final loss: {state.log_history[-1].get('loss', 'N/A') if state.log_history else 'N/A'}")