RobbiePasquale commited on
Commit
17d845e
·
verified ·
1 Parent(s): 8e79db6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +2 -113
README.md CHANGED
@@ -206,10 +206,8 @@ This model is a **World Model** that combines **Transformers**, **Mixture of Exp
206
  1. **Transformer**: The model uses a custom Transformer with rotary positional encoding and Mixture of Experts (MoE) layers. It serves as both an encoder and decoder, enabling sequential processing of input and target data.
207
  2. **MCTS**: The Monte Carlo Tree Search module iteratively simulates actions to select the best possible path based on exploration and exploitation.
208
  3. **PPO Agent**: A Proximal Policy Optimization agent is employed to update the policy and value functions. PPO loss is combined with other regularization losses to improve model performance.
209
- 4. **Custom Losses**: Several custom loss functions are implemented to help guide the model’s learning, including Covariance Regularization, Dynamics Performance Loss, Thought Consistency Loss, and more.
210
 
211
- ### Intended Use
212
- This model is suitable for tasks that require complex decision-making and optimization based on action-state transitions. It can be applied in fields like game development, reinforcement learning environments, and AI simulation tasks where sequential decision-making and policy optimization are essential.
213
 
214
  ## Model Architecture
215
 
@@ -245,36 +243,7 @@ This process allows the model to generate more fluent and accurate sequences by
245
 
246
  ---
247
 
248
- ### Brief Overview of the Transformer Architecture
249
-
250
- The Transformer architecture, introduced in the paper "Attention is All You Need," is a powerful neural network design for handling sequential data, especially in natural language processing tasks. Transformers are known for their parallelism and ability to capture long-range dependencies in data.
251
-
252
- #### Key Components of the Transformer
253
-
254
- 1. **Embeddings and Positional Encoding**:
255
- - The input tokens are embedded into dense vectors. Since Transformers do not inherently encode the sequence order (as opposed to RNNs), they require **positional encodings**. These encodings are added to the embeddings to provide information about the token positions in the sequence.
256
-
257
- 2. **Multi-Head Self-Attention**:
258
- - Each token in a sequence attends to every other token, capturing dependencies regardless of distance. Multiple attention heads allow the model to focus on different parts of the sequence, extracting varied features.
259
- - In self-attention, the model computes **query**, **key**, and **value** vectors for each token. The output is a weighted sum of values, where the weights are determined by the similarity between the query and key vectors.
260
-
261
- 3. **Feedforward Neural Networks**:
262
- - After self-attention, a position-wise feedforward neural network is applied to each token independently. This network consists of two linear layers with a ReLU or GELU activation function in between.
263
-
264
- 4. **Layer Normalization and Residual Connections**:
265
- - To improve learning stability, **layer normalization** is applied. Residual connections help the model to learn effectively by adding the input of a layer to its output, allowing gradients to flow more easily during backpropagation.
266
-
267
- 5. **Stacking of Layers**:
268
- - The Transformer consists of **multiple encoder and decoder layers**. Each encoder layer is identical and consists of self-attention and feedforward layers. The decoder layers include an additional cross-attention mechanism to attend to the encoder's output.
269
-
270
- 6. **Final Linear and Softmax Layer**:
271
- - The final output of the decoder layer is passed through a linear layer, projecting it onto the vocabulary size. A **softmax** function then converts the output into a probability distribution over the vocabulary, from which the next token is selected or sampled.
272
-
273
- #### Encoder-Decoder Structure
274
-
275
- - **Encoder**: The encoder processes the input sequence into a contextualized representation that captures relationships between tokens. It consists of multiple layers of self-attention and feedforward networks.
276
- - **Decoder**: The decoder generates the output sequence by attending to both the encoded input representation (using cross-attention) and previously generated tokens (using self-attention). The decoder's output is used to predict the next token in the sequence.
277
-
278
 
279
  2. **Representation Network**: This module encodes the Transformer output to generate a state representation, reducing dimensionality and making it suitable for further processing.
280
  3. **Dynamics Network**: This module predicts the next state given a current state and an action. It uses layer normalization and a GELU activation function.
@@ -299,8 +268,6 @@ thought_1 = {P1, ... , PN}
299
  The model explores and exploits thoughts, policies, actions, and tokens, and learning happens at each step of granularity.
300
 
301
 
302
-
303
-
304
  ## Training Details
305
 
306
  The model is trained with the following components and techniques:
@@ -325,77 +292,6 @@ After each epoch, the model is evaluated on the validation set, computing the av
325
  ### Checkpoints
326
  At the end of each epoch, the model saves checkpoints of all components, enabling easy resumption or further fine-tuning as needed.
327
 
328
- ## Usage
329
-
330
- To use this model, ensure you have the necessary libraries installed, including `torch`, `transformers`, `datasets`, and `argparse`. The model can be initialized with pre-trained weights for the Transformer, and custom paths for saving checkpoints can be specified. Here’s an example of how to start training:
331
-
332
- # To Train Language Model
333
- ```bash
334
-
335
- python your_script.py --model_name "gpt2" --dataset_name "wikitext" --dataset_config "wikitext-2-raw-v1" --batch_size 2 --num_epochs 3 --transformer_model_path "path/to/transformer/model"
336
- ```
337
-
338
- # To Train World Model
339
- ```bash
340
-
341
- python lightbulb_WM.py --model_name 'gpt2' --dataset_name 'wikitext' --dataset_config 'wikitext-2-raw-v1' --batch_size 2 --num_epochs 3 --max_length 128 --learning_rate 1e-4 --save_dir './models' --transformer_model_path 'path/to/transformer/model'
342
- ```
343
-
344
- # Language Model Args:
345
-
346
- parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name or path')
347
- parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name from HuggingFace Datasets')
348
- parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
349
- parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
350
- parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs')
351
- parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length')
352
- parser.add_argument('--accumulation_steps', type=int, default=4, help='Gradient accumulation steps')
353
- parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
354
- parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
355
- parser.add_argument('--alpha', type=float, default=0.1, help='Entropy regularization weight')
356
- parser.add_argument('--beta', type=float, default=0.1, help='Variance regularization weight')
357
- parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
358
- parser.add_argument('--save_dir', type=str, default='./models', help='Directory to save the models')
359
- parser.add_argument('--temperature', type=float, default=1.0, help='Temperature parameter for entropy and variance')
360
-
361
- # World Model Args:
362
-
363
- parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name or path')
364
- parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name from HuggingFace Datasets')
365
- parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
366
- parser.add_argument('--batch_size', type=int, default=2, help='Batch size')
367
- parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs')
368
- parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length')
369
- parser.add_argument('--mcts_iterations', type=int, default=5, help='Number of MCTS Iterations')
370
- parser.add_argument('--mcts_exploration_constant', type=float, default=1.414, help='Learning rate')
371
- parser.add_argument('--accumulation_steps', type=int, default=4, help='Gradient accumulation steps')
372
- parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
373
- parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
374
- parser.add_argument('--alpha', type=float, default=0.1, help='Entropy regularization weight')
375
- parser.add_argument('--beta', type=float, default=0.1, help='Variance regularization weight')
376
- parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
377
- parser.add_argument('--save_dir', type=str, default='./models', help='Directory to save the models')
378
- parser.add_argument('--temperature', type=float, default=1.0, help='Temperature parameter for entropy and variance')
379
- parser.add_argument('--transformer_model_path', type=str, required=True, help='Path to the saved Transformer model')
380
-
381
-
382
-
383
- This script will train the model on the specified dataset for the defined number of epochs, using a batch size of 2, and loading a pretrained Transformer model from the specified path.
384
-
385
- ### Model Hyperparameters
386
- Here are the main parameters you can set:
387
- - `--model_name`: Name of the pretrained model for tokenization.
388
- - `--dataset_name`: Hugging Face dataset name.
389
- - `--batch_size`: Batch size for training.
390
- - `--num_epochs`: Number of epochs to train.
391
- - `--max_length`: Max sequence length.
392
- - `--transformer_model_path`: Path to the pretrained Transformer model.
393
- - `--learning_rate`: Learning rate for optimizer.
394
- - `--save_dir`: Directory to save model checkpoints.
395
- - `--temperature`, `--alpha`, `--beta`, `--lambda_reg`: Hyperparameters for regularization.
396
-
397
- ### Expected Results
398
- As training proceeds, you should see progressively lower training and evaluation losses. Upon completion, the model can perform complex decision-making tasks by generating sequences of actions with MCTS and PPO optimization.
399
 
400
  ## Requirements
401
 
@@ -406,14 +302,7 @@ This code requires:
406
  - `datasets`
407
  - `argparse`
408
 
409
- ## Limitations
410
-
411
- Due to the heavy computational nature of this model, training time may be significant, especially on a CPU. GPU support is recommended for efficient training. Additionally, the MCTS and PPO implementations here are designed for demonstration purposes and may need further tuning for specific use cases.
412
 
413
  ## Citation
414
 
415
  If you use this model in your research, please cite the author.
416
-
417
- ---
418
-
419
- This model card should provide an overview for anyone looking to understand, utilize, or modify your World Model with MCTS and Transformer components.
 
206
  1. **Transformer**: The model uses a custom Transformer with rotary positional encoding and Mixture of Experts (MoE) layers. It serves as both an encoder and decoder, enabling sequential processing of input and target data.
207
  2. **MCTS**: The Monte Carlo Tree Search module iteratively simulates actions to select the best possible path based on exploration and exploitation.
208
  3. **PPO Agent**: A Proximal Policy Optimization agent is employed to update the policy and value functions. PPO loss is combined with other regularization losses to improve model performance.
209
+ 4. **Custom Objective Functions**: Several custom loss functions are implemented to help guide the model’s learning, including Covariance Regularization, Dynamics Performance Loss, Thought Consistency Loss, and more.
210
 
 
 
211
 
212
  ## Model Architecture
213
 
 
243
 
244
  ---
245
 
246
+ ## World Model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
  2. **Representation Network**: This module encodes the Transformer output to generate a state representation, reducing dimensionality and making it suitable for further processing.
249
  3. **Dynamics Network**: This module predicts the next state given a current state and an action. It uses layer normalization and a GELU activation function.
 
268
  The model explores and exploits thoughts, policies, actions, and tokens, and learning happens at each step of granularity.
269
 
270
 
 
 
271
  ## Training Details
272
 
273
  The model is trained with the following components and techniques:
 
292
  ### Checkpoints
293
  At the end of each epoch, the model saves checkpoints of all components, enabling easy resumption or further fine-tuning as needed.
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
  ## Requirements
297
 
 
302
  - `datasets`
303
  - `argparse`
304
 
 
 
 
305
 
306
  ## Citation
307
 
308
  If you use this model in your research, please cite the author.