lightbulb / README.md
RobbiePasquale's picture
Update README.md
8e79db6 verified
|
raw
history blame
32.6 kB
metadata
license: apache-2.0

Overview of the Main Menu

The main_menu.py script is the primary entry point for choosing and executing one of three tasks:

  1. Training the LLM and World Model: train_llm_world
  2. Training the Search Agent: train_agent
  3. Testing the Tree of Thought Search Agent: test_agent

Each task has unique functionalities and configurations. This script uses command-line arguments to specify the desired task and additional options, giving users the ability to tailor the execution according to their needs.

Running the Main Menu

To run the main menu, use the following command in the terminal:

python main_menu.py --task <task_name> [additional arguments]

Replace <task_name> with one of the following:

  • train_llm_world - Train the LLM (Language Model) and World Model.
  • train_agent - Train the Search Agent with an interactive Twisted-based process.
  • test_agent - Test the Tree of Thought Search Agent, with the option of an interactive session or a single query.

General Arguments

The script supports a set of command-line arguments to customize each task. Here’s an overview of all possible arguments:

Argument Required Description Default
--task Yes Specifies the task to run. Choose from train_llm_world, train_agent, or test_agent. None
--model_name No Pretrained model name for LLM. Options include gpt2, bert, etc., or a custom model path. gpt2
--dataset_name No Name of the dataset from Hugging Face Datasets for training the LLM and World Model (e.g., wikitext). wikitext
--dataset_config No Dataset configuration name for specifying different versions or configurations of the dataset. wikitext-2-raw-v1
--batch_size No Number of samples processed in a single forward/backward pass. Increasing the batch size can speed up training but requires more memory. 4
--num_epochs No Number of times to iterate over the training dataset during model training. More epochs generally improve learning but can lead to overfitting. 3
--max_length No Maximum sequence length for training/inference. Truncates or pads sequences to this length to maintain consistency in training. 128
--mode No Specifies the mode for the LLM and World Model. Use train for training and inference for generating responses. train
--query No Query input for test_agent when running a single query instead of an interactive session. '' (empty)

Task Details

1. Training the LLM and World Model (train_llm_world)

This task trains the LLM and the World Model using a chosen dataset from Hugging Face. Training includes adjusting model weights through epochs and creating a model capable of handling long sequences and complex reasoning tasks.

Example Usage

python main_menu.py --task train_llm_world --model_name gpt2 --dataset_name wikitext --num_epochs 5 --batch_size 8 --max_length 256

Arguments Specific to train_llm_world

  • --model_name: Name of the pretrained model to use for language model training. You can specify a model name (like gpt2, bert, etc.) or a path to a custom model. This argument affects the model architecture and tokenization style.

  • --dataset_name: Specifies the dataset from Hugging Face’s Datasets library to train the model. Options include wikitext, imdb, squad, etc. You can also use a custom dataset by specifying its path.

  • --dataset_config: Defines the configuration of the dataset, which may be different versions or variations of the dataset. For example, wikitext includes configurations such as wikitext-2-raw-v1. The configuration will affect the format and content of the data.

  • --batch_size: The number of samples per batch. A larger batch size requires more memory but can improve training speed. You might need to reduce the batch size if memory is limited.

  • --num_epochs: The number of complete passes through the training dataset. More epochs can improve the model’s ability to learn but may lead to overfitting if too high.

  • --max_length: Limits the maximum length of the input sequence. Truncated sequences will be cut off, and shorter sequences will be padded. This affects both training and inference.

  • --mode: Defines the task to be performed. Choose train to start training the model. If set to inference, the model generates text based on the input.

2. Training the Search Agent (train_agent)

Here's a detailed breakdown of your search agent, covering training, inference, and the functionality of each component. This overview will also highlight how the agent saves LLM training data, its modular structure, and the role of each module.


Overview of the AutonomousWebAgent

The AutonomousWebAgent is a sophisticated, multi-component search and retrieval agent designed to navigate the web, gather relevant content, and perform summarization and generation based on user queries. This agent integrates reinforcement learning (RL), Monte Carlo Tree Search (MCTS), a Retrieva-Augmented Generation (RAG) Summarizer, and a Hierarchical Reinforcement Learning (HRL) architecture to select, execute, and optimize its actions based on past experiences.

Key Components

  1. Prioritized Experience Replay:

    • The agent uses a PrioritizedReplayMemory and a SumTree to prioritize and store experiences (transitions between states).
    • The SumTree structure maintains a binary tree where each parent node's value is the sum of its children, helping to efficiently store, update, and retrieve experiences based on priority.
    • These experiences are critical in training both high-level (manager) and low-level (worker) components through prioritized sampling during replay, allowing the model to focus on more significant transitions.
  2. Hierarchical Reinforcement Learning (HRL):

    • HRL is employed to allow a Manager (high-level) model to select options, which are then executed by a Worker (low-level) model. The ManagerModel selects tasks (such as searching, summarizing, or generating), while the WorkerModel determines specific actions to take.
    • The manager and worker use LSTM networks with fully connected layers, and each has its own replay memory and optimization process.
    • The Manager focuses on broad decisions and options, while the Worker operates on specific actions, enabling a layered approach to decision-making.
  3. RAGSummarizer:

    • The RAGSummarizer leverages a pre-trained language model (e.g., GPT-2) for summarizing, and a SentenceTransformer for embedding-based retrieval. This module breaks down the input content into chunks, retrieves relevant sections based on cosine similarity with the query, and generates a coherent summary.
    • Additionally, it implements a Least Recently Used (LRU) cache to avoid redundant computation and enhance efficiency, along with persistent storage for cache data.
    • Summarized results are stored, and this module contributes directly to the generation of LLM training data.
  4. WorldModel:

    • This module encapsulates an LSTM architecture with linear layers and a value_head to estimate state values, allowing the agent to anticipate the long-term value of its actions.
    • It is utilized in the HRL architecture, specifically by the Worker for evaluating actions and by the Manager in long-term decision-making.
  5. Knowledge Base:

    • The knowledge base acts as a repository for collected data, maintaining embeddings for efficient search and retrieval.
    • It supports saving and loading document embeddings, so the agent can retrieve relevant information for new queries from previously collected knowledge.
    • Adding and retrieving from the knowledge base enriches the agent’s context and allows it to store and use information from past experiences to inform current tasks.
  6. Monte Carlo Tree Search (MCTS):

    • The MCTS component guides the agent through complex decision trees to determine the most promising paths for query refinement.
    • Nodes in the tree represent states (possible query refinements), and child nodes represent possible expansions (e.g., related query variations).
    • MCTS utilizes a select, expand, simulate, and backpropagate strategy to iteratively refine queries, scoring them based on relevance and other metrics to converge on optimal searches.
    • It also integrates RL by backpropagating rewards based on the ranking score from retrieved results.
  7. Ranking Model:

    • The ranking model, built with a neural network and the SentenceTransformer, ranks search results based on various features such as cosine similarity with the query, content length, keyword overlap, and domain authority.
    • This model assigns scores to results, which are then used to guide the MCTS process by enhancing the combined reward with ranking scores.
  8. Tree of Thought (ToT) Search:

    • This module enhances the agent's capability to generate a series of interconnected thoughts, exploring different perspectives or angles on a given query.
    • ToTNode and ToTSearch classes enable the agent to generate thoughts, evaluate them, and navigate through them as a tree, considering various potential paths to best answer the query.
    • It combines MCTS and RAG to synthesize responses based on the generated thought paths.

Training Process

The training process for the agent involves episodic learning, where it interacts with various queries from a predefined list. Each query initiates an episode, and the agent performs actions based on its learned policy:

  1. Search and Summarization:

    • The agent performs search operations, gathering relevant content from online sources using the MCTS and Ranking Model for prioritization.
    • Summarization is then carried out on the retrieved content, with relevant information stored in the LLM training data.
  2. Knowledge Base and LLM Training Data Storage:

    • Throughout the training process, the agent stores retrieved documents, query results, and summaries in its knowledge base and saves training data for future LLM fine-tuning.
    • The data is saved in JSONL format and includes metadata such as query terms, source links, and summaries, making it valuable for training language models.
  3. Experience Replay:

    • Both the manager and worker models engage in prioritized experience replay, sampling from the stored experiences in the SumTree based on TD-errors.
    • Replay is essential for reinforcing successful transitions and updating the models' policies over time.
  4. Reward Calculation and Backpropagation:

    • Rewards are calculated based on ranking scores, cosine similarity with the query, and other custom factors (e.g., query complexity, state length).
    • These rewards are backpropagated through the MCTS and used to update the models' decision-making processes, ensuring continuous learning and adaptation.

Inference Process

During inference:

  • The agent accepts a query, and the Manager model selects a high-level action based on its policy (e.g., search, summarize, or generate).
  • Once an option is chosen, the Worker model executes the corresponding low-level actions. For example, in a search operation, it leverages MCTS to refine the query, retrieves relevant web content, and processes it with the RAGSummarizer.
  • Each inference step is augmented by the agent's existing knowledge base, enabling it to produce more informed and contextually rich responses. Additionally, if Tree of Thought (ToT) is employed, the agent synthesizes a coherent and comprehensive answer based on the thought path.

Model Saving

The agent incorporates a series of save functions to preserve the models:

  • save_worker_model and save_manager_model functions save the worker and manager models independently.
  • The save method preserves the overall state of the agent, which includes its knowledge base, replay memories, and models. This facilitates model reusability and persistent storage, enabling the agent to resume from saved states during training or deployment.

This modular setup enhances flexibility, allowing the agent to dynamically adjust its behavior based on rewards from RL, improvements from experience replay, and efficient decision-making through MCTS. Additionally, by saving LLM training data, it becomes highly reusable for further fine-tuning, offering the opportunity to build specialized, data-driven language models optimized for specific domains or tasks. This task uses Twisted to train the Autonomous Web Agent by interacting with various queries in a simulated or real environment. It collects rewards based on how well the agent navigates and summarizes web content or performs other tasks.

Example Usage

python main_menu.py --task train_agent

Process Details

  • Training: During training, the agent will automatically sample a list of predefined queries, explore web pages, and use reinforcement learning to maximize its reward based on its actions. The training log provides insights into each episode's reward and the agent’s progress.

  • Logging: Logs are recorded to agent_training.log and provide information about each episode, such as the query, the total reward, and the episode duration. Errors are logged, and if an episode times out, a negative reward is given.

3. Testing the Tree of Thought Search Agent (test_agent)

This task lets you test the Tree of Thought Search Agent either in an interactive mode or by specifying a single query. In interactive mode, the user can repeatedly enter queries, and the agent will process them sequentially, producing responses based on the Tree of Thought architecture.

Example Usage

Interactive Mode:

python main_menu.py --task test_agent

Single Query Mode:

python main_menu.py --task test_agent --query "What are the impacts of renewable energy on global sustainability?"

Arguments Specific to test_agent

  • --query: If provided, the agent will process this specific query and return a response. This is ideal for quick, one-off tests or evaluations. If not provided, the program will start an interactive session where you can repeatedly input queries and view the agent's response.

Interactive Mode Details

  • Input: In interactive mode, enter a query and press Enter. The agent will respond based on its training and the Tree of Thought methodology, traversing different thought paths to generate a response.

  • Exiting: To exit the interactive session, type quit and press Enter. The agent will then save any new knowledge it has gained and exit the program.

Additional Tips and Considerations

  • Adjusting Memory Constraints: The batch size and model architecture (number of layers, dimensions, etc.) affect memory usage. If you encounter memory errors, reduce the batch size or sequence length.

  • Training Time: Training the LLM and World Model may take considerable time, especially with large datasets or complex models. Use fewer epochs or a smaller dataset to speed up initial trials.

  • Model Save Path: The train_llm_world task saves the model at the end of each epoch. Ensure you have enough storage space and specify a save directory if desired.

  • Logging: Detailed logs for train_agent are saved to a file, which can help track progress, debug errors, and measure performance.

World Model with MCTS and Transformer Components

Model Overview

This model is a World Model that combines Transformers, Mixture of Experts (MoE) layers, Monte Carlo Tree Search (MCTS), and Proximal Policy Optimization (PPO) to simulate and optimize a state-based environment. Designed for complex tasks involving decision-making and action prediction, this model leverages powerful components to encode, predict, and enhance action sequences.

Key Components

  1. Transformer: The model uses a custom Transformer with rotary positional encoding and Mixture of Experts (MoE) layers. It serves as both an encoder and decoder, enabling sequential processing of input and target data.
  2. MCTS: The Monte Carlo Tree Search module iteratively simulates actions to select the best possible path based on exploration and exploitation.
  3. PPO Agent: A Proximal Policy Optimization agent is employed to update the policy and value functions. PPO loss is combined with other regularization losses to improve model performance.
  4. Custom Losses: Several custom loss functions are implemented to help guide the model’s learning, including Covariance Regularization, Dynamics Performance Loss, Thought Consistency Loss, and more.

Intended Use

This model is suitable for tasks that require complex decision-making and optimization based on action-state transitions. It can be applied in fields like game development, reinforcement learning environments, and AI simulation tasks where sequential decision-making and policy optimization are essential.

Model Architecture

The model is constructed with several primary components:

  1. Transformer: The transformer has encoder and decoder layers with rotary positional encoding and Mixture of Experts (MoE) to improve generalization and reduce computational cost by routing only parts of the data to certain experts. GELU and SwiGLU activation functions are alternated between the experts.

Multi-Token Prediction with Beam Search

Multi-token prediction in a language model involves generating multiple tokens in sequence, rather than one token at a time. This can improve the fluency and coherence of generated text by allowing the model to "look ahead" and consider multiple possible continuations at each step.

Beam Search is a popular decoding algorithm used for multi-token prediction that allows the model to explore multiple potential sequences and choose the most likely one based on the overall probability. Here's how it works:

  1. Initialization:

    • Start with a single "beam" (sequence) that contains the initial token, typically the beginning-of-sequence (<sos>) token.
  2. Expansion:

    • At each time step, the model generates a probability distribution over the vocabulary for each sequence in the beam.
    • For each sequence, it expands by predicting the next possible tokens, creating new sequences for each possible token.
  3. Scoring:

    • Calculate the score for each expanded sequence by taking the sum (or average) of log probabilities for all tokens in the sequence. Log probabilities are used to avoid underflow and ensure stable computation.
  4. Selection:

    • Keep only the top-k sequences with the highest scores (known as the "beam width" or "beam size") and discard the rest. This limits the number of sequences kept at each step, focusing only on the most promising ones.
  5. Repeat:

    • Continue expanding and scoring until reaching the desired sequence length or the end-of-sequence (<eos>) token.
  6. Final Output:

    • After a set number of steps, or if all sequences end with <eos>, select the sequence with the highest score as the final output.

This process allows the model to generate more fluent and accurate sequences by considering multiple potential continuations at each step and selecting the best overall sequence.


Brief Overview of the Transformer Architecture

The Transformer architecture, introduced in the paper "Attention is All You Need," is a powerful neural network design for handling sequential data, especially in natural language processing tasks. Transformers are known for their parallelism and ability to capture long-range dependencies in data.

Key Components of the Transformer

  1. Embeddings and Positional Encoding:

    • The input tokens are embedded into dense vectors. Since Transformers do not inherently encode the sequence order (as opposed to RNNs), they require positional encodings. These encodings are added to the embeddings to provide information about the token positions in the sequence.
  2. Multi-Head Self-Attention:

    • Each token in a sequence attends to every other token, capturing dependencies regardless of distance. Multiple attention heads allow the model to focus on different parts of the sequence, extracting varied features.
    • In self-attention, the model computes query, key, and value vectors for each token. The output is a weighted sum of values, where the weights are determined by the similarity between the query and key vectors.
  3. Feedforward Neural Networks:

    • After self-attention, a position-wise feedforward neural network is applied to each token independently. This network consists of two linear layers with a ReLU or GELU activation function in between.
  4. Layer Normalization and Residual Connections:

    • To improve learning stability, layer normalization is applied. Residual connections help the model to learn effectively by adding the input of a layer to its output, allowing gradients to flow more easily during backpropagation.
  5. Stacking of Layers:

    • The Transformer consists of multiple encoder and decoder layers. Each encoder layer is identical and consists of self-attention and feedforward layers. The decoder layers include an additional cross-attention mechanism to attend to the encoder's output.
  6. Final Linear and Softmax Layer:

    • The final output of the decoder layer is passed through a linear layer, projecting it onto the vocabulary size. A softmax function then converts the output into a probability distribution over the vocabulary, from which the next token is selected or sampled.

Encoder-Decoder Structure

  • Encoder: The encoder processes the input sequence into a contextualized representation that captures relationships between tokens. It consists of multiple layers of self-attention and feedforward networks.
  • Decoder: The decoder generates the output sequence by attending to both the encoded input representation (using cross-attention) and previously generated tokens (using self-attention). The decoder's output is used to predict the next token in the sequence.
  1. Representation Network: This module encodes the Transformer output to generate a state representation, reducing dimensionality and making it suitable for further processing.
  2. Dynamics Network: This module predicts the next state given a current state and an action. It uses layer normalization and a GELU activation function.
  3. Prediction Network: Predicts both the policy logits and value estimates for a given state. It outputs the probabilities of different actions as well as a single scalar value.
  4. MCTS: This module performs Monte Carlo Tree Search to evaluate the quality of actions over multiple iterations. It expands nodes based on the policy logits from the Prediction Network and simulates the reward by backpropagating value estimates.
  5. PPO Agent: Uses policy and value estimates to calculate PPO loss, which updates the policy while maintaining the constraint on the KL divergence between old and new policies.

The transformer strategically utilises beam search as well as multi token prediction, in order to enrich the encoding from the representation network.

A generated sequence of tokens is an action, for example if a token is t, then an action is:

a_1= {t1,...,tN}

then a policy is a sequence of actions:

P_1 = {a_1,...,aN}

The MCTS and OOPS explores what we are defining as 'thoughts', where a thought is a set of policies:

thought_1 = {P1, ... , PN}

The model explores and exploits thoughts, policies, actions, and tokens, and learning happens at each step of granularity.

Training Details

The model is trained with the following components and techniques:

Training Procedure

  • Data Loading: The data is tokenized and prepared with attention to padding and truncation. Text data is grouped into sequences of fixed length for efficient training.
  • Optimization: Training uses an AdamW optimizer with CosineAnnealingLR scheduler for learning rate adjustments. The Gradient Scaler helps prevent overflow when training with mixed precision.
  • Gradient Accumulation: Since the model can be computationally heavy, gradients are accumulated over several steps to reduce memory usage.
  • Loss Functions: The training process leverages a comprehensive set of custom loss functions:
    • InfoNCE Loss: A contrastive loss to encourage representation similarity between related pairs.
    • Covariance Regularization: Encourages diverse state representations by minimizing co-linearity in embeddings.
    • Dynamics Performance Loss: Combines MSE and variance losses to penalize incorrect state predictions.
    • Thought Consistency Loss: Encourages the model to output consistent states for similar actions.
    • Policy Value Joint Loss: A weighted combination of policy and value loss for the PPO agent.
    • Action Diversity Reward: Rewards diverse action embeddings to avoid mode collapse.
    • Exploration Regularization: Encourages exploration by penalizing high visitation counts.
    • KL Divergence Loss: Keeps the policy update close to the previous policy to stabilize training.

Evaluation

After each epoch, the model is evaluated on the validation set, computing the average loss over the dataset. The evaluation function utilizes the same loss functions as training but does not backpropagate, allowing it to be run in inference mode.

Checkpoints

At the end of each epoch, the model saves checkpoints of all components, enabling easy resumption or further fine-tuning as needed.

Usage

To use this model, ensure you have the necessary libraries installed, including torch, transformers, datasets, and argparse. The model can be initialized with pre-trained weights for the Transformer, and custom paths for saving checkpoints can be specified. Here’s an example of how to start training:

To Train Language Model


python your_script.py --model_name "gpt2" --dataset_name "wikitext" --dataset_config "wikitext-2-raw-v1" --batch_size 2 --num_epochs 3 --transformer_model_path "path/to/transformer/model"

To Train World Model


python lightbulb_WM.py --model_name 'gpt2' --dataset_name 'wikitext' --dataset_config 'wikitext-2-raw-v1' --batch_size 2 --num_epochs 3 --max_length 128 --learning_rate 1e-4 --save_dir './models'  --transformer_model_path 'path/to/transformer/model'

Language Model Args:

parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name or path')
parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name from HuggingFace Datasets')
parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs')
parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length')
parser.add_argument('--accumulation_steps', type=int, default=4, help='Gradient accumulation steps')
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
parser.add_argument('--alpha', type=float, default=0.1, help='Entropy regularization weight')
parser.add_argument('--beta', type=float, default=0.1, help='Variance regularization weight')
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
parser.add_argument('--save_dir', type=str, default='./models', help='Directory to save the models')
parser.add_argument('--temperature', type=float, default=1.0, help='Temperature parameter for entropy and variance')

World Model Args:

parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name or path')
parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name from HuggingFace Datasets')
parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
parser.add_argument('--batch_size', type=int, default=2, help='Batch size')
parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs')
parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length')
parser.add_argument('--mcts_iterations', type=int, default=5, help='Number of MCTS Iterations')
parser.add_argument('--mcts_exploration_constant', type=float, default=1.414, help='Learning rate')
parser.add_argument('--accumulation_steps', type=int, default=4, help='Gradient accumulation steps')
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
parser.add_argument('--alpha', type=float, default=0.1, help='Entropy regularization weight')
parser.add_argument('--beta', type=float, default=0.1, help='Variance regularization weight')
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
parser.add_argument('--save_dir', type=str, default='./models', help='Directory to save the models')
parser.add_argument('--temperature', type=float, default=1.0, help='Temperature parameter for entropy and variance')
parser.add_argument('--transformer_model_path', type=str, required=True, help='Path to the saved Transformer model')

This script will train the model on the specified dataset for the defined number of epochs, using a batch size of 2, and loading a pretrained Transformer model from the specified path.

Model Hyperparameters

Here are the main parameters you can set:

  • --model_name: Name of the pretrained model for tokenization.
  • --dataset_name: Hugging Face dataset name.
  • --batch_size: Batch size for training.
  • --num_epochs: Number of epochs to train.
  • --max_length: Max sequence length.
  • --transformer_model_path: Path to the pretrained Transformer model.
  • --learning_rate: Learning rate for optimizer.
  • --save_dir: Directory to save model checkpoints.
  • --temperature, --alpha, --beta, --lambda_reg: Hyperparameters for regularization.

Expected Results

As training proceeds, you should see progressively lower training and evaluation losses. Upon completion, the model can perform complex decision-making tasks by generating sequences of actions with MCTS and PPO optimization.

Requirements

This code requires:

  • Python 3.7+
  • torch>=1.7.1
  • transformers
  • datasets
  • argparse

Limitations

Due to the heavy computational nature of this model, training time may be significant, especially on a CPU. GPU support is recommended for efficient training. Additionally, the MCTS and PPO implementations here are designed for demonstration purposes and may need further tuning for specific use cases.

Citation

If you use this model in your research, please cite the author.


This model card should provide an overview for anyone looking to understand, utilize, or modify your World Model with MCTS and Transformer components.