George-API commited on
Commit
1cf4e07
·
verified ·
1 Parent(s): ae57ea2

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. run_transformers_training.py +240 -69
run_transformers_training.py CHANGED
@@ -285,7 +285,7 @@ def load_model_and_tokenizer(config):
285
  raise
286
 
287
  def load_dataset_with_mapping(dataset_config):
288
- """Load and prepare dataset with proper column mapping."""
289
  try:
290
  # Load dataset
291
  dataset_name = dataset_config.get("dataset", {}).get("name", "")
@@ -319,6 +319,45 @@ def load_dataset_with_mapping(dataset_config):
319
  if source != target: # Only rename if names are different
320
  dataset = dataset.rename_column(source, target)
321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  # Verify expected columns exist
323
  expected_columns = {"id", "conversations"}
324
  for col in expected_columns:
@@ -369,40 +408,105 @@ def load_dataset_with_mapping(dataset_config):
369
 
370
  # Verify the IDs are in sequential order if they're numeric
371
  try:
372
- if len(dataset) > 1 and all(isinstance(example.get('id', ''), (int, str)) for example in dataset.select(range(min(10, len(dataset))))):
373
- sample_ids = [example['id'] for example in dataset.select(range(min(10, len(dataset))))]
374
- logger.info(f"Verifying sequential integrity with first few IDs: {sample_ids}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
- # Check if IDs are numeric and ordered
377
- if all(isinstance(id, int) or id.isdigit() for id in sample_ids):
378
- numeric_ids = [int(id) if isinstance(id, str) else id for id in sample_ids]
379
- is_ordered = all(numeric_ids[i] <= numeric_ids[i+1] for i in range(len(numeric_ids)-1))
380
- if not is_ordered:
381
- logger.warning("WARNING: Sample IDs are not in sequential order.")
382
  logger.warning("This may indicate that data sequence is not preserved.")
383
  else:
384
- logger.info("Sample IDs appear to be in sequential order.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  except Exception as e:
386
  logger.warning(f"Could not verify sequential integrity: {e}")
387
 
388
- # Log examples without printing full content
389
  if "conversations" in dataset.column_names:
390
- sample_ids = [example['id'] for example in dataset.select(range(min(5, len(dataset))))]
391
- logger.info(f"First few IDs: {sample_ids}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
 
393
- # Log conversation structure without full content
394
- if len(dataset) > 0:
395
- sample_conv_structure = []
396
- for msg in dataset["conversations"][0]:
397
- if isinstance(msg, dict):
398
- content = msg.get('content', '')
399
- preview = content[:50] + "..." if len(content) > 50 else content
400
- sample_conv_structure.append({
401
- "role": msg.get('role', ''),
402
- "content_length": len(content),
403
- "preview": preview
404
- })
405
- logger.info(f"Conversation structure: {sample_conv_structure}")
 
 
 
 
 
 
 
 
406
 
407
  logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
408
  logger.info(f"Dataset columns: {dataset.column_names}")
@@ -597,39 +701,88 @@ class LoggingCallback(TrainerCallback):
597
  if self.verify_sequence is True and state.global_step % 100 == 0 and self.sequence_samples:
598
  try:
599
  # Get a batch of data without disturbing the training
600
- batch = next(iter(trainer.get_train_dataloader()))
601
- if 'input_ids' in batch and 'labels' in batch:
602
- log_info("Verifying data sequence integrity...")
603
-
604
- # Check if we can access some of our reference samples
605
- current_indices = list(range(min(3, len(trainer.train_dataset))))
606
- current_samples = [trainer.train_dataset[i] for i in current_indices]
607
-
608
- # Compare current samples with our reference samples from training start
609
- is_sequence_maintained = True
610
- for i, (orig_idx, orig_sample) in enumerate(zip(self.sample_indices, self.sequence_samples)):
611
- # Check if sample IDs still match our reference
612
- if orig_idx < len(current_samples):
613
- current_sample = current_samples[i]
614
-
615
- # Compare IDs if available
616
- if 'id' in orig_sample and 'id' in current_sample:
617
- if orig_sample['id'] != current_sample['id']:
618
- log_info(f"WARNING: Sequence integrity compromised! Sample {i} ID changed from {orig_sample['id']} to {current_sample['id']}")
619
- is_sequence_maintained = False
620
-
621
- # Compare input fingerprints
622
- if 'conversations' in orig_sample and 'conversations' in current_sample:
623
- orig_len = len(orig_sample['conversations'])
624
- curr_len = len(current_sample['conversations'])
625
- if orig_len != curr_len:
626
- log_info(f"WARNING: Sequence integrity compromised! Sample {i} conversation length changed from {orig_len} to {curr_len}")
627
- is_sequence_maintained = False
628
-
629
- if is_sequence_maintained:
630
- log_info("Data sequence integrity check: OK")
631
  else:
632
- log_info("CRITICAL WARNING: Data sequence integrity check FAILED!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
  except Exception as e:
634
  log_info(f"Warning: Couldn't verify sequence integrity: {e}")
635
 
@@ -666,16 +819,33 @@ class LoggingCallback(TrainerCallback):
666
  log_info("Sequence integrity verification enabled during training")
667
 
668
  # Save actual samples for later verification
669
- if trainer and trainer.train_dataset:
670
- # Get some reference samples from the beginning of the dataset
671
- self.sample_indices = list(range(min(5, len(trainer.train_dataset))))
672
- self.sequence_samples = [trainer.train_dataset[i] for i in self.sample_indices]
673
- log_info(f"Captured {len(self.sequence_samples)} reference samples for sequence integrity verification")
 
 
 
 
 
 
 
 
674
 
675
- # Log sample IDs for debugging
676
- if len(self.sequence_samples) > 0 and 'id' in self.sequence_samples[0]:
677
- sample_ids = [s.get('id') for s in self.sequence_samples if 'id' in s]
678
- log_info(f"Reference sample IDs: {sample_ids}")
 
 
 
 
 
 
 
 
 
679
  else:
680
  log_info("Warning: Could not capture reference samples - verification will be limited")
681
  except Exception as e:
@@ -685,7 +855,8 @@ class LoggingCallback(TrainerCallback):
685
  log_info("=== Training is starting ===")
686
 
687
  # Log important training parameters for visibility
688
- log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {max(1, torch.cuda.device_count())} GPUs")
 
689
  log_info(f"Learning rate: {args.learning_rate}")
690
  log_info(f"Epochs: {args.num_train_epochs}")
691
 
 
285
  raise
286
 
287
  def load_dataset_with_mapping(dataset_config):
288
+ """Load dataset and apply appropriate column mappings."""
289
  try:
290
  # Load dataset
291
  dataset_name = dataset_config.get("dataset", {}).get("name", "")
 
319
  if source != target: # Only rename if names are different
320
  dataset = dataset.rename_column(source, target)
321
 
322
+ # Add prompt_number field that increments based on original order
323
+ def add_prompt_numbers(examples, indices):
324
+ # Defensive check to ensure indices is not None
325
+ if indices is None:
326
+ logger.warning("Warning: indices is None in add_prompt_numbers, using empty list")
327
+ indices = []
328
+
329
+ # Create a new field with the dataset index as the prompt number, starting at 1
330
+ examples["prompt_number"] = [idx + 1 for idx in indices] # Adding 1 to make it 1-indexed
331
+ return examples
332
+
333
+ # Add prompt numbers to the dataset based on original order
334
+ logger.info("Adding prompt numbers based on original dataset order (starting at 1)")
335
+ try:
336
+ dataset = dataset.map(
337
+ add_prompt_numbers,
338
+ with_indices=True,
339
+ desc="Adding prompt numbers"
340
+ )
341
+ logger.info(f"Successfully added prompt_number field to dataset")
342
+ except Exception as e:
343
+ logger.error(f"Error adding prompt numbers: {e}")
344
+ # Create a fallback implementation that doesn't rely on with_indices
345
+ logger.info("Attempting fallback method for adding prompt numbers")
346
+
347
+ def add_prompt_numbers_fallback(example, idx):
348
+ example["prompt_number"] = idx + 1
349
+ return example
350
+
351
+ # Process each example one by one with explicit indices
352
+ updated_examples = []
353
+ for i, example in enumerate(dataset):
354
+ updated_examples.append(add_prompt_numbers_fallback(dict(example), i))
355
+
356
+ # Create a new dataset with the updated examples
357
+ from datasets import Dataset
358
+ dataset = Dataset.from_list(updated_examples)
359
+ logger.info(f"Successfully added prompt_number field using fallback method")
360
+
361
  # Verify expected columns exist
362
  expected_columns = {"id", "conversations"}
363
  for col in expected_columns:
 
408
 
409
  # Verify the IDs are in sequential order if they're numeric
410
  try:
411
+ if len(dataset) > 1:
412
+ # Check prompt numbers are sequential
413
+ sample_indices = range(min(10, len(dataset)))
414
+ sample_prompt_numbers = []
415
+
416
+ # Defensive collection of prompt numbers
417
+ for i in sample_indices:
418
+ try:
419
+ if i < len(dataset) and "prompt_number" in dataset[i]:
420
+ sample_prompt_numbers.append(dataset[i]["prompt_number"])
421
+ else:
422
+ # If prompt_number doesn't exist, use index+1 as fallback
423
+ sample_prompt_numbers.append(i + 1)
424
+ logger.warning(f"Sample at index {i} missing prompt_number, using {i+1} as fallback")
425
+ except Exception as e:
426
+ logger.warning(f"Error accessing sample at index {i}: {e}")
427
+ sample_prompt_numbers.append(i + 1) # Use fallback
428
+
429
+ logger.info(f"Verifying sequential integrity with prompt numbers: {sample_prompt_numbers}")
430
 
431
+ # Check if prompt numbers are sequential (1-indexed)
432
+ if sample_prompt_numbers:
433
+ is_sequential = all(sample_prompt_numbers[i] == i + 1 for i in range(len(sample_prompt_numbers)))
434
+ if not is_sequential:
435
+ logger.warning("WARNING: Prompt numbers are not in sequential order!")
 
436
  logger.warning("This may indicate that data sequence is not preserved.")
437
  else:
438
+ logger.info("Prompt numbers verify that samples are in sequential order.")
439
+ else:
440
+ logger.warning("Could not verify sequential integrity: no prompt numbers collected")
441
+
442
+ # Also check original IDs as a backup if numeric
443
+ try:
444
+ sample_examples = []
445
+ for i in sample_indices:
446
+ try:
447
+ if i < len(dataset):
448
+ sample_examples.append(dataset[i])
449
+ except Exception as e:
450
+ logger.warning(f"Error accessing dataset at index {i}: {e}")
451
+
452
+ if sample_examples:
453
+ if all(isinstance(example.get('id', ''), (int, str)) for example in sample_examples):
454
+ sample_ids = [example.get('id', '') for example in sample_examples if 'id' in example]
455
+
456
+ if sample_ids and all(isinstance(id, int) or (isinstance(id, str) and id.isdigit()) for id in sample_ids):
457
+ numeric_ids = [int(id) if isinstance(id, str) else id for id in sample_ids]
458
+ if len(numeric_ids) > 1:
459
+ is_ordered = all(numeric_ids[i] <= numeric_ids[i+1] for i in range(len(numeric_ids)-1))
460
+ if not is_ordered:
461
+ logger.warning("WARNING: Sample IDs are not in sequential order.")
462
+ else:
463
+ logger.info("Sample IDs appear to be in sequential order.")
464
+ except Exception as e:
465
+ logger.warning(f"Error checking ID sequence: {e}")
466
  except Exception as e:
467
  logger.warning(f"Could not verify sequential integrity: {e}")
468
 
469
+ # Log examples without printing full content - with defensive coding
470
  if "conversations" in dataset.column_names:
471
+ try:
472
+ # Safely get first few samples
473
+ first_few_indices = range(min(5, len(dataset)))
474
+ sample_prompt_numbers = []
475
+ sample_ids = []
476
+
477
+ for i in first_few_indices:
478
+ try:
479
+ example = dataset[i]
480
+ if 'prompt_number' in example:
481
+ sample_prompt_numbers.append(example['prompt_number'])
482
+ if 'id' in example:
483
+ sample_ids.append(example['id'])
484
+ except Exception as e:
485
+ logger.warning(f"Error accessing sample at index {i}: {e}")
486
+
487
+ logger.info(f"First few samples - Prompt numbers: {sample_prompt_numbers}, IDs: {sample_ids}")
488
 
489
+ # Log conversation structure without full content
490
+ if len(dataset) > 0:
491
+ try:
492
+ sample_conv_structure = []
493
+ first_example = dataset[0]
494
+
495
+ if 'conversations' in first_example and first_example['conversations'] is not None:
496
+ for msg in first_example['conversations']:
497
+ if isinstance(msg, dict):
498
+ content = msg.get('content', '')
499
+ preview = content[:50] + "..." if len(content) > 50 else content
500
+ sample_conv_structure.append({
501
+ "role": msg.get('role', ''),
502
+ "content_length": len(content),
503
+ "preview": preview
504
+ })
505
+ logger.info(f"Conversation structure: {sample_conv_structure}")
506
+ except Exception as e:
507
+ logger.warning(f"Error logging conversation structure: {e}")
508
+ except Exception as e:
509
+ logger.warning(f"Error logging sample examples: {e}")
510
 
511
  logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
512
  logger.info(f"Dataset columns: {dataset.column_names}")
 
701
  if self.verify_sequence is True and state.global_step % 100 == 0 and self.sequence_samples:
702
  try:
703
  # Get a batch of data without disturbing the training
704
+ train_dataloader = trainer.get_train_dataloader()
705
+ if train_dataloader is None:
706
+ log_info("Warning: Could not get train dataloader for verification")
707
+ else:
708
+ batch_iterator = iter(train_dataloader)
709
+ if batch_iterator is None:
710
+ log_info("Warning: Could not get batch iterator for verification")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711
  else:
712
+ try:
713
+ batch = next(batch_iterator)
714
+ if batch is None:
715
+ log_info("Warning: Could not get batch for verification")
716
+ elif 'input_ids' in batch and 'labels' in batch:
717
+ log_info("Verifying data sequence integrity...")
718
+
719
+ # Check if we can access some of our reference samples
720
+ if not hasattr(trainer, 'train_dataset') or trainer.train_dataset is None:
721
+ log_info("Warning: Train dataset is not available")
722
+ else:
723
+ # Get current samples defensively
724
+ current_samples = []
725
+ current_indices = list(range(min(3, len(trainer.train_dataset))))
726
+
727
+ for idx in current_indices:
728
+ try:
729
+ if idx < len(trainer.train_dataset):
730
+ current_samples.append(trainer.train_dataset[idx])
731
+ except Exception as e:
732
+ log_info(f"Warning: Error accessing dataset at index {idx}: {e}")
733
+
734
+ # Only proceed if we have samples to compare
735
+ if current_samples and self.sequence_samples:
736
+ # Compare current samples with our reference samples from training start
737
+ is_sequence_maintained = True
738
+
739
+ for i, (orig_idx, orig_sample) in enumerate(zip(self.sample_indices, self.sequence_samples)):
740
+ # Check if sample index is valid
741
+ if i < len(current_samples):
742
+ current_sample = current_samples[i]
743
+
744
+ # Compare prompt numbers if available
745
+ if ('prompt_number' in orig_sample and
746
+ 'prompt_number' in current_sample and
747
+ orig_sample['prompt_number'] is not None and
748
+ current_sample['prompt_number'] is not None):
749
+
750
+ if orig_sample['prompt_number'] != current_sample['prompt_number']:
751
+ log_info(f"WARNING: Sequence integrity compromised! Sample {i} prompt number changed from {orig_sample['prompt_number']} to {current_sample['prompt_number']}")
752
+ is_sequence_maintained = False
753
+
754
+ # Also compare IDs as a backup check
755
+ elif ('id' in orig_sample and
756
+ 'id' in current_sample and
757
+ orig_sample['id'] is not None and
758
+ current_sample['id'] is not None):
759
+
760
+ if orig_sample['id'] != current_sample['id']:
761
+ log_info(f"WARNING: Sequence integrity compromised! Sample {i} ID changed from {orig_sample['id']} to {current_sample['id']}")
762
+ is_sequence_maintained = False
763
+
764
+ # Compare input fingerprints
765
+ if ('conversations' in orig_sample and
766
+ 'conversations' in current_sample and
767
+ orig_sample['conversations'] is not None and
768
+ current_sample['conversations'] is not None):
769
+
770
+ orig_len = len(orig_sample['conversations'])
771
+ curr_len = len(current_sample['conversations'])
772
+ if orig_len != curr_len:
773
+ log_info(f"WARNING: Sequence integrity compromised! Sample {i} conversation length changed from {orig_len} to {curr_len}")
774
+ is_sequence_maintained = False
775
+
776
+ if is_sequence_maintained:
777
+ log_info("Data sequence integrity check: OK")
778
+ else:
779
+ log_info("CRITICAL WARNING: Data sequence integrity check FAILED!")
780
+ else:
781
+ log_info("Warning: Not enough samples available for sequence verification")
782
+ except StopIteration:
783
+ log_info("Warning: No batches available in the dataloader")
784
+ except Exception as e:
785
+ log_info(f"Warning: Error iterating through dataloader: {e}")
786
  except Exception as e:
787
  log_info(f"Warning: Couldn't verify sequence integrity: {e}")
788
 
 
819
  log_info("Sequence integrity verification enabled during training")
820
 
821
  # Save actual samples for later verification
822
+ if trainer and hasattr(trainer, 'train_dataset') and trainer.train_dataset is not None:
823
+ # Get some reference samples from the beginning of the dataset defensively
824
+ self.sample_indices = []
825
+ self.sequence_samples = []
826
+
827
+ max_samples = min(5, len(trainer.train_dataset))
828
+ for i in range(max_samples):
829
+ try:
830
+ if i < len(trainer.train_dataset):
831
+ self.sample_indices.append(i)
832
+ self.sequence_samples.append(trainer.train_dataset[i])
833
+ except Exception as e:
834
+ log_info(f"Warning: Error capturing reference sample at index {i}: {e}")
835
 
836
+ if self.sequence_samples:
837
+ log_info(f"Captured {len(self.sequence_samples)} reference samples for sequence integrity verification")
838
+
839
+ # Log sample prompt numbers for debugging
840
+ sample_prompt_numbers = []
841
+ for s in self.sequence_samples:
842
+ if isinstance(s, dict) and 'prompt_number' in s and s['prompt_number'] is not None:
843
+ sample_prompt_numbers.append(s.get('prompt_number'))
844
+
845
+ if sample_prompt_numbers:
846
+ log_info(f"Reference sample prompt numbers: {sample_prompt_numbers}")
847
+ else:
848
+ log_info("Warning: No reference samples were captured")
849
  else:
850
  log_info("Warning: Could not capture reference samples - verification will be limited")
851
  except Exception as e:
 
855
  log_info("=== Training is starting ===")
856
 
857
  # Log important training parameters for visibility
858
+ total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS
859
+ log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total")
860
  log_info(f"Learning rate: {args.learning_rate}")
861
  log_info(f"Epochs: {args.num_train_epochs}")
862