Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- run_transformers_training.py +113 -24
run_transformers_training.py
CHANGED
@@ -337,6 +337,31 @@ def load_dataset_with_mapping(dataset_config):
|
|
337 |
if len(dataset) == 0:
|
338 |
raise ValueError(f"Dataset {dataset_name} (split {dataset_split}) is empty (contains 0 examples)")
|
339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
except Exception as dataset_error:
|
341 |
logger.error(f"Failed to load dataset {dataset_name}: {str(dataset_error)}")
|
342 |
logger.error("Make sure the dataset exists and you have proper access permissions")
|
@@ -478,32 +503,59 @@ class SimpleDataCollator:
|
|
478 |
for example in features:
|
479 |
try:
|
480 |
# Get ID
|
481 |
-
paper_id = example.get("id", "")
|
482 |
|
483 |
-
# Get conversations
|
484 |
-
|
485 |
-
if not
|
|
|
486 |
self.stats["skipped"] += 1
|
487 |
continue
|
488 |
|
489 |
-
#
|
490 |
-
# This
|
491 |
try:
|
492 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
inputs = self.tokenizer.apply_chat_template(
|
494 |
-
|
495 |
return_tensors=None,
|
496 |
add_generation_prompt=False
|
497 |
)
|
498 |
except Exception as chat_error:
|
499 |
# Fallback if apply_chat_template fails
|
500 |
-
logger.warning(f"Chat template application failed for example {paper_id}: {str(chat_error)
|
501 |
|
502 |
-
# Create a basic representation of the
|
503 |
conversation_text = ""
|
504 |
-
for msg in
|
505 |
if isinstance(msg, dict) and 'content' in msg:
|
506 |
-
conversation_text += msg
|
|
|
|
|
507 |
|
508 |
# Basic tokenization
|
509 |
inputs = self.tokenizer(
|
@@ -537,7 +589,7 @@ class SimpleDataCollator:
|
|
537 |
logger.info(f"Example {self.stats['processed']}:")
|
538 |
logger.info(f"Paper ID: {paper_id}")
|
539 |
logger.info(f"Token count: {len(inputs)}")
|
540 |
-
logger.info(f"Conversation entries: {len(
|
541 |
else:
|
542 |
self.stats["skipped"] += 1
|
543 |
except Exception as e:
|
@@ -1004,6 +1056,14 @@ def main():
|
|
1004 |
"""Custom dataloader that preserves original dataset order"""
|
1005 |
log_info("Creating sequential dataloader to maintain original dataset order")
|
1006 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1007 |
# Create a simple sequential sampler
|
1008 |
sequential_sampler = torch.utils.data.SequentialSampler(dataset)
|
1009 |
|
@@ -1018,10 +1078,16 @@ def main():
|
|
1018 |
# Log our approach clearly
|
1019 |
log_info("Using SequentialSampler to guarantee dataset order is preserved based on prompt_number")
|
1020 |
|
1021 |
-
# Verify column order
|
1022 |
expected_order = ["prompt_number", "article_id", "conversations"]
|
1023 |
if hasattr(dataset, 'column_names'):
|
1024 |
actual_order = dataset.column_names
|
|
|
|
|
|
|
|
|
|
|
|
|
1025 |
if actual_order == expected_order:
|
1026 |
log_info(f"Confirmed dataset columns are in expected order: {', '.join(expected_order)}")
|
1027 |
else:
|
@@ -1030,6 +1096,16 @@ def main():
|
|
1030 |
|
1031 |
log_info("Dataset is pre-processed with prompt_number field indicating the correct sequence")
|
1032 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1033 |
# Calculate batch size based on device availability
|
1034 |
if getattr(training_args, "no_cuda", False):
|
1035 |
batch_size = training_args.per_device_train_batch_size
|
@@ -1038,16 +1114,29 @@ def main():
|
|
1038 |
|
1039 |
log_info(f"Using sequential sampler with batch size {batch_size}")
|
1040 |
|
1041 |
-
# Return DataLoader with sequential sampler
|
1042 |
-
|
1043 |
-
|
1044 |
-
|
1045 |
-
|
1046 |
-
|
1047 |
-
|
1048 |
-
|
1049 |
-
|
1050 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1051 |
|
1052 |
# Override the get_train_dataloader method
|
1053 |
trainer.get_train_dataloader = custom_get_train_dataloader
|
|
|
337 |
if len(dataset) == 0:
|
338 |
raise ValueError(f"Dataset {dataset_name} (split {dataset_split}) is empty (contains 0 examples)")
|
339 |
|
340 |
+
# Verify conversations field specifically - this is critical for training
|
341 |
+
if "conversations" not in dataset.column_names:
|
342 |
+
raise ValueError(f"Dataset {dataset_name} missing required 'conversations' column")
|
343 |
+
|
344 |
+
# Check a sample of conversation entries to validate structure
|
345 |
+
logger.info("Validating conversation structure...")
|
346 |
+
for i in range(min(5, len(dataset))):
|
347 |
+
conv = dataset[i].get("conversations")
|
348 |
+
if conv is None:
|
349 |
+
logger.warning(f"Example {i} has None as 'conversations' value")
|
350 |
+
elif not isinstance(conv, list):
|
351 |
+
logger.warning(f"Example {i} has non-list 'conversations': {type(conv)}")
|
352 |
+
elif len(conv) == 0:
|
353 |
+
logger.warning(f"Example {i} has empty conversations list")
|
354 |
+
else:
|
355 |
+
# Look at the first conversation entry
|
356 |
+
first_entry = conv[0]
|
357 |
+
logger.info(f"Sample conversation: {str(first_entry)[:100]}...")
|
358 |
+
|
359 |
+
# Make sure content field exists
|
360 |
+
if isinstance(first_entry, dict) and "content" in first_entry:
|
361 |
+
logger.info(f"Content field example: {str(first_entry['content'])[:50]}...")
|
362 |
+
else:
|
363 |
+
logger.warning(f"Example {i} missing 'content' key in conversation")
|
364 |
+
|
365 |
except Exception as dataset_error:
|
366 |
logger.error(f"Failed to load dataset {dataset_name}: {str(dataset_error)}")
|
367 |
logger.error("Make sure the dataset exists and you have proper access permissions")
|
|
|
503 |
for example in features:
|
504 |
try:
|
505 |
# Get ID
|
506 |
+
paper_id = example.get("article_id", example.get("id", ""))
|
507 |
|
508 |
+
# Get conversations
|
509 |
+
raw_conversations = example.get("conversations", [])
|
510 |
+
if not raw_conversations:
|
511 |
+
logger.warning(f"Empty conversations for example {paper_id}")
|
512 |
self.stats["skipped"] += 1
|
513 |
continue
|
514 |
|
515 |
+
# Extract only the 'content' field from each conversation item
|
516 |
+
# This simplifies the structure and avoids potential NoneType errors
|
517 |
try:
|
518 |
+
# Convert conversations to the simple format with only content
|
519 |
+
simplified_conversations = []
|
520 |
+
for item in raw_conversations:
|
521 |
+
if isinstance(item, dict) and "content" in item:
|
522 |
+
# Keep only the content field
|
523 |
+
content = item["content"]
|
524 |
+
simplified_conversations.append({"role": "user", "content": content})
|
525 |
+
elif isinstance(item, str):
|
526 |
+
# If it's just a string, treat it as content
|
527 |
+
simplified_conversations.append({"role": "user", "content": item})
|
528 |
+
else:
|
529 |
+
logger.warning(f"Skipping invalid conversation item: {item}")
|
530 |
+
|
531 |
+
# Skip if no valid conversations after filtering
|
532 |
+
if not simplified_conversations:
|
533 |
+
logger.warning(f"No valid conversations after filtering for example {paper_id}")
|
534 |
+
self.stats["skipped"] += 1
|
535 |
+
continue
|
536 |
+
|
537 |
+
# Log the simplified content for debugging
|
538 |
+
if len(simplified_conversations) > 0:
|
539 |
+
first_content = simplified_conversations[0]["content"]
|
540 |
+
logger.debug(f"First content: {first_content[:50]}...")
|
541 |
+
|
542 |
+
# Let tokenizer handle the simplified conversations
|
543 |
inputs = self.tokenizer.apply_chat_template(
|
544 |
+
simplified_conversations,
|
545 |
return_tensors=None,
|
546 |
add_generation_prompt=False
|
547 |
)
|
548 |
except Exception as chat_error:
|
549 |
# Fallback if apply_chat_template fails
|
550 |
+
logger.warning(f"Chat template application failed for example {paper_id}: {str(chat_error)}")
|
551 |
|
552 |
+
# Create a basic representation of just the content
|
553 |
conversation_text = ""
|
554 |
+
for msg in raw_conversations:
|
555 |
if isinstance(msg, dict) and 'content' in msg:
|
556 |
+
conversation_text += msg['content'] + "\n\n"
|
557 |
+
elif isinstance(msg, str):
|
558 |
+
conversation_text += msg + "\n\n"
|
559 |
|
560 |
# Basic tokenization
|
561 |
inputs = self.tokenizer(
|
|
|
589 |
logger.info(f"Example {self.stats['processed']}:")
|
590 |
logger.info(f"Paper ID: {paper_id}")
|
591 |
logger.info(f"Token count: {len(inputs)}")
|
592 |
+
logger.info(f"Conversation entries: {len(raw_conversations)}")
|
593 |
else:
|
594 |
self.stats["skipped"] += 1
|
595 |
except Exception as e:
|
|
|
1056 |
"""Custom dataloader that preserves original dataset order"""
|
1057 |
log_info("Creating sequential dataloader to maintain original dataset order")
|
1058 |
|
1059 |
+
# Safety check - make sure dataset exists and is not None
|
1060 |
+
if dataset is None:
|
1061 |
+
raise ValueError("Dataset is None - cannot create dataloader")
|
1062 |
+
|
1063 |
+
# Make sure dataset is not empty
|
1064 |
+
if len(dataset) == 0:
|
1065 |
+
raise ValueError("Dataset is empty - cannot create dataloader")
|
1066 |
+
|
1067 |
# Create a simple sequential sampler
|
1068 |
sequential_sampler = torch.utils.data.SequentialSampler(dataset)
|
1069 |
|
|
|
1078 |
# Log our approach clearly
|
1079 |
log_info("Using SequentialSampler to guarantee dataset order is preserved based on prompt_number")
|
1080 |
|
1081 |
+
# Verify column order and check for 'conversations' field
|
1082 |
expected_order = ["prompt_number", "article_id", "conversations"]
|
1083 |
if hasattr(dataset, 'column_names'):
|
1084 |
actual_order = dataset.column_names
|
1085 |
+
|
1086 |
+
# Verify all required fields exist
|
1087 |
+
missing_fields = [field for field in ["conversations"] if field not in actual_order]
|
1088 |
+
if missing_fields:
|
1089 |
+
raise ValueError(f"Dataset missing critical fields: {missing_fields}")
|
1090 |
+
|
1091 |
if actual_order == expected_order:
|
1092 |
log_info(f"Confirmed dataset columns are in expected order: {', '.join(expected_order)}")
|
1093 |
else:
|
|
|
1096 |
|
1097 |
log_info("Dataset is pre-processed with prompt_number field indicating the correct sequence")
|
1098 |
|
1099 |
+
# Validate a few samples before proceeding
|
1100 |
+
for i in range(min(3, len(dataset))):
|
1101 |
+
sample = dataset[i]
|
1102 |
+
if "conversations" not in sample:
|
1103 |
+
log_info(f"WARNING: Sample {i} missing 'conversations' field")
|
1104 |
+
elif sample["conversations"] is None:
|
1105 |
+
log_info(f"WARNING: Sample {i} has None 'conversations' field")
|
1106 |
+
elif not isinstance(sample["conversations"], list):
|
1107 |
+
log_info(f"WARNING: Sample {i} has non-list 'conversations' field: {type(sample['conversations'])}")
|
1108 |
+
|
1109 |
# Calculate batch size based on device availability
|
1110 |
if getattr(training_args, "no_cuda", False):
|
1111 |
batch_size = training_args.per_device_train_batch_size
|
|
|
1114 |
|
1115 |
log_info(f"Using sequential sampler with batch size {batch_size}")
|
1116 |
|
1117 |
+
# Return DataLoader with sequential sampler and extra error handling
|
1118 |
+
try:
|
1119 |
+
return torch.utils.data.DataLoader(
|
1120 |
+
dataset,
|
1121 |
+
batch_size=batch_size,
|
1122 |
+
sampler=sequential_sampler, # Always use sequential sampler
|
1123 |
+
collate_fn=data_collator,
|
1124 |
+
drop_last=training_args.dataloader_drop_last,
|
1125 |
+
num_workers=training_args.dataloader_num_workers,
|
1126 |
+
pin_memory=training_args.dataloader_pin_memory,
|
1127 |
+
)
|
1128 |
+
except Exception as e:
|
1129 |
+
log_info(f"Error creating DataLoader: {str(e)}")
|
1130 |
+
# Try again with minimal settings
|
1131 |
+
log_info("Attempting to create DataLoader with minimal settings")
|
1132 |
+
return torch.utils.data.DataLoader(
|
1133 |
+
dataset,
|
1134 |
+
batch_size=1, # Minimal batch size
|
1135 |
+
sampler=sequential_sampler,
|
1136 |
+
collate_fn=data_collator,
|
1137 |
+
num_workers=0, # No parallel workers
|
1138 |
+
pin_memory=False,
|
1139 |
+
)
|
1140 |
|
1141 |
# Override the get_train_dataloader method
|
1142 |
trainer.get_train_dataloader = custom_get_train_dataloader
|