File size: 3,023 Bytes
e1392d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# main_menu.py

import argparse
import sys
from train_agent import train_agent
from test_agent import TestAgent, run_test_session
from lightbulb import main as world_model_main

def parse_main_args():
    parser = argparse.ArgumentParser(description="Main Menu for Selecting Tasks")
    parser.add_argument('--task', type=str, choices=['train_llm_world', 'train_agent', 'test_agent'], 
                        required=True, help='Choose task to execute: train_llm_world, train_agent, test_agent')
    # Optional arguments for more granular control
    parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name for LLM')
    parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name for training')
    parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
    parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training')
    parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs for training')
    parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length for training')
    parser.add_argument('--mode', type=str, choices=['train', 'inference'], default='train', help='Train or inference mode for LLM')
    parser.add_argument('--query', type=str, default='', help='Query for the test_agent')
    return parser.parse_args()

def main():
    # Parse arguments for the main function
    args = parse_main_args()

    # Execute tasks based on user input
    if args.task == 'train_llm_world':
        print("Starting LLM and World Model Training...")
        # Directly call the world model main function
        sys.argv = ['lightbulb.py', '--mode', args.mode, '--model_name', args.model_name, 
                    '--dataset_name', args.dataset_name, '--dataset_config', args.dataset_config,
                    '--batch_size', str(args.batch_size), '--num_epochs', str(args.num_epochs),
                    '--max_length', str(args.max_length)]
        world_model_main()

    elif args.task == 'train_agent':
        print("Starting Agent Training...")
        # Call the train_agent function from train_agent.py
        from twisted.internet import reactor, task
        d = task.deferLater(reactor, 0, train_agent)
        d.addErrback(lambda failure: print(f"An error occurred: {failure}", exc_info=True))
        d.addBoth(lambda _: reactor.stop())
        reactor.run()

    elif args.task == 'test_agent':
        print("Starting Test Agent...")
        test_agent = TestAgent()
        if args.query:
            # Directly process a single query
            result = test_agent.process_query(args.query)
            print("\nAgent's response:")
            print(result)
        else:
            # Run the interactive session
            reactor.callWhenRunning(run_test_session)
            reactor.run()

if __name__ == "__main__":
    main()