George-API commited on
Commit
73ea801
·
verified ·
1 Parent(s): 93b2fec

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. run_transformers_training.py +113 -24
run_transformers_training.py CHANGED
@@ -337,6 +337,31 @@ def load_dataset_with_mapping(dataset_config):
337
  if len(dataset) == 0:
338
  raise ValueError(f"Dataset {dataset_name} (split {dataset_split}) is empty (contains 0 examples)")
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  except Exception as dataset_error:
341
  logger.error(f"Failed to load dataset {dataset_name}: {str(dataset_error)}")
342
  logger.error("Make sure the dataset exists and you have proper access permissions")
@@ -478,32 +503,59 @@ class SimpleDataCollator:
478
  for example in features:
479
  try:
480
  # Get ID
481
- paper_id = example.get("id", "")
482
 
483
- # Get conversations - these should already contain role and content
484
- conversations = example.get("conversations", [])
485
- if not conversations:
 
486
  self.stats["skipped"] += 1
487
  continue
488
 
489
- # Directly use the conversations array as input to the model's chat template
490
- # This preserves the exact structure with roles and content as they are
491
  try:
492
- # Let tokenizer handle the content with the model's chat template
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  inputs = self.tokenizer.apply_chat_template(
494
- conversations,
495
  return_tensors=None,
496
  add_generation_prompt=False
497
  )
498
  except Exception as chat_error:
499
  # Fallback if apply_chat_template fails
500
- logger.warning(f"Chat template application failed for example {paper_id}: {str(chat_error)[:100]}")
501
 
502
- # Create a basic representation of the conversation
503
  conversation_text = ""
504
- for msg in conversations:
505
  if isinstance(msg, dict) and 'content' in msg:
506
- conversation_text += msg.get('content', '') + "\n\n"
 
 
507
 
508
  # Basic tokenization
509
  inputs = self.tokenizer(
@@ -537,7 +589,7 @@ class SimpleDataCollator:
537
  logger.info(f"Example {self.stats['processed']}:")
538
  logger.info(f"Paper ID: {paper_id}")
539
  logger.info(f"Token count: {len(inputs)}")
540
- logger.info(f"Conversation entries: {len(conversations)}")
541
  else:
542
  self.stats["skipped"] += 1
543
  except Exception as e:
@@ -1004,6 +1056,14 @@ def main():
1004
  """Custom dataloader that preserves original dataset order"""
1005
  log_info("Creating sequential dataloader to maintain original dataset order")
1006
 
 
 
 
 
 
 
 
 
1007
  # Create a simple sequential sampler
1008
  sequential_sampler = torch.utils.data.SequentialSampler(dataset)
1009
 
@@ -1018,10 +1078,16 @@ def main():
1018
  # Log our approach clearly
1019
  log_info("Using SequentialSampler to guarantee dataset order is preserved based on prompt_number")
1020
 
1021
- # Verify column order
1022
  expected_order = ["prompt_number", "article_id", "conversations"]
1023
  if hasattr(dataset, 'column_names'):
1024
  actual_order = dataset.column_names
 
 
 
 
 
 
1025
  if actual_order == expected_order:
1026
  log_info(f"Confirmed dataset columns are in expected order: {', '.join(expected_order)}")
1027
  else:
@@ -1030,6 +1096,16 @@ def main():
1030
 
1031
  log_info("Dataset is pre-processed with prompt_number field indicating the correct sequence")
1032
 
 
 
 
 
 
 
 
 
 
 
1033
  # Calculate batch size based on device availability
1034
  if getattr(training_args, "no_cuda", False):
1035
  batch_size = training_args.per_device_train_batch_size
@@ -1038,16 +1114,29 @@ def main():
1038
 
1039
  log_info(f"Using sequential sampler with batch size {batch_size}")
1040
 
1041
- # Return DataLoader with sequential sampler
1042
- return torch.utils.data.DataLoader(
1043
- dataset,
1044
- batch_size=batch_size,
1045
- sampler=sequential_sampler, # Always use sequential sampler
1046
- collate_fn=data_collator,
1047
- drop_last=training_args.dataloader_drop_last,
1048
- num_workers=training_args.dataloader_num_workers,
1049
- pin_memory=training_args.dataloader_pin_memory,
1050
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
1051
 
1052
  # Override the get_train_dataloader method
1053
  trainer.get_train_dataloader = custom_get_train_dataloader
 
337
  if len(dataset) == 0:
338
  raise ValueError(f"Dataset {dataset_name} (split {dataset_split}) is empty (contains 0 examples)")
339
 
340
+ # Verify conversations field specifically - this is critical for training
341
+ if "conversations" not in dataset.column_names:
342
+ raise ValueError(f"Dataset {dataset_name} missing required 'conversations' column")
343
+
344
+ # Check a sample of conversation entries to validate structure
345
+ logger.info("Validating conversation structure...")
346
+ for i in range(min(5, len(dataset))):
347
+ conv = dataset[i].get("conversations")
348
+ if conv is None:
349
+ logger.warning(f"Example {i} has None as 'conversations' value")
350
+ elif not isinstance(conv, list):
351
+ logger.warning(f"Example {i} has non-list 'conversations': {type(conv)}")
352
+ elif len(conv) == 0:
353
+ logger.warning(f"Example {i} has empty conversations list")
354
+ else:
355
+ # Look at the first conversation entry
356
+ first_entry = conv[0]
357
+ logger.info(f"Sample conversation: {str(first_entry)[:100]}...")
358
+
359
+ # Make sure content field exists
360
+ if isinstance(first_entry, dict) and "content" in first_entry:
361
+ logger.info(f"Content field example: {str(first_entry['content'])[:50]}...")
362
+ else:
363
+ logger.warning(f"Example {i} missing 'content' key in conversation")
364
+
365
  except Exception as dataset_error:
366
  logger.error(f"Failed to load dataset {dataset_name}: {str(dataset_error)}")
367
  logger.error("Make sure the dataset exists and you have proper access permissions")
 
503
  for example in features:
504
  try:
505
  # Get ID
506
+ paper_id = example.get("article_id", example.get("id", ""))
507
 
508
+ # Get conversations
509
+ raw_conversations = example.get("conversations", [])
510
+ if not raw_conversations:
511
+ logger.warning(f"Empty conversations for example {paper_id}")
512
  self.stats["skipped"] += 1
513
  continue
514
 
515
+ # Extract only the 'content' field from each conversation item
516
+ # This simplifies the structure and avoids potential NoneType errors
517
  try:
518
+ # Convert conversations to the simple format with only content
519
+ simplified_conversations = []
520
+ for item in raw_conversations:
521
+ if isinstance(item, dict) and "content" in item:
522
+ # Keep only the content field
523
+ content = item["content"]
524
+ simplified_conversations.append({"role": "user", "content": content})
525
+ elif isinstance(item, str):
526
+ # If it's just a string, treat it as content
527
+ simplified_conversations.append({"role": "user", "content": item})
528
+ else:
529
+ logger.warning(f"Skipping invalid conversation item: {item}")
530
+
531
+ # Skip if no valid conversations after filtering
532
+ if not simplified_conversations:
533
+ logger.warning(f"No valid conversations after filtering for example {paper_id}")
534
+ self.stats["skipped"] += 1
535
+ continue
536
+
537
+ # Log the simplified content for debugging
538
+ if len(simplified_conversations) > 0:
539
+ first_content = simplified_conversations[0]["content"]
540
+ logger.debug(f"First content: {first_content[:50]}...")
541
+
542
+ # Let tokenizer handle the simplified conversations
543
  inputs = self.tokenizer.apply_chat_template(
544
+ simplified_conversations,
545
  return_tensors=None,
546
  add_generation_prompt=False
547
  )
548
  except Exception as chat_error:
549
  # Fallback if apply_chat_template fails
550
+ logger.warning(f"Chat template application failed for example {paper_id}: {str(chat_error)}")
551
 
552
+ # Create a basic representation of just the content
553
  conversation_text = ""
554
+ for msg in raw_conversations:
555
  if isinstance(msg, dict) and 'content' in msg:
556
+ conversation_text += msg['content'] + "\n\n"
557
+ elif isinstance(msg, str):
558
+ conversation_text += msg + "\n\n"
559
 
560
  # Basic tokenization
561
  inputs = self.tokenizer(
 
589
  logger.info(f"Example {self.stats['processed']}:")
590
  logger.info(f"Paper ID: {paper_id}")
591
  logger.info(f"Token count: {len(inputs)}")
592
+ logger.info(f"Conversation entries: {len(raw_conversations)}")
593
  else:
594
  self.stats["skipped"] += 1
595
  except Exception as e:
 
1056
  """Custom dataloader that preserves original dataset order"""
1057
  log_info("Creating sequential dataloader to maintain original dataset order")
1058
 
1059
+ # Safety check - make sure dataset exists and is not None
1060
+ if dataset is None:
1061
+ raise ValueError("Dataset is None - cannot create dataloader")
1062
+
1063
+ # Make sure dataset is not empty
1064
+ if len(dataset) == 0:
1065
+ raise ValueError("Dataset is empty - cannot create dataloader")
1066
+
1067
  # Create a simple sequential sampler
1068
  sequential_sampler = torch.utils.data.SequentialSampler(dataset)
1069
 
 
1078
  # Log our approach clearly
1079
  log_info("Using SequentialSampler to guarantee dataset order is preserved based on prompt_number")
1080
 
1081
+ # Verify column order and check for 'conversations' field
1082
  expected_order = ["prompt_number", "article_id", "conversations"]
1083
  if hasattr(dataset, 'column_names'):
1084
  actual_order = dataset.column_names
1085
+
1086
+ # Verify all required fields exist
1087
+ missing_fields = [field for field in ["conversations"] if field not in actual_order]
1088
+ if missing_fields:
1089
+ raise ValueError(f"Dataset missing critical fields: {missing_fields}")
1090
+
1091
  if actual_order == expected_order:
1092
  log_info(f"Confirmed dataset columns are in expected order: {', '.join(expected_order)}")
1093
  else:
 
1096
 
1097
  log_info("Dataset is pre-processed with prompt_number field indicating the correct sequence")
1098
 
1099
+ # Validate a few samples before proceeding
1100
+ for i in range(min(3, len(dataset))):
1101
+ sample = dataset[i]
1102
+ if "conversations" not in sample:
1103
+ log_info(f"WARNING: Sample {i} missing 'conversations' field")
1104
+ elif sample["conversations"] is None:
1105
+ log_info(f"WARNING: Sample {i} has None 'conversations' field")
1106
+ elif not isinstance(sample["conversations"], list):
1107
+ log_info(f"WARNING: Sample {i} has non-list 'conversations' field: {type(sample['conversations'])}")
1108
+
1109
  # Calculate batch size based on device availability
1110
  if getattr(training_args, "no_cuda", False):
1111
  batch_size = training_args.per_device_train_batch_size
 
1114
 
1115
  log_info(f"Using sequential sampler with batch size {batch_size}")
1116
 
1117
+ # Return DataLoader with sequential sampler and extra error handling
1118
+ try:
1119
+ return torch.utils.data.DataLoader(
1120
+ dataset,
1121
+ batch_size=batch_size,
1122
+ sampler=sequential_sampler, # Always use sequential sampler
1123
+ collate_fn=data_collator,
1124
+ drop_last=training_args.dataloader_drop_last,
1125
+ num_workers=training_args.dataloader_num_workers,
1126
+ pin_memory=training_args.dataloader_pin_memory,
1127
+ )
1128
+ except Exception as e:
1129
+ log_info(f"Error creating DataLoader: {str(e)}")
1130
+ # Try again with minimal settings
1131
+ log_info("Attempting to create DataLoader with minimal settings")
1132
+ return torch.utils.data.DataLoader(
1133
+ dataset,
1134
+ batch_size=1, # Minimal batch size
1135
+ sampler=sequential_sampler,
1136
+ collate_fn=data_collator,
1137
+ num_workers=0, # No parallel workers
1138
+ pin_memory=False,
1139
+ )
1140
 
1141
  # Override the get_train_dataloader method
1142
  trainer.get_train_dataloader = custom_get_train_dataloader