George-API commited on
Commit
4ce739a
·
verified ·
1 Parent(s): 0360950

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. run_transformers_training.py +22 -14
run_transformers_training.py CHANGED
@@ -391,7 +391,8 @@ def load_dataset_with_mapping(dataset_config):
391
  return 1
392
 
393
  def format_phi_chat(messages, dataset_config):
394
- """Format messages according to phi-4's chat template and dataset config."""
 
395
  formatted_chat = ""
396
 
397
  # Get role templates from config
@@ -407,13 +408,13 @@ def format_phi_chat(messages, dataset_config):
407
  logger.warning(f"Skipping invalid message format: {message}")
408
  continue
409
 
410
- content = message.get("content", "").strip()
411
 
412
  # Skip empty content
413
  if not content:
414
  continue
415
 
416
- # Infer role based on content patterns
417
  if "[RESEARCH INTRODUCTION]" in content:
418
  # System message
419
  template = roles.get("system", "System: {content}\n\n")
@@ -429,7 +430,7 @@ def format_phi_chat(messages, dataset_config):
429
  template = roles.get("assistant", "Assistant: {content}\n\n")
430
  formatted_chat += template.format(content=content)
431
 
432
- return formatted_chat.strip()
433
 
434
  class SimpleDataCollator:
435
  def __init__(self, tokenizer, dataset_config):
@@ -459,17 +460,25 @@ class SimpleDataCollator:
459
  self.stats["skipped"] += 1
460
  continue
461
 
462
- # Format the conversation using phi chat template
463
- formatted_chat = format_phi_chat(conversations, self.dataset_config)
464
-
465
- # Skip if formatting resulted in empty content
466
- if not formatted_chat:
467
- logger.warning(f"Empty formatted chat for paper_id {paper_id}, prompt {prompt_num}")
468
  self.stats["skipped"] += 1
469
  continue
470
 
471
- # Create input IDs and attention mask
472
- input_ids = self.tokenizer.encode(formatted_chat, add_special_tokens=False)
 
 
 
 
 
 
 
 
 
473
 
474
  # Truncate if needed
475
  if len(input_ids) > self.max_seq_length:
@@ -489,8 +498,7 @@ class SimpleDataCollator:
489
 
490
  # Log first few examples for verification
491
  if self.stats["processed"] <= 3:
492
- logger.info(f"Sample {self.stats['processed']} formatted chat:")
493
- logger.info(f"{formatted_chat[:200]}...")
494
 
495
  except Exception as e:
496
  logger.warning(f"Error processing example {paper_id}, prompt {prompt_num}: {str(e)}")
 
391
  return 1
392
 
393
  def format_phi_chat(messages, dataset_config):
394
+ """Format messages according to phi-4's chat template and dataset config.
395
+ Only formats the conversation structure, preserves the actual content."""
396
  formatted_chat = ""
397
 
398
  # Get role templates from config
 
408
  logger.warning(f"Skipping invalid message format: {message}")
409
  continue
410
 
411
+ content = message.get("content", "") # Don't strip() - preserve exact content
412
 
413
  # Skip empty content
414
  if not content:
415
  continue
416
 
417
+ # Only add role prefixes based on position/content
418
  if "[RESEARCH INTRODUCTION]" in content:
419
  # System message
420
  template = roles.get("system", "System: {content}\n\n")
 
430
  template = roles.get("assistant", "Assistant: {content}\n\n")
431
  formatted_chat += template.format(content=content)
432
 
433
+ return formatted_chat
434
 
435
  class SimpleDataCollator:
436
  def __init__(self, tokenizer, dataset_config):
 
460
  self.stats["skipped"] += 1
461
  continue
462
 
463
+ # Get the pre-tokenized content directly
464
+ # The content should already be properly tokenized and formatted
465
+ content = conversations[0].get("content", "")
466
+ if not content:
467
+ logger.warning(f"Empty content for paper_id {paper_id}, prompt {prompt_num}")
 
468
  self.stats["skipped"] += 1
469
  continue
470
 
471
+ # Convert string of numbers to list of integers if needed
472
+ if isinstance(content, str):
473
+ try:
474
+ # Assuming content is space-separated numbers
475
+ input_ids = [int(x) for x in content.split()]
476
+ except ValueError:
477
+ logger.warning(f"Invalid pre-tokenized content format for paper_id {paper_id}, prompt {prompt_num}")
478
+ self.stats["skipped"] += 1
479
+ continue
480
+ else:
481
+ input_ids = content
482
 
483
  # Truncate if needed
484
  if len(input_ids) > self.max_seq_length:
 
498
 
499
  # Log first few examples for verification
500
  if self.stats["processed"] <= 3:
501
+ logger.info(f"Sample {self.stats['processed']} token count: {len(input_ids)}")
 
502
 
503
  except Exception as e:
504
  logger.warning(f"Error processing example {paper_id}, prompt {prompt_num}: {str(e)}")