Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- requirements.txt +1 -1
- run_transformers_training.py +44 -158
- update_space.py +12 -4
requirements.txt
CHANGED
@@ -4,7 +4,6 @@ bitsandbytes>=0.41.0
|
|
4 |
datasets>=2.15.0
|
5 |
einops>=0.7.0
|
6 |
filelock>=3.13.1
|
7 |
-
flash-attn==2.5.2
|
8 |
gradio>=5.17.0
|
9 |
huggingface-hub>=0.19.0
|
10 |
matplotlib>=3.7.0
|
@@ -23,3 +22,4 @@ tqdm>=4.65.0
|
|
23 |
transformers>=4.36.0
|
24 |
typing-extensions>=4.8.0
|
25 |
unsloth>=2024.3
|
|
|
|
4 |
datasets>=2.15.0
|
5 |
einops>=0.7.0
|
6 |
filelock>=3.13.1
|
|
|
7 |
gradio>=5.17.0
|
8 |
huggingface-hub>=0.19.0
|
9 |
matplotlib>=3.7.0
|
|
|
22 |
transformers>=4.36.0
|
23 |
typing-extensions>=4.8.0
|
24 |
unsloth>=2024.3
|
25 |
+
flash-attn==2.5.2
|
run_transformers_training.py
CHANGED
@@ -158,38 +158,13 @@ def load_model_and_tokenizer(config):
|
|
158 |
|
159 |
logger.info("Using Unsloth optimizations with pre-quantized model")
|
160 |
|
161 |
-
# Check for flash attention
|
162 |
-
use_flash_attention = config.get("use_flash_attention", True)
|
163 |
-
if use_flash_attention and not find_spec("flash_attn"):
|
164 |
-
logger.warning("flash-attn not found. Will continue without flash attention.")
|
165 |
-
logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
|
166 |
-
use_flash_attention = False
|
167 |
-
|
168 |
# First detect if we have a GPU
|
169 |
if torch.cuda.is_available():
|
170 |
gpu_count = torch.cuda.device_count()
|
171 |
-
logger.info(f"
|
172 |
-
|
173 |
-
# Log GPU info
|
174 |
-
for i in range(gpu_count):
|
175 |
-
logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
|
176 |
-
logger.info(f"Memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.2f} GB")
|
177 |
-
|
178 |
-
# Create an optimized device map for better balance
|
179 |
-
if gpu_count > 1:
|
180 |
-
logger.info(f"Creating balanced device map for {gpu_count} GPUs")
|
181 |
-
# Use auto mapping but with memory tracking
|
182 |
-
device_map = "auto"
|
183 |
-
# Set max memory for better balancing
|
184 |
-
max_memory = {i: f"{int(torch.cuda.get_device_properties(i).total_memory * 0.85 / 1024**3)}GiB" for i in range(gpu_count)}
|
185 |
-
logger.info(f"Max memory settings: {max_memory}")
|
186 |
-
else:
|
187 |
-
device_map = "auto"
|
188 |
-
max_memory = None
|
189 |
else:
|
190 |
-
logger.warning("No CUDA
|
191 |
-
|
192 |
-
max_memory = None
|
193 |
|
194 |
# Set default dtype for better numerics
|
195 |
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
|
@@ -205,6 +180,13 @@ def load_model_and_tokenizer(config):
|
|
205 |
dtype = None
|
206 |
logger.info("Using default precision (CPU)")
|
207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
# Load model with proper error handling for out-of-memory
|
209 |
try:
|
210 |
# Improved memory settings for multi-GPU setup
|
@@ -300,6 +282,16 @@ def load_dataset_with_mapping(dataset_config):
|
|
300 |
else:
|
301 |
logger.info(f"Dataset has all required fields: {required_fields}")
|
302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
# Log a few samples for verification
|
304 |
if len(dataset) > 0:
|
305 |
sample_indices = range(min(5, len(dataset)))
|
@@ -524,54 +516,15 @@ class LoggingCallback(TrainerCallback):
|
|
524 |
log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===")
|
525 |
log_info(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
|
526 |
|
527 |
-
#
|
528 |
-
|
529 |
-
self.verify_sequence = dataset_config.get("validation", {}).get("verify_sequence_integrity", False)
|
530 |
-
if self.verify_sequence:
|
531 |
-
log_info("Sequence integrity verification enabled during training")
|
532 |
-
|
533 |
-
# Save actual samples for later verification
|
534 |
-
if trainer and hasattr(trainer, 'train_dataset') and trainer.train_dataset is not None:
|
535 |
-
# Get some reference samples from the beginning of the dataset defensively
|
536 |
-
self.sample_indices = []
|
537 |
-
self.sequence_samples = []
|
538 |
-
|
539 |
-
max_samples = min(5, len(trainer.train_dataset))
|
540 |
-
for i in range(max_samples):
|
541 |
-
try:
|
542 |
-
if i < len(trainer.train_dataset):
|
543 |
-
self.sample_indices.append(i)
|
544 |
-
self.sequence_samples.append(trainer.train_dataset[i])
|
545 |
-
except Exception as e:
|
546 |
-
log_info(f"Warning: Error capturing reference sample at index {i}: {e}")
|
547 |
-
|
548 |
-
if self.sequence_samples:
|
549 |
-
log_info(f"Captured {len(self.sequence_samples)} reference samples for sequence integrity verification")
|
550 |
-
|
551 |
-
# Log sample prompt numbers for debugging
|
552 |
-
sample_prompt_numbers = []
|
553 |
-
for s in self.sequence_samples:
|
554 |
-
if isinstance(s, dict) and 'prompt_number' in s and s['prompt_number'] is not None:
|
555 |
-
sample_prompt_numbers.append(s.get('prompt_number'))
|
556 |
-
|
557 |
-
if sample_prompt_numbers:
|
558 |
-
log_info(f"Reference sample prompt numbers: {sample_prompt_numbers}")
|
559 |
-
if sample_prompt_numbers == list(range(1, len(sample_prompt_numbers) + 1)):
|
560 |
-
log_info("Prompt numbers are sequential (1-indexed) - sequence integrity confirmed")
|
561 |
-
else:
|
562 |
-
log_info("Prompt numbers are not in expected sequence - will verify during training")
|
563 |
-
else:
|
564 |
-
log_info("Warning: No reference samples were captured")
|
565 |
-
else:
|
566 |
-
log_info("Warning: Could not capture reference samples - verification will be limited")
|
567 |
-
except Exception as e:
|
568 |
-
log_info(f"Warning: Could not set up sequence integrity verification: {e}")
|
569 |
-
self.verify_sequence = False
|
570 |
|
571 |
log_info("=== Training is starting ===")
|
572 |
|
573 |
# Log important training parameters for visibility
|
574 |
total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS
|
|
|
|
|
575 |
log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total")
|
576 |
log_info(f"Learning rate: {args.learning_rate}")
|
577 |
log_info(f"Epochs: {args.num_train_epochs}")
|
@@ -585,90 +538,12 @@ class LoggingCallback(TrainerCallback):
|
|
585 |
memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)")
|
586 |
|
587 |
log_info(f"Initial memory usage - {', '.join(memory_info)}")
|
588 |
-
|
589 |
def on_step_end(self, args, state, control, **kwargs):
|
590 |
# Log every 50 steps or every 5 minutes, whichever comes first
|
591 |
current_time = time.time()
|
592 |
|
593 |
-
#
|
594 |
-
if self.verify_sequence is True and state.global_step % 100 == 0 and self.sequence_samples:
|
595 |
-
try:
|
596 |
-
# Get a batch of data without disturbing the training
|
597 |
-
train_dataloader = trainer.get_train_dataloader()
|
598 |
-
if train_dataloader is None:
|
599 |
-
log_info("Warning: Could not get train dataloader for verification")
|
600 |
-
else:
|
601 |
-
batch_iterator = iter(train_dataloader)
|
602 |
-
if batch_iterator is None:
|
603 |
-
log_info("Warning: Could not get batch iterator for verification")
|
604 |
-
else:
|
605 |
-
try:
|
606 |
-
batch = next(batch_iterator)
|
607 |
-
if batch is None:
|
608 |
-
log_info("Warning: Could not get batch for verification")
|
609 |
-
elif 'input_ids' in batch and 'labels' in batch:
|
610 |
-
log_info("Verifying data sequence integrity...")
|
611 |
-
|
612 |
-
# Check if we can access some of our reference samples
|
613 |
-
if not hasattr(trainer, 'train_dataset') or trainer.train_dataset is None:
|
614 |
-
log_info("Warning: Train dataset is not available")
|
615 |
-
else:
|
616 |
-
# Get current samples defensively
|
617 |
-
current_samples = []
|
618 |
-
current_indices = list(range(min(3, len(trainer.train_dataset))))
|
619 |
-
|
620 |
-
for idx in current_indices:
|
621 |
-
try:
|
622 |
-
if idx < len(trainer.train_dataset):
|
623 |
-
current_samples.append(trainer.train_dataset[idx])
|
624 |
-
except Exception as e:
|
625 |
-
log_info(f"Warning: Error accessing dataset at index {idx}: {e}")
|
626 |
-
|
627 |
-
# Only proceed if we have samples to compare
|
628 |
-
if current_samples and self.sequence_samples:
|
629 |
-
# Compare current samples with our reference samples from training start
|
630 |
-
is_sequence_maintained = True
|
631 |
-
|
632 |
-
for i, (orig_idx, orig_sample) in enumerate(zip(self.sample_indices, self.sequence_samples)):
|
633 |
-
# Check if sample index is valid
|
634 |
-
if i < len(current_samples):
|
635 |
-
current_sample = current_samples[i]
|
636 |
-
|
637 |
-
# Compare prompt numbers if available - this is our primary check now
|
638 |
-
if ('prompt_number' in orig_sample and
|
639 |
-
'prompt_number' in current_sample and
|
640 |
-
orig_sample['prompt_number'] is not None and
|
641 |
-
current_sample['prompt_number'] is not None):
|
642 |
-
|
643 |
-
if orig_sample['prompt_number'] != current_sample['prompt_number']:
|
644 |
-
log_info(f"WARNING: Sequence integrity compromised! Sample {i} prompt number changed from {orig_sample['prompt_number']} to {current_sample['prompt_number']}")
|
645 |
-
is_sequence_maintained = False
|
646 |
-
else:
|
647 |
-
# This is now our primary verification
|
648 |
-
log_info(f"Prompt number match confirmed for sample {i}: {orig_sample['prompt_number']}")
|
649 |
-
|
650 |
-
# Also compare article_id as a backup check
|
651 |
-
elif ('article_id' in orig_sample and
|
652 |
-
'article_id' in current_sample and
|
653 |
-
orig_sample['article_id'] is not None and
|
654 |
-
current_sample['article_id'] is not None):
|
655 |
-
|
656 |
-
if orig_sample['article_id'] != current_sample['article_id']:
|
657 |
-
log_info(f"WARNING: Sequence integrity compromised! Sample {i} article_id changed from {orig_sample['article_id']} to {current_sample['article_id']}")
|
658 |
-
is_sequence_maintained = False
|
659 |
-
|
660 |
-
if is_sequence_maintained:
|
661 |
-
log_info("Data sequence integrity check: OK - prompt numbers preserved")
|
662 |
-
else:
|
663 |
-
log_info("CRITICAL WARNING: Data sequence integrity check FAILED!")
|
664 |
-
else:
|
665 |
-
log_info("Warning: Not enough samples available for sequence verification")
|
666 |
-
except StopIteration:
|
667 |
-
log_info("Warning: No batches available in the dataloader")
|
668 |
-
except Exception as e:
|
669 |
-
log_info(f"Warning: Error iterating through dataloader: {e}")
|
670 |
-
except Exception as e:
|
671 |
-
log_info(f"Warning: Couldn't verify sequence integrity: {e}")
|
672 |
|
673 |
# Log progress at regular intervals
|
674 |
if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300):
|
@@ -708,13 +583,6 @@ def check_dependencies():
|
|
708 |
if not peft_available:
|
709 |
missing_packages.append("peft>=0.9.0")
|
710 |
|
711 |
-
# Optional packages - don't add to missing list, just log
|
712 |
-
if find_spec("flash_attn"):
|
713 |
-
logger.info("flash-attn found. Flash attention will be used for faster training.")
|
714 |
-
else:
|
715 |
-
logger.warning("flash-attn not found. Training will work but may be slower.")
|
716 |
-
logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
|
717 |
-
|
718 |
# If critical packages are missing, exit with instructions
|
719 |
if missing_packages:
|
720 |
logger.error("Critical dependencies missing:")
|
@@ -723,6 +591,13 @@ def check_dependencies():
|
|
723 |
logger.error("Please ensure the space has these packages in requirements.txt")
|
724 |
return False
|
725 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
726 |
return True
|
727 |
|
728 |
def main():
|
@@ -934,6 +809,17 @@ def main():
|
|
934 |
|
935 |
# Log our approach clearly
|
936 |
log_info("Using SequentialSampler to guarantee dataset order is preserved based on prompt_number")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
937 |
log_info("Dataset is pre-processed with prompt_number field indicating the correct sequence")
|
938 |
|
939 |
# Calculate batch size based on device availability
|
|
|
158 |
|
159 |
logger.info("Using Unsloth optimizations with pre-quantized model")
|
160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
# First detect if we have a GPU
|
162 |
if torch.cuda.is_available():
|
163 |
gpu_count = torch.cuda.device_count()
|
164 |
+
logger.info(f"Found {gpu_count} CUDA devices")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
else:
|
166 |
+
logger.warning("No CUDA devices detected. Training will be slow on CPU!")
|
167 |
+
gpu_count = 0
|
|
|
168 |
|
169 |
# Set default dtype for better numerics
|
170 |
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
|
|
|
180 |
dtype = None
|
181 |
logger.info("Using default precision (CPU)")
|
182 |
|
183 |
+
# Check for flash attention as the last dependency check
|
184 |
+
use_flash_attention = config.get("use_flash_attention", True)
|
185 |
+
if use_flash_attention and not find_spec("flash_attn"):
|
186 |
+
logger.warning("flash-attn not found. Will continue without flash attention.")
|
187 |
+
logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
|
188 |
+
use_flash_attention = False
|
189 |
+
|
190 |
# Load model with proper error handling for out-of-memory
|
191 |
try:
|
192 |
# Improved memory settings for multi-GPU setup
|
|
|
282 |
else:
|
283 |
logger.info(f"Dataset has all required fields: {required_fields}")
|
284 |
|
285 |
+
# Verify that column order matches our expectation
|
286 |
+
expected_order = ["prompt_number", "article_id", "conversations"]
|
287 |
+
actual_order = dataset.column_names
|
288 |
+
|
289 |
+
if actual_order == expected_order:
|
290 |
+
logger.info("Dataset column order matches expected order (prompt_number, article_id, conversations)")
|
291 |
+
else:
|
292 |
+
logger.warning(f"Dataset column order ({', '.join(actual_order)}) differs from expected order ({', '.join(expected_order)})")
|
293 |
+
logger.warning("This should not affect processing but is noted for debugging purposes")
|
294 |
+
|
295 |
# Log a few samples for verification
|
296 |
if len(dataset) > 0:
|
297 |
sample_indices = range(min(5, len(dataset)))
|
|
|
516 |
log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===")
|
517 |
log_info(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
|
518 |
|
519 |
+
# Disable sequence verification
|
520 |
+
self.verify_sequence = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
521 |
|
522 |
log_info("=== Training is starting ===")
|
523 |
|
524 |
# Log important training parameters for visibility
|
525 |
total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS
|
526 |
+
total_steps = int(len(dataset) / (args.per_device_train_batch_size * NUM_GPUS * args.gradient_accumulation_steps) * args.num_train_epochs)
|
527 |
+
log_info(f"Training plan: {len(dataset)} examples over {args.num_train_epochs} epochs ≈ {total_steps} steps")
|
528 |
log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total")
|
529 |
log_info(f"Learning rate: {args.learning_rate}")
|
530 |
log_info(f"Epochs: {args.num_train_epochs}")
|
|
|
538 |
memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)")
|
539 |
|
540 |
log_info(f"Initial memory usage - {', '.join(memory_info)}")
|
541 |
+
|
542 |
def on_step_end(self, args, state, control, **kwargs):
|
543 |
# Log every 50 steps or every 5 minutes, whichever comes first
|
544 |
current_time = time.time()
|
545 |
|
546 |
+
# Sequence verification removed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
547 |
|
548 |
# Log progress at regular intervals
|
549 |
if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300):
|
|
|
583 |
if not peft_available:
|
584 |
missing_packages.append("peft>=0.9.0")
|
585 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
586 |
# If critical packages are missing, exit with instructions
|
587 |
if missing_packages:
|
588 |
logger.error("Critical dependencies missing:")
|
|
|
591 |
logger.error("Please ensure the space has these packages in requirements.txt")
|
592 |
return False
|
593 |
|
594 |
+
# Optional packages - moved to the end
|
595 |
+
if find_spec("flash_attn"):
|
596 |
+
logger.info("flash-attn found. Flash attention will be used for faster training.")
|
597 |
+
else:
|
598 |
+
logger.warning("flash-attn not found. Training will work but may be slower.")
|
599 |
+
logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
|
600 |
+
|
601 |
return True
|
602 |
|
603 |
def main():
|
|
|
809 |
|
810 |
# Log our approach clearly
|
811 |
log_info("Using SequentialSampler to guarantee dataset order is preserved based on prompt_number")
|
812 |
+
|
813 |
+
# Verify column order
|
814 |
+
expected_order = ["prompt_number", "article_id", "conversations"]
|
815 |
+
if hasattr(dataset, 'column_names'):
|
816 |
+
actual_order = dataset.column_names
|
817 |
+
if actual_order == expected_order:
|
818 |
+
log_info(f"Confirmed dataset columns are in expected order: {', '.join(expected_order)}")
|
819 |
+
else:
|
820 |
+
log_info(f"Note: Dataset columns ({', '.join(actual_order)}) are not in expected order ({', '.join(expected_order)})")
|
821 |
+
log_info("This is handled correctly by field-based access, but noting for clarity")
|
822 |
+
|
823 |
log_info("Dataset is pre-processed with prompt_number field indicating the correct sequence")
|
824 |
|
825 |
# Calculate batch size based on device availability
|
update_space.py
CHANGED
@@ -121,17 +121,25 @@ def update_requirements():
|
|
121 |
# Add new requirements
|
122 |
updated_requirements = existing_requirements.union(required_packages)
|
123 |
|
124 |
-
# Write updated requirements with torch first
|
125 |
with open(req_path, 'w') as f:
|
126 |
# Ensure torch is first
|
127 |
torch_req = next((req for req in updated_requirements if req.startswith("torch")), "torch>=2.0.0")
|
128 |
f.write(f"{torch_req}\n")
|
129 |
|
130 |
-
#
|
131 |
-
|
|
|
|
|
|
|
|
|
132 |
f.write(f"{req}\n")
|
|
|
|
|
|
|
|
|
133 |
|
134 |
-
logger.info("Updated requirements.txt with
|
135 |
|
136 |
def create_space(username, space_name):
|
137 |
"""Create or get a Hugging Face Space."""
|
|
|
121 |
# Add new requirements
|
122 |
updated_requirements = existing_requirements.union(required_packages)
|
123 |
|
124 |
+
# Write updated requirements with torch first and flash-attn last
|
125 |
with open(req_path, 'w') as f:
|
126 |
# Ensure torch is first
|
127 |
torch_req = next((req for req in updated_requirements if req.startswith("torch")), "torch>=2.0.0")
|
128 |
f.write(f"{torch_req}\n")
|
129 |
|
130 |
+
# Extract flash-attn to add it last
|
131 |
+
flash_attn_req = next((req for req in updated_requirements if req.startswith("flash-attn")), None)
|
132 |
+
|
133 |
+
# Write all other requirements (excluding torch and flash-attn)
|
134 |
+
for req in sorted(r for r in updated_requirements
|
135 |
+
if not r.startswith("torch") and not r.startswith("flash-attn")):
|
136 |
f.write(f"{req}\n")
|
137 |
+
|
138 |
+
# Add flash-attn as the very last package
|
139 |
+
if flash_attn_req:
|
140 |
+
f.write(f"{flash_attn_req}\n")
|
141 |
|
142 |
+
logger.info("Updated requirements.txt with torch listed first and flash-attn listed last")
|
143 |
|
144 |
def create_space(username, space_name):
|
145 |
"""Create or get a Hugging Face Space."""
|