license: apache-2.0
Overview of the Main Menu
The main_menu.py
script is the primary entry point for choosing and executing one of three tasks:
- Training the LLM and World Model:
train_llm_world
- Training the Search Agent:
train_agent
- Testing the Tree of Thought Search Agent:
test_agent
Each task has unique functionalities and configurations. This script uses command-line arguments to specify the desired task and additional options, giving users the ability to tailor the execution according to their needs.
Running the Main Menu
To run the main menu, use the following command in the terminal:
python main_menu.py --task <task_name> [additional arguments]
Replace <task_name>
with one of the following:
train_llm_world
- Train the LLM (Language Model) and World Model.train_agent
- Train the Search Agent with an interactive Twisted-based process.test_agent
- Test the Tree of Thought Search Agent, with the option of an interactive session or a single query.
General Arguments
The script supports a set of command-line arguments to customize each task. Here’s an overview of all possible arguments:
Argument | Required | Description | Default |
---|---|---|---|
--task |
Yes | Specifies the task to run. Choose from train_llm_world , train_agent , or test_agent . |
None |
--model_name |
No | Pretrained model name for LLM. Options include gpt2 , bert , etc., or a custom model path. |
gpt2 |
--dataset_name |
No | Name of the dataset from Hugging Face Datasets for training the LLM and World Model (e.g., wikitext ). |
wikitext |
--dataset_config |
No | Dataset configuration name for specifying different versions or configurations of the dataset. | wikitext-2-raw-v1 |
--batch_size |
No | Number of samples processed in a single forward/backward pass. Increasing the batch size can speed up training but requires more memory. | 4 |
--num_epochs |
No | Number of times to iterate over the training dataset during model training. More epochs generally improve learning but can lead to overfitting. | 3 |
--max_length |
No | Maximum sequence length for training/inference. Truncates or pads sequences to this length to maintain consistency in training. | 128 |
--mode |
No | Specifies the mode for the LLM and World Model. Use train for training and inference for generating responses. |
train |
--query |
No | Query input for test_agent when running a single query instead of an interactive session. |
'' (empty) |
Task Details
1. Training the LLM and World Model (train_llm_world
)
This task trains the LLM and the World Model using a chosen dataset from Hugging Face. Training includes adjusting model weights through epochs and creating a model capable of handling long sequences and complex reasoning tasks.
Example Usage
python main_menu.py --task train_llm_world --model_name gpt2 --dataset_name wikitext --num_epochs 5 --batch_size 8 --max_length 256
Arguments Specific to train_llm_world
--model_name
: Name of the pretrained model to use for language model training. You can specify a model name (likegpt2
,bert
, etc.) or a path to a custom model. This argument affects the model architecture and tokenization style.--dataset_name
: Specifies the dataset from Hugging Face’s Datasets library to train the model. Options includewikitext
,imdb
,squad
, etc. You can also use a custom dataset by specifying its path.--dataset_config
: Defines the configuration of the dataset, which may be different versions or variations of the dataset. For example,wikitext
includes configurations such aswikitext-2-raw-v1
. The configuration will affect the format and content of the data.--batch_size
: The number of samples per batch. A larger batch size requires more memory but can improve training speed. You might need to reduce the batch size if memory is limited.--num_epochs
: The number of complete passes through the training dataset. More epochs can improve the model’s ability to learn but may lead to overfitting if too high.--max_length
: Limits the maximum length of the input sequence. Truncated sequences will be cut off, and shorter sequences will be padded. This affects both training and inference.--mode
: Defines the task to be performed. Choosetrain
to start training the model. If set toinference
, the model generates text based on the input.
2. Training the Search Agent (train_agent
)
Here's a detailed breakdown of your search agent, covering training, inference, and the functionality of each component. This overview will also highlight how the agent saves LLM training data, its modular structure, and the role of each module.
Overview of the AutonomousWebAgent
The AutonomousWebAgent
is a sophisticated, multi-component search and retrieval agent designed to navigate the web, gather relevant content, and perform summarization and generation based on user queries. This agent integrates reinforcement learning (RL), Monte Carlo Tree Search (MCTS), a Retrieva-Augmented Generation (RAG) Summarizer, and a Hierarchical Reinforcement Learning (HRL) architecture to select, execute, and optimize its actions based on past experiences.
Key Components
Prioritized Experience Replay:
- The agent uses a
PrioritizedReplayMemory
and aSumTree
to prioritize and store experiences (transitions between states). - The
SumTree
structure maintains a binary tree where each parent node's value is the sum of its children, helping to efficiently store, update, and retrieve experiences based on priority. - These experiences are critical in training both high-level (manager) and low-level (worker) components through prioritized sampling during replay, allowing the model to focus on more significant transitions.
- The agent uses a
Hierarchical Reinforcement Learning (HRL):
- HRL is employed to allow a Manager (high-level) model to select options, which are then executed by a Worker (low-level) model. The
ManagerModel
selects tasks (such as searching, summarizing, or generating), while theWorkerModel
determines specific actions to take. - The manager and worker use LSTM networks with fully connected layers, and each has its own replay memory and optimization process.
- The Manager focuses on broad decisions and options, while the Worker operates on specific actions, enabling a layered approach to decision-making.
- HRL is employed to allow a Manager (high-level) model to select options, which are then executed by a Worker (low-level) model. The
RAGSummarizer:
- The
RAGSummarizer
leverages a pre-trained language model (e.g., GPT-2) for summarizing, and a SentenceTransformer for embedding-based retrieval. This module breaks down the input content into chunks, retrieves relevant sections based on cosine similarity with the query, and generates a coherent summary. - Additionally, it implements a Least Recently Used (LRU) cache to avoid redundant computation and enhance efficiency, along with persistent storage for cache data.
- Summarized results are stored, and this module contributes directly to the generation of LLM training data.
- The
WorldModel:
- This module encapsulates an LSTM architecture with linear layers and a
value_head
to estimate state values, allowing the agent to anticipate the long-term value of its actions. - It is utilized in the HRL architecture, specifically by the Worker for evaluating actions and by the Manager in long-term decision-making.
- This module encapsulates an LSTM architecture with linear layers and a
Knowledge Base:
- The knowledge base acts as a repository for collected data, maintaining embeddings for efficient search and retrieval.
- It supports saving and loading document embeddings, so the agent can retrieve relevant information for new queries from previously collected knowledge.
- Adding and retrieving from the knowledge base enriches the agent’s context and allows it to store and use information from past experiences to inform current tasks.
Monte Carlo Tree Search (MCTS):
- The MCTS component guides the agent through complex decision trees to determine the most promising paths for query refinement.
- Nodes in the tree represent states (possible query refinements), and child nodes represent possible expansions (e.g., related query variations).
- MCTS utilizes a
select
,expand
,simulate
, andbackpropagate
strategy to iteratively refine queries, scoring them based on relevance and other metrics to converge on optimal searches. - It also integrates RL by backpropagating rewards based on the ranking score from retrieved results.
Ranking Model:
- The ranking model, built with a neural network and the
SentenceTransformer
, ranks search results based on various features such as cosine similarity with the query, content length, keyword overlap, and domain authority. - This model assigns scores to results, which are then used to guide the MCTS process by enhancing the combined reward with ranking scores.
- The ranking model, built with a neural network and the
Tree of Thought (ToT) Search:
- This module enhances the agent's capability to generate a series of interconnected thoughts, exploring different perspectives or angles on a given query.
ToTNode
andToTSearch
classes enable the agent to generate thoughts, evaluate them, and navigate through them as a tree, considering various potential paths to best answer the query.- It combines MCTS and RAG to synthesize responses based on the generated thought paths.
Training Process
The training process for the agent involves episodic learning, where it interacts with various queries from a predefined list. Each query initiates an episode, and the agent performs actions based on its learned policy:
Search and Summarization:
- The agent performs search operations, gathering relevant content from online sources using the MCTS and Ranking Model for prioritization.
- Summarization is then carried out on the retrieved content, with relevant information stored in the LLM training data.
Knowledge Base and LLM Training Data Storage:
- Throughout the training process, the agent stores retrieved documents, query results, and summaries in its knowledge base and saves training data for future LLM fine-tuning.
- The data is saved in JSONL format and includes metadata such as query terms, source links, and summaries, making it valuable for training language models.
Experience Replay:
- Both the manager and worker models engage in prioritized experience replay, sampling from the stored experiences in the SumTree based on TD-errors.
- Replay is essential for reinforcing successful transitions and updating the models' policies over time.
Reward Calculation and Backpropagation:
- Rewards are calculated based on ranking scores, cosine similarity with the query, and other custom factors (e.g., query complexity, state length).
- These rewards are backpropagated through the MCTS and used to update the models' decision-making processes, ensuring continuous learning and adaptation.
Inference Process
During inference:
- The agent accepts a query, and the Manager model selects a high-level action based on its policy (e.g., search, summarize, or generate).
- Once an option is chosen, the Worker model executes the corresponding low-level actions. For example, in a search operation, it leverages MCTS to refine the query, retrieves relevant web content, and processes it with the RAGSummarizer.
- Each inference step is augmented by the agent's existing knowledge base, enabling it to produce more informed and contextually rich responses. Additionally, if Tree of Thought (ToT) is employed, the agent synthesizes a coherent and comprehensive answer based on the thought path.
Model Saving
The agent incorporates a series of save functions to preserve the models:
save_worker_model
andsave_manager_model
functions save the worker and manager models independently.- The
save
method preserves the overall state of the agent, which includes its knowledge base, replay memories, and models. This facilitates model reusability and persistent storage, enabling the agent to resume from saved states during training or deployment.
This modular setup enhances flexibility, allowing the agent to dynamically adjust its behavior based on rewards from RL, improvements from experience replay, and efficient decision-making through MCTS. Additionally, by saving LLM training data, it becomes highly reusable for further fine-tuning, offering the opportunity to build specialized, data-driven language models optimized for specific domains or tasks. This task uses Twisted to train the Autonomous Web Agent by interacting with various queries in a simulated or real environment. It collects rewards based on how well the agent navigates and summarizes web content or performs other tasks.
Example Usage
python main_menu.py --task train_agent
Process Details
Training: During training, the agent will automatically sample a list of predefined queries, explore web pages, and use reinforcement learning to maximize its reward based on its actions. The training log provides insights into each episode's reward and the agent’s progress.
Logging: Logs are recorded to
agent_training.log
and provide information about each episode, such as the query, the total reward, and the episode duration. Errors are logged, and if an episode times out, a negative reward is given.
3. Testing the Tree of Thought Search Agent (test_agent
)
This task lets you test the Tree of Thought Search Agent either in an interactive mode or by specifying a single query. In interactive mode, the user can repeatedly enter queries, and the agent will process them sequentially, producing responses based on the Tree of Thought architecture.
Example Usage
Interactive Mode:
python main_menu.py --task test_agent
Single Query Mode:
python main_menu.py --task test_agent --query "What are the impacts of renewable energy on global sustainability?"
Arguments Specific to test_agent
--query
: If provided, the agent will process this specific query and return a response. This is ideal for quick, one-off tests or evaluations. If not provided, the program will start an interactive session where you can repeatedly input queries and view the agent's response.
Interactive Mode Details
Input: In interactive mode, enter a query and press Enter. The agent will respond based on its training and the Tree of Thought methodology, traversing different thought paths to generate a response.
Exiting: To exit the interactive session, type
quit
and press Enter. The agent will then save any new knowledge it has gained and exit the program.
Additional Tips and Considerations
Adjusting Memory Constraints: The batch size and model architecture (number of layers, dimensions, etc.) affect memory usage. If you encounter memory errors, reduce the batch size or sequence length.
Training Time: Training the LLM and World Model may take considerable time, especially with large datasets or complex models. Use fewer epochs or a smaller dataset to speed up initial trials.
Model Save Path: The
train_llm_world
task saves the model at the end of each epoch. Ensure you have enough storage space and specify a save directory if desired.Logging: Detailed logs for
train_agent
are saved to a file, which can help track progress, debug errors, and measure performance.
World Model with MCTS and Transformer Components
Model Overview
This model is a World Model that combines Transformers, Mixture of Experts (MoE) layers, Monte Carlo Tree Search (MCTS), and Proximal Policy Optimization (PPO) to simulate and optimize a state-based environment. Designed for complex tasks involving decision-making and action prediction, this model leverages powerful components to encode, predict, and enhance action sequences.
Key Components
- Transformer: The model uses a custom Transformer with rotary positional encoding and Mixture of Experts (MoE) layers. It serves as both an encoder and decoder, enabling sequential processing of input and target data.
- MCTS: The Monte Carlo Tree Search module iteratively simulates actions to select the best possible path based on exploration and exploitation.
- PPO Agent: A Proximal Policy Optimization agent is employed to update the policy and value functions. PPO loss is combined with other regularization losses to improve model performance.
- Custom Objective Functions: Several custom loss functions are implemented to help guide the model’s learning, including Covariance Regularization, Dynamics Performance Loss, Thought Consistency Loss, and more.
Model Architecture
The model is constructed with several primary components:
- Transformer: The transformer has encoder and decoder layers with rotary positional encoding and Mixture of Experts (MoE) to improve generalization and reduce computational cost by routing only parts of the data to certain experts. GELU and SwiGLU activation functions are alternated between the experts.
Multi-Token Prediction with Beam Search
Multi-token prediction in a language model involves generating multiple tokens in sequence, rather than one token at a time. This can improve the fluency and coherence of generated text by allowing the model to "look ahead" and consider multiple possible continuations at each step.
Beam Search is a popular decoding algorithm used for multi-token prediction that allows the model to explore multiple potential sequences and choose the most likely one based on the overall probability. Here's how it works:
Initialization:
- Start with a single "beam" (sequence) that contains the initial token, typically the beginning-of-sequence (
<sos>
) token.
- Start with a single "beam" (sequence) that contains the initial token, typically the beginning-of-sequence (
Expansion:
- At each time step, the model generates a probability distribution over the vocabulary for each sequence in the beam.
- For each sequence, it expands by predicting the next possible tokens, creating new sequences for each possible token.
Scoring:
- Calculate the score for each expanded sequence by taking the sum (or average) of log probabilities for all tokens in the sequence. Log probabilities are used to avoid underflow and ensure stable computation.
Selection:
- Keep only the top-k sequences with the highest scores (known as the "beam width" or "beam size") and discard the rest. This limits the number of sequences kept at each step, focusing only on the most promising ones.
Repeat:
- Continue expanding and scoring until reaching the desired sequence length or the end-of-sequence (
<eos>
) token.
- Continue expanding and scoring until reaching the desired sequence length or the end-of-sequence (
Final Output:
- After a set number of steps, or if all sequences end with
<eos>
, select the sequence with the highest score as the final output.
- After a set number of steps, or if all sequences end with
This process allows the model to generate more fluent and accurate sequences by considering multiple potential continuations at each step and selecting the best overall sequence.
World Model
- Representation Network: This module encodes the Transformer output to generate a state representation, reducing dimensionality and making it suitable for further processing.
- Dynamics Network: This module predicts the next state given a current state and an action. It uses layer normalization and a GELU activation function.
- Prediction Network: Predicts both the policy logits and value estimates for a given state. It outputs the probabilities of different actions as well as a single scalar value.
- MCTS: This module performs Monte Carlo Tree Search to evaluate the quality of actions over multiple iterations. It expands nodes based on the policy logits from the Prediction Network and simulates the reward by backpropagating value estimates.
- PPO Agent: Uses policy and value estimates to calculate PPO loss, which updates the policy while maintaining the constraint on the KL divergence between old and new policies.
The transformer strategically utilises beam search as well as multi token prediction, in order to enrich the encoding from the representation network.
A generated sequence of tokens is an action, for example if a token is t, then an action is:
a_1= {t1,...,tN}
then a policy is a sequence of actions:
P_1 = {a_1,...,aN}
The MCTS and OOPS explores what we are defining as 'thoughts', where a thought is a set of policies:
thought_1 = {P1, ... , PN}
The model explores and exploits thoughts, policies, actions, and tokens, and learning happens at each step of granularity.
Training Details
The model is trained with the following components and techniques:
Training Procedure
- Data Loading: The data is tokenized and prepared with attention to padding and truncation. Text data is grouped into sequences of fixed length for efficient training.
- Optimization: Training uses an AdamW optimizer with CosineAnnealingLR scheduler for learning rate adjustments. The Gradient Scaler helps prevent overflow when training with mixed precision.
- Gradient Accumulation: Since the model can be computationally heavy, gradients are accumulated over several steps to reduce memory usage.
- Loss Functions: The training process leverages a comprehensive set of custom loss functions:
- InfoNCE Loss: A contrastive loss to encourage representation similarity between related pairs.
- Covariance Regularization: Encourages diverse state representations by minimizing co-linearity in embeddings.
- Dynamics Performance Loss: Combines MSE and variance losses to penalize incorrect state predictions.
- Thought Consistency Loss: Encourages the model to output consistent states for similar actions.
- Policy Value Joint Loss: A weighted combination of policy and value loss for the PPO agent.
- Action Diversity Reward: Rewards diverse action embeddings to avoid mode collapse.
- Exploration Regularization: Encourages exploration by penalizing high visitation counts.
- KL Divergence Loss: Keeps the policy update close to the previous policy to stabilize training.
Evaluation
After each epoch, the model is evaluated on the validation set, computing the average loss over the dataset. The evaluation function utilizes the same loss functions as training but does not backpropagate, allowing it to be run in inference mode.
Checkpoints
At the end of each epoch, the model saves checkpoints of all components, enabling easy resumption or further fine-tuning as needed.
Inference Details
Input Processing:
- The function takes a query (text input), world model components, a root thought node, and a tokenizer.
- The query is tokenized and encoded using the provided tokenizer.
Inference Modes: The function supports three inference modes:
a. 'without_world_model':
- This mode directly uses the transformer model to generate text.
- It doesn't utilize the world model components or the Tree of Thought.
- The transformer generates text autoregressively up to the specified max length.
b. 'world_model':
- This mode uses the world model components but doesn't use the Tree of Thought.
- It generates actions based on the prediction network's output.
c. 'world_model_tree_of_thought':
- This is the most comprehensive mode, using both the world model and the Tree of Thought.
World Model Inference Process: For the 'world_model' and 'world_model_tree_of_thought' modes:
a. Initial State:
- The query is passed through the transformer model.
- The representation network creates an initial state representation from the transformer output.
b. Action Selection:
For 'world_model':
- The prediction network generates policy logits from the state representation.
- Actions are selected based on the highest probabilities in the policy.
For 'world_model_tree_of_thought':
- It uses Monte Carlo Tree Search (MCTS) to explore the Tree of Thought.
- For each MCTS iteration:
- Selection: Traverse the tree to find a leaf node.
- Expansion: Add child nodes to the leaf.
- Evaluation: Use the prediction network to estimate the value of the node.
- Backpropagation: Update the values and visit counts of nodes.
- The best action is chosen based on visit counts after MCTS.
c. State Transition:
- The selected action is applied to the current state using the dynamics network.
- This creates a new state representation for the next step.
d. Sequence Generation:
- The process repeats for the specified number of steps or until a termination condition is met.
- For the Tree of Thought approach, it continues until reaching a leaf node in the thought tree.
Output:
- For 'without_world_model', it returns the generated text.
- For 'world_model' and 'world_model_tree_of_thought', it returns a sequence of selected actions (thoughts).
The world model inference leverages the learned representations and dynamics to navigate the problem-solving process. The Tree of Thought approach adds structure to this process, guiding the model through a predefined hierarchy of problem-solving steps. This allows for a more structured and potentially more effective approach to complex problem-solving tasks.
Here I am utilising Trees of Thought as a structure of how to structure sets of policies, and sequences of actions. These Tree structures provide the World Model a general thought structure and pattern, similarly to how humans create thought patterns for solving certain problems (e.g. understand, describe, analyse, etc).
Here are some example Trees of Thought: graph TD A[Problem-Solving Process] --> B[Problem Identification] A --> C[Problem Analysis] A --> D[Solution Generation] A --> E[Implementation] A --> F[Evaluation and Adjustment] B --> B1[Define the Problem] B --> B2[Identify Stakeholders] B --> B3[Determine Constraints] B --> B4[Recognize Problem Type] B --> B5[Historical Context] C --> C1[Root Cause Analysis] C --> C2[System Mapping] C --> C3[Data Collection] C --> C4[Impact Assessment] C --> C5[Theoretical Framework] D --> D1[Creative Problem Solving] D --> D2[Analytical Approach] D --> D3[Mathematical Computation] D --> D4[Decision Making] E --> E1[Action Planning] E --> E2[Resource Allocation] E --> E3[Change Management] F --> F1[Verification] F --> F2[Performance Metrics] F --> F3[Feedback Loops] F --> F4[Continuous Improvement] C3 --> C3a[Quantitative Data] C3 --> C3b[Qualitative Data] C3 --> C3c[Data Validation] D1 --> D1a[Divergent Thinking] D1 --> D1b[Convergent Thinking] D1 --> D1c[Lateral Thinking] D2 --> D2a[Logical Reasoning] D2 --> D2b[Critical Analysis] D2 --> D2c[Systems Thinking] D3 --> D3a[Basic Operations] D3 --> D3b[Advanced Operations] D3 --> D3c[Computational Methods] D4 --> D4a[Decision Trees] D4 --> D4b[Multi-Criteria Analysis] D4 --> D4c[Probabilistic Reasoning] G[Cross-Cutting Considerations] --> G1[Ethical Framework] G --> G2[Stakeholder Management] G --> G3[Interdisciplinary Connections] G --> G4[Technological Integration] G --> G5[Emotional Intelligence] G --> G6[Collaborative Problem Solving] G1 --> G1a[Value-based Decision Making] G1 --> G1b[Long-term Consequences] G2 --> G2a[Direct Stakeholders] G2 --> G2b[Indirect Stakeholders] G2 --> G2c[Conflicting Interests] G3 --> G3a[Related Fields] G3 --> G3b[Cross-disciplinary Impact] G4 --> G4a[AI-assisted Problem Solving] G4 --> G4b[Data-driven Insights] G4 --> G4c[Digital Collaboration Tools] G5 --> G5a[Self-Awareness] G5 --> G5b[Empathy] G5 --> G5c[Stress Management] G6 --> G6a[Team Dynamics] G6 --> G6b[Communication Strategies] G6 --> G6c[Conflict Resolution] H[Computational Considerations] --> H1[CPU Operations] H --> H2[GPU Parallelization] H --> H3[Floating-Point Precision] I[Order of Operations] --> I1[Parentheses] I --> I2[Exponents] I --> I3[Multiplication and Division] I --> I4[Addition and Subtraction] J[Critical Thinking] --> J1[Assumptions Questioning] J --> J2[Bias Recognition] K[Future Perspective] --> K1[Short-term Projections] K --> K2[Long-term Scenarios] K --> K3[Potential Impacts] L[Learning and Adaptation] --> L1[Reflective Practice] L --> L2[Knowledge Transfer] L --> L3[Adaptive Problem Solving]
graph TD A[Meta-Cognitive Strategies] --> B[Creative Problem Solving] A --> C[Systems Thinking] A --> D[Decision Making] A --> E[Emotional Intelligence] A --> F[Collaborative Problem Solving] B --> B1[Divergent Thinking] B --> B2[Convergent Thinking] B --> B3[Lateral Thinking] C --> C1[Holistic Perspective] C --> C2[Feedback Loops] C --> C3[Emergent Properties] D --> D1[Decision Trees] D --> D2[Multi-Criteria Decision Analysis] D --> D3[Probabilistic Reasoning] E --> E1[Self-Awareness] E --> E2[Empathy] E --> E3[Stress Management] F --> F1[Team Dynamics] F --> F2[Communication Strategies] F --> F3[Conflict Resolution] G[Learning and Adaptation] A --> G G --> G1[Reflective Practice] G --> G2[Knowledge Transfer] G --> G3[Adaptive Problem Solving] H[Ethical Framework] A --> H H --> H1[Value-based Decision Making] H --> H2[Stakeholder Analysis] H --> H3[Long-term Consequences] I[Technological Integration] A --> I I --> I1[AI-assisted Problem Solving] I --> I2[Data-driven Insights] I --> I3[Digital Collaboration Tools]
Requirements
This code requires:
- Python 3.7+
torch>=1.7.1
transformers
datasets
argparse
Citation
If you use this model in your research, please cite the author.