George-API commited on
Commit
71642d9
·
verified ·
1 Parent(s): 90530d1

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. requirements.txt +1 -1
  2. run_transformers_training.py +44 -158
  3. update_space.py +12 -4
requirements.txt CHANGED
@@ -4,7 +4,6 @@ bitsandbytes>=0.41.0
4
  datasets>=2.15.0
5
  einops>=0.7.0
6
  filelock>=3.13.1
7
- flash-attn==2.5.2
8
  gradio>=5.17.0
9
  huggingface-hub>=0.19.0
10
  matplotlib>=3.7.0
@@ -23,3 +22,4 @@ tqdm>=4.65.0
23
  transformers>=4.36.0
24
  typing-extensions>=4.8.0
25
  unsloth>=2024.3
 
 
4
  datasets>=2.15.0
5
  einops>=0.7.0
6
  filelock>=3.13.1
 
7
  gradio>=5.17.0
8
  huggingface-hub>=0.19.0
9
  matplotlib>=3.7.0
 
22
  transformers>=4.36.0
23
  typing-extensions>=4.8.0
24
  unsloth>=2024.3
25
+ flash-attn==2.5.2
run_transformers_training.py CHANGED
@@ -158,38 +158,13 @@ def load_model_and_tokenizer(config):
158
 
159
  logger.info("Using Unsloth optimizations with pre-quantized model")
160
 
161
- # Check for flash attention
162
- use_flash_attention = config.get("use_flash_attention", True)
163
- if use_flash_attention and not find_spec("flash_attn"):
164
- logger.warning("flash-attn not found. Will continue without flash attention.")
165
- logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
166
- use_flash_attention = False
167
-
168
  # First detect if we have a GPU
169
  if torch.cuda.is_available():
170
  gpu_count = torch.cuda.device_count()
171
- logger.info(f"CUDA available, found {gpu_count} GPU(s)")
172
-
173
- # Log GPU info
174
- for i in range(gpu_count):
175
- logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
176
- logger.info(f"Memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.2f} GB")
177
-
178
- # Create an optimized device map for better balance
179
- if gpu_count > 1:
180
- logger.info(f"Creating balanced device map for {gpu_count} GPUs")
181
- # Use auto mapping but with memory tracking
182
- device_map = "auto"
183
- # Set max memory for better balancing
184
- max_memory = {i: f"{int(torch.cuda.get_device_properties(i).total_memory * 0.85 / 1024**3)}GiB" for i in range(gpu_count)}
185
- logger.info(f"Max memory settings: {max_memory}")
186
- else:
187
- device_map = "auto"
188
- max_memory = None
189
  else:
190
- logger.warning("No CUDA available, falling back to CPU")
191
- device_map = {"": "cpu"} # Force CPU placement
192
- max_memory = None
193
 
194
  # Set default dtype for better numerics
195
  if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
@@ -205,6 +180,13 @@ def load_model_and_tokenizer(config):
205
  dtype = None
206
  logger.info("Using default precision (CPU)")
207
 
 
 
 
 
 
 
 
208
  # Load model with proper error handling for out-of-memory
209
  try:
210
  # Improved memory settings for multi-GPU setup
@@ -300,6 +282,16 @@ def load_dataset_with_mapping(dataset_config):
300
  else:
301
  logger.info(f"Dataset has all required fields: {required_fields}")
302
 
 
 
 
 
 
 
 
 
 
 
303
  # Log a few samples for verification
304
  if len(dataset) > 0:
305
  sample_indices = range(min(5, len(dataset)))
@@ -524,54 +516,15 @@ class LoggingCallback(TrainerCallback):
524
  log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===")
525
  log_info(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
526
 
527
- # Set up sequence verification with actual sample capturing
528
- try:
529
- self.verify_sequence = dataset_config.get("validation", {}).get("verify_sequence_integrity", False)
530
- if self.verify_sequence:
531
- log_info("Sequence integrity verification enabled during training")
532
-
533
- # Save actual samples for later verification
534
- if trainer and hasattr(trainer, 'train_dataset') and trainer.train_dataset is not None:
535
- # Get some reference samples from the beginning of the dataset defensively
536
- self.sample_indices = []
537
- self.sequence_samples = []
538
-
539
- max_samples = min(5, len(trainer.train_dataset))
540
- for i in range(max_samples):
541
- try:
542
- if i < len(trainer.train_dataset):
543
- self.sample_indices.append(i)
544
- self.sequence_samples.append(trainer.train_dataset[i])
545
- except Exception as e:
546
- log_info(f"Warning: Error capturing reference sample at index {i}: {e}")
547
-
548
- if self.sequence_samples:
549
- log_info(f"Captured {len(self.sequence_samples)} reference samples for sequence integrity verification")
550
-
551
- # Log sample prompt numbers for debugging
552
- sample_prompt_numbers = []
553
- for s in self.sequence_samples:
554
- if isinstance(s, dict) and 'prompt_number' in s and s['prompt_number'] is not None:
555
- sample_prompt_numbers.append(s.get('prompt_number'))
556
-
557
- if sample_prompt_numbers:
558
- log_info(f"Reference sample prompt numbers: {sample_prompt_numbers}")
559
- if sample_prompt_numbers == list(range(1, len(sample_prompt_numbers) + 1)):
560
- log_info("Prompt numbers are sequential (1-indexed) - sequence integrity confirmed")
561
- else:
562
- log_info("Prompt numbers are not in expected sequence - will verify during training")
563
- else:
564
- log_info("Warning: No reference samples were captured")
565
- else:
566
- log_info("Warning: Could not capture reference samples - verification will be limited")
567
- except Exception as e:
568
- log_info(f"Warning: Could not set up sequence integrity verification: {e}")
569
- self.verify_sequence = False
570
 
571
  log_info("=== Training is starting ===")
572
 
573
  # Log important training parameters for visibility
574
  total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS
 
 
575
  log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total")
576
  log_info(f"Learning rate: {args.learning_rate}")
577
  log_info(f"Epochs: {args.num_train_epochs}")
@@ -585,90 +538,12 @@ class LoggingCallback(TrainerCallback):
585
  memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)")
586
 
587
  log_info(f"Initial memory usage - {', '.join(memory_info)}")
588
-
589
  def on_step_end(self, args, state, control, **kwargs):
590
  # Log every 50 steps or every 5 minutes, whichever comes first
591
  current_time = time.time()
592
 
593
- # Perform actual sequence integrity verification if enabled
594
- if self.verify_sequence is True and state.global_step % 100 == 0 and self.sequence_samples:
595
- try:
596
- # Get a batch of data without disturbing the training
597
- train_dataloader = trainer.get_train_dataloader()
598
- if train_dataloader is None:
599
- log_info("Warning: Could not get train dataloader for verification")
600
- else:
601
- batch_iterator = iter(train_dataloader)
602
- if batch_iterator is None:
603
- log_info("Warning: Could not get batch iterator for verification")
604
- else:
605
- try:
606
- batch = next(batch_iterator)
607
- if batch is None:
608
- log_info("Warning: Could not get batch for verification")
609
- elif 'input_ids' in batch and 'labels' in batch:
610
- log_info("Verifying data sequence integrity...")
611
-
612
- # Check if we can access some of our reference samples
613
- if not hasattr(trainer, 'train_dataset') or trainer.train_dataset is None:
614
- log_info("Warning: Train dataset is not available")
615
- else:
616
- # Get current samples defensively
617
- current_samples = []
618
- current_indices = list(range(min(3, len(trainer.train_dataset))))
619
-
620
- for idx in current_indices:
621
- try:
622
- if idx < len(trainer.train_dataset):
623
- current_samples.append(trainer.train_dataset[idx])
624
- except Exception as e:
625
- log_info(f"Warning: Error accessing dataset at index {idx}: {e}")
626
-
627
- # Only proceed if we have samples to compare
628
- if current_samples and self.sequence_samples:
629
- # Compare current samples with our reference samples from training start
630
- is_sequence_maintained = True
631
-
632
- for i, (orig_idx, orig_sample) in enumerate(zip(self.sample_indices, self.sequence_samples)):
633
- # Check if sample index is valid
634
- if i < len(current_samples):
635
- current_sample = current_samples[i]
636
-
637
- # Compare prompt numbers if available - this is our primary check now
638
- if ('prompt_number' in orig_sample and
639
- 'prompt_number' in current_sample and
640
- orig_sample['prompt_number'] is not None and
641
- current_sample['prompt_number'] is not None):
642
-
643
- if orig_sample['prompt_number'] != current_sample['prompt_number']:
644
- log_info(f"WARNING: Sequence integrity compromised! Sample {i} prompt number changed from {orig_sample['prompt_number']} to {current_sample['prompt_number']}")
645
- is_sequence_maintained = False
646
- else:
647
- # This is now our primary verification
648
- log_info(f"Prompt number match confirmed for sample {i}: {orig_sample['prompt_number']}")
649
-
650
- # Also compare article_id as a backup check
651
- elif ('article_id' in orig_sample and
652
- 'article_id' in current_sample and
653
- orig_sample['article_id'] is not None and
654
- current_sample['article_id'] is not None):
655
-
656
- if orig_sample['article_id'] != current_sample['article_id']:
657
- log_info(f"WARNING: Sequence integrity compromised! Sample {i} article_id changed from {orig_sample['article_id']} to {current_sample['article_id']}")
658
- is_sequence_maintained = False
659
-
660
- if is_sequence_maintained:
661
- log_info("Data sequence integrity check: OK - prompt numbers preserved")
662
- else:
663
- log_info("CRITICAL WARNING: Data sequence integrity check FAILED!")
664
- else:
665
- log_info("Warning: Not enough samples available for sequence verification")
666
- except StopIteration:
667
- log_info("Warning: No batches available in the dataloader")
668
- except Exception as e:
669
- log_info(f"Warning: Error iterating through dataloader: {e}")
670
- except Exception as e:
671
- log_info(f"Warning: Couldn't verify sequence integrity: {e}")
672
 
673
  # Log progress at regular intervals
674
  if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300):
@@ -708,13 +583,6 @@ def check_dependencies():
708
  if not peft_available:
709
  missing_packages.append("peft>=0.9.0")
710
 
711
- # Optional packages - don't add to missing list, just log
712
- if find_spec("flash_attn"):
713
- logger.info("flash-attn found. Flash attention will be used for faster training.")
714
- else:
715
- logger.warning("flash-attn not found. Training will work but may be slower.")
716
- logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
717
-
718
  # If critical packages are missing, exit with instructions
719
  if missing_packages:
720
  logger.error("Critical dependencies missing:")
@@ -723,6 +591,13 @@ def check_dependencies():
723
  logger.error("Please ensure the space has these packages in requirements.txt")
724
  return False
725
 
 
 
 
 
 
 
 
726
  return True
727
 
728
  def main():
@@ -934,6 +809,17 @@ def main():
934
 
935
  # Log our approach clearly
936
  log_info("Using SequentialSampler to guarantee dataset order is preserved based on prompt_number")
 
 
 
 
 
 
 
 
 
 
 
937
  log_info("Dataset is pre-processed with prompt_number field indicating the correct sequence")
938
 
939
  # Calculate batch size based on device availability
 
158
 
159
  logger.info("Using Unsloth optimizations with pre-quantized model")
160
 
 
 
 
 
 
 
 
161
  # First detect if we have a GPU
162
  if torch.cuda.is_available():
163
  gpu_count = torch.cuda.device_count()
164
+ logger.info(f"Found {gpu_count} CUDA devices")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  else:
166
+ logger.warning("No CUDA devices detected. Training will be slow on CPU!")
167
+ gpu_count = 0
 
168
 
169
  # Set default dtype for better numerics
170
  if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
 
180
  dtype = None
181
  logger.info("Using default precision (CPU)")
182
 
183
+ # Check for flash attention as the last dependency check
184
+ use_flash_attention = config.get("use_flash_attention", True)
185
+ if use_flash_attention and not find_spec("flash_attn"):
186
+ logger.warning("flash-attn not found. Will continue without flash attention.")
187
+ logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
188
+ use_flash_attention = False
189
+
190
  # Load model with proper error handling for out-of-memory
191
  try:
192
  # Improved memory settings for multi-GPU setup
 
282
  else:
283
  logger.info(f"Dataset has all required fields: {required_fields}")
284
 
285
+ # Verify that column order matches our expectation
286
+ expected_order = ["prompt_number", "article_id", "conversations"]
287
+ actual_order = dataset.column_names
288
+
289
+ if actual_order == expected_order:
290
+ logger.info("Dataset column order matches expected order (prompt_number, article_id, conversations)")
291
+ else:
292
+ logger.warning(f"Dataset column order ({', '.join(actual_order)}) differs from expected order ({', '.join(expected_order)})")
293
+ logger.warning("This should not affect processing but is noted for debugging purposes")
294
+
295
  # Log a few samples for verification
296
  if len(dataset) > 0:
297
  sample_indices = range(min(5, len(dataset)))
 
516
  log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===")
517
  log_info(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
518
 
519
+ # Disable sequence verification
520
+ self.verify_sequence = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
 
522
  log_info("=== Training is starting ===")
523
 
524
  # Log important training parameters for visibility
525
  total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS
526
+ total_steps = int(len(dataset) / (args.per_device_train_batch_size * NUM_GPUS * args.gradient_accumulation_steps) * args.num_train_epochs)
527
+ log_info(f"Training plan: {len(dataset)} examples over {args.num_train_epochs} epochs ≈ {total_steps} steps")
528
  log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total")
529
  log_info(f"Learning rate: {args.learning_rate}")
530
  log_info(f"Epochs: {args.num_train_epochs}")
 
538
  memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)")
539
 
540
  log_info(f"Initial memory usage - {', '.join(memory_info)}")
541
+
542
  def on_step_end(self, args, state, control, **kwargs):
543
  # Log every 50 steps or every 5 minutes, whichever comes first
544
  current_time = time.time()
545
 
546
+ # Sequence verification removed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
 
548
  # Log progress at regular intervals
549
  if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300):
 
583
  if not peft_available:
584
  missing_packages.append("peft>=0.9.0")
585
 
 
 
 
 
 
 
 
586
  # If critical packages are missing, exit with instructions
587
  if missing_packages:
588
  logger.error("Critical dependencies missing:")
 
591
  logger.error("Please ensure the space has these packages in requirements.txt")
592
  return False
593
 
594
+ # Optional packages - moved to the end
595
+ if find_spec("flash_attn"):
596
+ logger.info("flash-attn found. Flash attention will be used for faster training.")
597
+ else:
598
+ logger.warning("flash-attn not found. Training will work but may be slower.")
599
+ logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
600
+
601
  return True
602
 
603
  def main():
 
809
 
810
  # Log our approach clearly
811
  log_info("Using SequentialSampler to guarantee dataset order is preserved based on prompt_number")
812
+
813
+ # Verify column order
814
+ expected_order = ["prompt_number", "article_id", "conversations"]
815
+ if hasattr(dataset, 'column_names'):
816
+ actual_order = dataset.column_names
817
+ if actual_order == expected_order:
818
+ log_info(f"Confirmed dataset columns are in expected order: {', '.join(expected_order)}")
819
+ else:
820
+ log_info(f"Note: Dataset columns ({', '.join(actual_order)}) are not in expected order ({', '.join(expected_order)})")
821
+ log_info("This is handled correctly by field-based access, but noting for clarity")
822
+
823
  log_info("Dataset is pre-processed with prompt_number field indicating the correct sequence")
824
 
825
  # Calculate batch size based on device availability
update_space.py CHANGED
@@ -121,17 +121,25 @@ def update_requirements():
121
  # Add new requirements
122
  updated_requirements = existing_requirements.union(required_packages)
123
 
124
- # Write updated requirements with torch first
125
  with open(req_path, 'w') as f:
126
  # Ensure torch is first
127
  torch_req = next((req for req in updated_requirements if req.startswith("torch")), "torch>=2.0.0")
128
  f.write(f"{torch_req}\n")
129
 
130
- # Write remaining requirements
131
- for req in sorted(r for r in updated_requirements if not r.startswith("torch")):
 
 
 
 
132
  f.write(f"{req}\n")
 
 
 
 
133
 
134
- logger.info("Updated requirements.txt with necessary packages (torch listed first)")
135
 
136
  def create_space(username, space_name):
137
  """Create or get a Hugging Face Space."""
 
121
  # Add new requirements
122
  updated_requirements = existing_requirements.union(required_packages)
123
 
124
+ # Write updated requirements with torch first and flash-attn last
125
  with open(req_path, 'w') as f:
126
  # Ensure torch is first
127
  torch_req = next((req for req in updated_requirements if req.startswith("torch")), "torch>=2.0.0")
128
  f.write(f"{torch_req}\n")
129
 
130
+ # Extract flash-attn to add it last
131
+ flash_attn_req = next((req for req in updated_requirements if req.startswith("flash-attn")), None)
132
+
133
+ # Write all other requirements (excluding torch and flash-attn)
134
+ for req in sorted(r for r in updated_requirements
135
+ if not r.startswith("torch") and not r.startswith("flash-attn")):
136
  f.write(f"{req}\n")
137
+
138
+ # Add flash-attn as the very last package
139
+ if flash_attn_req:
140
+ f.write(f"{flash_attn_req}\n")
141
 
142
+ logger.info("Updated requirements.txt with torch listed first and flash-attn listed last")
143
 
144
  def create_space(username, space_name):
145
  """Create or get a Hugging Face Space."""