|
|
|
|
|
import argparse
|
|
import sys
|
|
from train_agent import train_agent
|
|
from test_agent import TestAgent, run_test_session
|
|
from lightbulb import main as world_model_main
|
|
|
|
def parse_main_args():
|
|
parser = argparse.ArgumentParser(description="Main Menu for Selecting Tasks")
|
|
parser.add_argument('--task', type=str, choices=['train_llm_world', 'train_agent', 'test_agent'],
|
|
required=True, help='Choose task to execute: train_llm_world, train_agent, test_agent')
|
|
|
|
parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name for LLM')
|
|
parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name for training')
|
|
parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
|
|
parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training')
|
|
parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs for training')
|
|
parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length for training')
|
|
parser.add_argument('--mode', type=str, choices=['train', 'inference'], default='train', help='Train or inference mode for LLM')
|
|
parser.add_argument('--query', type=str, default='', help='Query for the test_agent')
|
|
return parser.parse_args()
|
|
|
|
def main():
|
|
|
|
args = parse_main_args()
|
|
|
|
|
|
if args.task == 'train_llm_world':
|
|
print("Starting LLM and World Model Training...")
|
|
|
|
sys.argv = ['lightbulb.py', '--mode', args.mode, '--model_name', args.model_name,
|
|
'--dataset_name', args.dataset_name, '--dataset_config', args.dataset_config,
|
|
'--batch_size', str(args.batch_size), '--num_epochs', str(args.num_epochs),
|
|
'--max_length', str(args.max_length)]
|
|
world_model_main()
|
|
|
|
elif args.task == 'train_agent':
|
|
print("Starting Agent Training...")
|
|
|
|
from twisted.internet import reactor, task
|
|
d = task.deferLater(reactor, 0, train_agent)
|
|
d.addErrback(lambda failure: print(f"An error occurred: {failure}", exc_info=True))
|
|
d.addBoth(lambda _: reactor.stop())
|
|
reactor.run()
|
|
|
|
elif args.task == 'test_agent':
|
|
print("Starting Test Agent...")
|
|
test_agent = TestAgent()
|
|
if args.query:
|
|
|
|
result = test_agent.process_query(args.query)
|
|
print("\nAgent's response:")
|
|
print(result)
|
|
else:
|
|
|
|
reactor.callWhenRunning(run_test_session)
|
|
reactor.run()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|