# main_menu.py import argparse import sys from train_agent import train_agent from test_agent import TestAgent, run_test_session from lightbulb import main as world_model_main def parse_main_args(): parser = argparse.ArgumentParser(description="Main Menu for Selecting Tasks") parser.add_argument('--task', type=str, choices=['train_llm_world', 'train_agent', 'test_agent'], required=True, help='Choose task to execute: train_llm_world, train_agent, test_agent') # Optional arguments for more granular control parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name for LLM') parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name for training') parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name') parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training') parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs for training') parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length for training') parser.add_argument('--mode', type=str, choices=['train', 'inference'], default='train', help='Train or inference mode for LLM') parser.add_argument('--query', type=str, default='', help='Query for the test_agent') return parser.parse_args() def main(): # Parse arguments for the main function args = parse_main_args() # Execute tasks based on user input if args.task == 'train_llm_world': print("Starting LLM and World Model Training...") # Directly call the world model main function sys.argv = ['lightbulb.py', '--mode', args.mode, '--model_name', args.model_name, '--dataset_name', args.dataset_name, '--dataset_config', args.dataset_config, '--batch_size', str(args.batch_size), '--num_epochs', str(args.num_epochs), '--max_length', str(args.max_length)] world_model_main() elif args.task == 'train_agent': print("Starting Agent Training...") # Call the train_agent function from train_agent.py from twisted.internet import reactor, task d = task.deferLater(reactor, 0, train_agent) d.addErrback(lambda failure: print(f"An error occurred: {failure}", exc_info=True)) d.addBoth(lambda _: reactor.stop()) reactor.run() elif args.task == 'test_agent': print("Starting Test Agent...") test_agent = TestAgent() if args.query: # Directly process a single query result = test_agent.process_query(args.query) print("\nAgent's response:") print(result) else: # Run the interactive session reactor.callWhenRunning(run_test_session) reactor.run() if __name__ == "__main__": main()