File size: 3,023 Bytes
e1392d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
# main_menu.py
import argparse
import sys
from train_agent import train_agent
from test_agent import TestAgent, run_test_session
from lightbulb import main as world_model_main
def parse_main_args():
parser = argparse.ArgumentParser(description="Main Menu for Selecting Tasks")
parser.add_argument('--task', type=str, choices=['train_llm_world', 'train_agent', 'test_agent'],
required=True, help='Choose task to execute: train_llm_world, train_agent, test_agent')
# Optional arguments for more granular control
parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name for LLM')
parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name for training')
parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training')
parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs for training')
parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length for training')
parser.add_argument('--mode', type=str, choices=['train', 'inference'], default='train', help='Train or inference mode for LLM')
parser.add_argument('--query', type=str, default='', help='Query for the test_agent')
return parser.parse_args()
def main():
# Parse arguments for the main function
args = parse_main_args()
# Execute tasks based on user input
if args.task == 'train_llm_world':
print("Starting LLM and World Model Training...")
# Directly call the world model main function
sys.argv = ['lightbulb.py', '--mode', args.mode, '--model_name', args.model_name,
'--dataset_name', args.dataset_name, '--dataset_config', args.dataset_config,
'--batch_size', str(args.batch_size), '--num_epochs', str(args.num_epochs),
'--max_length', str(args.max_length)]
world_model_main()
elif args.task == 'train_agent':
print("Starting Agent Training...")
# Call the train_agent function from train_agent.py
from twisted.internet import reactor, task
d = task.deferLater(reactor, 0, train_agent)
d.addErrback(lambda failure: print(f"An error occurred: {failure}", exc_info=True))
d.addBoth(lambda _: reactor.stop())
reactor.run()
elif args.task == 'test_agent':
print("Starting Test Agent...")
test_agent = TestAgent()
if args.query:
# Directly process a single query
result = test_agent.process_query(args.query)
print("\nAgent's response:")
print(result)
else:
# Run the interactive session
reactor.callWhenRunning(run_test_session)
reactor.run()
if __name__ == "__main__":
main()
|