Update README.md
Browse files
README.md
CHANGED
@@ -206,10 +206,8 @@ This model is a **World Model** that combines **Transformers**, **Mixture of Exp
|
|
206 |
1. **Transformer**: The model uses a custom Transformer with rotary positional encoding and Mixture of Experts (MoE) layers. It serves as both an encoder and decoder, enabling sequential processing of input and target data.
|
207 |
2. **MCTS**: The Monte Carlo Tree Search module iteratively simulates actions to select the best possible path based on exploration and exploitation.
|
208 |
3. **PPO Agent**: A Proximal Policy Optimization agent is employed to update the policy and value functions. PPO loss is combined with other regularization losses to improve model performance.
|
209 |
-
4. **Custom
|
210 |
|
211 |
-
### Intended Use
|
212 |
-
This model is suitable for tasks that require complex decision-making and optimization based on action-state transitions. It can be applied in fields like game development, reinforcement learning environments, and AI simulation tasks where sequential decision-making and policy optimization are essential.
|
213 |
|
214 |
## Model Architecture
|
215 |
|
@@ -245,36 +243,7 @@ This process allows the model to generate more fluent and accurate sequences by
|
|
245 |
|
246 |
---
|
247 |
|
248 |
-
|
249 |
-
|
250 |
-
The Transformer architecture, introduced in the paper "Attention is All You Need," is a powerful neural network design for handling sequential data, especially in natural language processing tasks. Transformers are known for their parallelism and ability to capture long-range dependencies in data.
|
251 |
-
|
252 |
-
#### Key Components of the Transformer
|
253 |
-
|
254 |
-
1. **Embeddings and Positional Encoding**:
|
255 |
-
- The input tokens are embedded into dense vectors. Since Transformers do not inherently encode the sequence order (as opposed to RNNs), they require **positional encodings**. These encodings are added to the embeddings to provide information about the token positions in the sequence.
|
256 |
-
|
257 |
-
2. **Multi-Head Self-Attention**:
|
258 |
-
- Each token in a sequence attends to every other token, capturing dependencies regardless of distance. Multiple attention heads allow the model to focus on different parts of the sequence, extracting varied features.
|
259 |
-
- In self-attention, the model computes **query**, **key**, and **value** vectors for each token. The output is a weighted sum of values, where the weights are determined by the similarity between the query and key vectors.
|
260 |
-
|
261 |
-
3. **Feedforward Neural Networks**:
|
262 |
-
- After self-attention, a position-wise feedforward neural network is applied to each token independently. This network consists of two linear layers with a ReLU or GELU activation function in between.
|
263 |
-
|
264 |
-
4. **Layer Normalization and Residual Connections**:
|
265 |
-
- To improve learning stability, **layer normalization** is applied. Residual connections help the model to learn effectively by adding the input of a layer to its output, allowing gradients to flow more easily during backpropagation.
|
266 |
-
|
267 |
-
5. **Stacking of Layers**:
|
268 |
-
- The Transformer consists of **multiple encoder and decoder layers**. Each encoder layer is identical and consists of self-attention and feedforward layers. The decoder layers include an additional cross-attention mechanism to attend to the encoder's output.
|
269 |
-
|
270 |
-
6. **Final Linear and Softmax Layer**:
|
271 |
-
- The final output of the decoder layer is passed through a linear layer, projecting it onto the vocabulary size. A **softmax** function then converts the output into a probability distribution over the vocabulary, from which the next token is selected or sampled.
|
272 |
-
|
273 |
-
#### Encoder-Decoder Structure
|
274 |
-
|
275 |
-
- **Encoder**: The encoder processes the input sequence into a contextualized representation that captures relationships between tokens. It consists of multiple layers of self-attention and feedforward networks.
|
276 |
-
- **Decoder**: The decoder generates the output sequence by attending to both the encoded input representation (using cross-attention) and previously generated tokens (using self-attention). The decoder's output is used to predict the next token in the sequence.
|
277 |
-
|
278 |
|
279 |
2. **Representation Network**: This module encodes the Transformer output to generate a state representation, reducing dimensionality and making it suitable for further processing.
|
280 |
3. **Dynamics Network**: This module predicts the next state given a current state and an action. It uses layer normalization and a GELU activation function.
|
@@ -299,8 +268,6 @@ thought_1 = {P1, ... , PN}
|
|
299 |
The model explores and exploits thoughts, policies, actions, and tokens, and learning happens at each step of granularity.
|
300 |
|
301 |
|
302 |
-
|
303 |
-
|
304 |
## Training Details
|
305 |
|
306 |
The model is trained with the following components and techniques:
|
@@ -325,77 +292,6 @@ After each epoch, the model is evaluated on the validation set, computing the av
|
|
325 |
### Checkpoints
|
326 |
At the end of each epoch, the model saves checkpoints of all components, enabling easy resumption or further fine-tuning as needed.
|
327 |
|
328 |
-
## Usage
|
329 |
-
|
330 |
-
To use this model, ensure you have the necessary libraries installed, including `torch`, `transformers`, `datasets`, and `argparse`. The model can be initialized with pre-trained weights for the Transformer, and custom paths for saving checkpoints can be specified. Here’s an example of how to start training:
|
331 |
-
|
332 |
-
# To Train Language Model
|
333 |
-
```bash
|
334 |
-
|
335 |
-
python your_script.py --model_name "gpt2" --dataset_name "wikitext" --dataset_config "wikitext-2-raw-v1" --batch_size 2 --num_epochs 3 --transformer_model_path "path/to/transformer/model"
|
336 |
-
```
|
337 |
-
|
338 |
-
# To Train World Model
|
339 |
-
```bash
|
340 |
-
|
341 |
-
python lightbulb_WM.py --model_name 'gpt2' --dataset_name 'wikitext' --dataset_config 'wikitext-2-raw-v1' --batch_size 2 --num_epochs 3 --max_length 128 --learning_rate 1e-4 --save_dir './models' --transformer_model_path 'path/to/transformer/model'
|
342 |
-
```
|
343 |
-
|
344 |
-
# Language Model Args:
|
345 |
-
|
346 |
-
parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name or path')
|
347 |
-
parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name from HuggingFace Datasets')
|
348 |
-
parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
|
349 |
-
parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
|
350 |
-
parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs')
|
351 |
-
parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length')
|
352 |
-
parser.add_argument('--accumulation_steps', type=int, default=4, help='Gradient accumulation steps')
|
353 |
-
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
|
354 |
-
parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
|
355 |
-
parser.add_argument('--alpha', type=float, default=0.1, help='Entropy regularization weight')
|
356 |
-
parser.add_argument('--beta', type=float, default=0.1, help='Variance regularization weight')
|
357 |
-
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
|
358 |
-
parser.add_argument('--save_dir', type=str, default='./models', help='Directory to save the models')
|
359 |
-
parser.add_argument('--temperature', type=float, default=1.0, help='Temperature parameter for entropy and variance')
|
360 |
-
|
361 |
-
# World Model Args:
|
362 |
-
|
363 |
-
parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name or path')
|
364 |
-
parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name from HuggingFace Datasets')
|
365 |
-
parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
|
366 |
-
parser.add_argument('--batch_size', type=int, default=2, help='Batch size')
|
367 |
-
parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs')
|
368 |
-
parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length')
|
369 |
-
parser.add_argument('--mcts_iterations', type=int, default=5, help='Number of MCTS Iterations')
|
370 |
-
parser.add_argument('--mcts_exploration_constant', type=float, default=1.414, help='Learning rate')
|
371 |
-
parser.add_argument('--accumulation_steps', type=int, default=4, help='Gradient accumulation steps')
|
372 |
-
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
|
373 |
-
parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
|
374 |
-
parser.add_argument('--alpha', type=float, default=0.1, help='Entropy regularization weight')
|
375 |
-
parser.add_argument('--beta', type=float, default=0.1, help='Variance regularization weight')
|
376 |
-
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
|
377 |
-
parser.add_argument('--save_dir', type=str, default='./models', help='Directory to save the models')
|
378 |
-
parser.add_argument('--temperature', type=float, default=1.0, help='Temperature parameter for entropy and variance')
|
379 |
-
parser.add_argument('--transformer_model_path', type=str, required=True, help='Path to the saved Transformer model')
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
This script will train the model on the specified dataset for the defined number of epochs, using a batch size of 2, and loading a pretrained Transformer model from the specified path.
|
384 |
-
|
385 |
-
### Model Hyperparameters
|
386 |
-
Here are the main parameters you can set:
|
387 |
-
- `--model_name`: Name of the pretrained model for tokenization.
|
388 |
-
- `--dataset_name`: Hugging Face dataset name.
|
389 |
-
- `--batch_size`: Batch size for training.
|
390 |
-
- `--num_epochs`: Number of epochs to train.
|
391 |
-
- `--max_length`: Max sequence length.
|
392 |
-
- `--transformer_model_path`: Path to the pretrained Transformer model.
|
393 |
-
- `--learning_rate`: Learning rate for optimizer.
|
394 |
-
- `--save_dir`: Directory to save model checkpoints.
|
395 |
-
- `--temperature`, `--alpha`, `--beta`, `--lambda_reg`: Hyperparameters for regularization.
|
396 |
-
|
397 |
-
### Expected Results
|
398 |
-
As training proceeds, you should see progressively lower training and evaluation losses. Upon completion, the model can perform complex decision-making tasks by generating sequences of actions with MCTS and PPO optimization.
|
399 |
|
400 |
## Requirements
|
401 |
|
@@ -406,14 +302,7 @@ This code requires:
|
|
406 |
- `datasets`
|
407 |
- `argparse`
|
408 |
|
409 |
-
## Limitations
|
410 |
-
|
411 |
-
Due to the heavy computational nature of this model, training time may be significant, especially on a CPU. GPU support is recommended for efficient training. Additionally, the MCTS and PPO implementations here are designed for demonstration purposes and may need further tuning for specific use cases.
|
412 |
|
413 |
## Citation
|
414 |
|
415 |
If you use this model in your research, please cite the author.
|
416 |
-
|
417 |
-
---
|
418 |
-
|
419 |
-
This model card should provide an overview for anyone looking to understand, utilize, or modify your World Model with MCTS and Transformer components.
|
|
|
206 |
1. **Transformer**: The model uses a custom Transformer with rotary positional encoding and Mixture of Experts (MoE) layers. It serves as both an encoder and decoder, enabling sequential processing of input and target data.
|
207 |
2. **MCTS**: The Monte Carlo Tree Search module iteratively simulates actions to select the best possible path based on exploration and exploitation.
|
208 |
3. **PPO Agent**: A Proximal Policy Optimization agent is employed to update the policy and value functions. PPO loss is combined with other regularization losses to improve model performance.
|
209 |
+
4. **Custom Objective Functions**: Several custom loss functions are implemented to help guide the model’s learning, including Covariance Regularization, Dynamics Performance Loss, Thought Consistency Loss, and more.
|
210 |
|
|
|
|
|
211 |
|
212 |
## Model Architecture
|
213 |
|
|
|
243 |
|
244 |
---
|
245 |
|
246 |
+
## World Model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
|
248 |
2. **Representation Network**: This module encodes the Transformer output to generate a state representation, reducing dimensionality and making it suitable for further processing.
|
249 |
3. **Dynamics Network**: This module predicts the next state given a current state and an action. It uses layer normalization and a GELU activation function.
|
|
|
268 |
The model explores and exploits thoughts, policies, actions, and tokens, and learning happens at each step of granularity.
|
269 |
|
270 |
|
|
|
|
|
271 |
## Training Details
|
272 |
|
273 |
The model is trained with the following components and techniques:
|
|
|
292 |
### Checkpoints
|
293 |
At the end of each epoch, the model saves checkpoints of all components, enabling easy resumption or further fine-tuning as needed.
|
294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
|
296 |
## Requirements
|
297 |
|
|
|
302 |
- `datasets`
|
303 |
- `argparse`
|
304 |
|
|
|
|
|
|
|
305 |
|
306 |
## Citation
|
307 |
|
308 |
If you use this model in your research, please cite the author.
|
|
|
|
|
|
|
|