RobbiePasquale commited on
Commit
e1392d6
·
verified ·
1 Parent(s): adebc30

Upload 20 files

Browse files
ToTSearch.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ToTSearch.py
2
+ import random
3
+ from typing import List, Dict, Any, Generator
4
+ from sentence_transformers import SentenceTransformer, util
5
+ import torch
6
+ import numpy as np
7
+ from twisted.internet import defer
8
+ from agent import AutonomousWebAgent
9
+ from mcts import MCTS, MCTSNode
10
+ import logging
11
+ from twisted.internet.defer import Deferred
12
+
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class ToTNode:
17
+ def __init__(self, thought, parent=None):
18
+ self.thought = thought
19
+ self.parent = parent
20
+ self.children = []
21
+ self.visits = 0
22
+ self.value = 0
23
+ self.search_results = []
24
+ self.mcts_node = None
25
+
26
+ def add_child(self, child_thought):
27
+ child = ToTNode(child_thought, self)
28
+ self.children.append(child)
29
+ return child
30
+
31
+ def update(self, reward):
32
+ self.visits += 1
33
+ self.value += reward
34
+
35
+ class ToTSearch:
36
+ def __init__(self, agent: AutonomousWebAgent, model='all-MiniLM-L6-v2', max_depth=3, num_thoughts=3, num_simulations=100):
37
+ self.agent = agent
38
+ self.model = SentenceTransformer(model)
39
+ self.max_depth = max_depth
40
+ self.num_thoughts = num_thoughts
41
+ self.num_simulations = num_simulations
42
+ self.mcts = MCTS(initial_state="", num_simulations=num_simulations)
43
+
44
+ def generate_thoughts(self, query: str) -> List[str]:
45
+ prompt = f"""Given the query "{query}", generate {self.num_thoughts} distinct thoughts or approaches to address it.
46
+ Each thought should be a complete sentence and offer a unique perspective or solution path."""
47
+
48
+ thoughts = self.agent.generate_text(prompt).split('\n')
49
+ return [thought.strip() for thought in thoughts if thought.strip()]
50
+
51
+ def expand_thought(self, thought: str) -> List[str]:
52
+ prompt = f"""Expand on the following thought: "{thought}"
53
+ Generate {self.num_thoughts} more specific sub-thoughts or considerations.
54
+ Each sub-thought should be a complete sentence and offer additional detail or a new angle."""
55
+
56
+ expansions = self.agent.generate_text(prompt).split('\n')
57
+ return [exp.strip() for exp in expansions if exp.strip()]
58
+
59
+ def evaluate_thought(self, thought: str, query: str) -> float:
60
+ thought_embedding = self.model.encode(thought)
61
+ query_embedding = self.model.encode(query)
62
+ return util.pytorch_cos_sim(thought_embedding, query_embedding).item()
63
+
64
+ @defer.inlineCallbacks
65
+ def search_and_augment(self, thought: str) -> Generator[Deferred, Any, List[Dict[str, Any]]]:
66
+ search_results = yield self.agent.retrieve_from_web(thought)
67
+ for result in search_results:
68
+ result['originating_thought'] = thought
69
+ defer.returnValue(search_results)
70
+
71
+ def select(self, node: ToTNode) -> ToTNode:
72
+ while node.children:
73
+ # Choose a node with zero visits or select based on the value/visits ratio
74
+ if any(child.visits == 0 for child in node.children):
75
+ zero_visit_nodes = [child for child in node.children if child.visits == 0]
76
+ selected_node = random.choice(zero_visit_nodes)
77
+ logger.debug(f"Selected node with 0 visits: {selected_node.thought}")
78
+ return selected_node
79
+ else:
80
+ selected_node = max(node.children, key=lambda child: (child.value / child.visits) if child.visits > 0 else float('-inf'))
81
+ logger.debug(f"Selected node based on value/visits ratio: {selected_node.thought}, value: {selected_node.value}, visits: {selected_node.visits}")
82
+ return selected_node
83
+ return node
84
+
85
+
86
+ def expand(self, node: ToTNode, query: str) -> ToTNode:
87
+ if not node.children and len(node.thought.split()) > 2:
88
+ expansions = self.expand_thought(node.thought)
89
+ for expansion in expansions:
90
+ node.add_child(expansion)
91
+ return random.choice(node.children) if node.children else node
92
+
93
+ @defer.inlineCallbacks
94
+ def simulate(self, node: ToTNode, query: str):
95
+ current_node = node
96
+ depth = 0
97
+ while depth < self.max_depth:
98
+ if not current_node.children:
99
+ break
100
+ current_node = random.choice(current_node.children)
101
+ depth += 1
102
+
103
+ logger.debug(f"Simulating for thought: {current_node.thought}")
104
+
105
+ search_results = yield self.search_and_augment(current_node.thought)
106
+ current_node.search_results = search_results
107
+
108
+ logger.debug(f"Search results count: {len(search_results)}")
109
+
110
+ ranked_results = self.agent.calculate_reward(current_node.thought, query)
111
+ logger.debug(f"Ranked results: {ranked_results}")
112
+
113
+ mcts_node = MCTSNode(current_node.thought)
114
+ current_node.mcts_node = mcts_node
115
+ mcts_total_reward = 0
116
+
117
+ for _ in range(self.num_simulations):
118
+ mcts_reward = yield self.mcts.simulate(mcts_node)
119
+ mcts_total_reward += mcts_reward
120
+ self.mcts.backpropagate(mcts_node, mcts_reward)
121
+
122
+ logger.debug(f"MCTS node visits: {mcts_node.visits}, total reward: {mcts_total_reward}")
123
+
124
+ if mcts_node.visits == 0 or ranked_results == 0:
125
+ logger.warning(f"Avoiding division by zero. MCTS visits: {mcts_node.visits}, Ranked results: {ranked_results}")
126
+ combined_reward = 0
127
+ else:
128
+ combined_reward = (ranked_results + mcts_value) / 2
129
+
130
+ if mcts_node.visits > 0:
131
+ mcts_value = mcts_total_reward / mcts_node.visits
132
+ logger.debug(f"MCTS value: {mcts_value}")
133
+ else:
134
+ mcts_value = 0
135
+ logger.warning(f"MCTS node has 0 visits, assigning value 0")
136
+
137
+ combined_reward = (ranked_results + mcts_value) / 2
138
+ logger.debug(f"Combined reward: {combined_reward}")
139
+
140
+ defer.returnValue(combined_reward)
141
+
142
+ def backpropagate(self, node: ToTNode, reward: float):
143
+ while node:
144
+ node.update(reward)
145
+ node = node.parent
146
+
147
+ @defer.inlineCallbacks
148
+ def tot_search(self, query: str) -> Generator[Deferred, Any, ToTNode]:
149
+ root = ToTNode(query)
150
+ for _ in range(self.num_simulations):
151
+ node = self.select(root)
152
+ node = self.expand(node, query)
153
+ reward = yield self.simulate(node, query)
154
+ self.backpropagate(node, reward)
155
+
156
+ # Update agent's experience replay
157
+ state = self.agent.extract_features(node.thought, query)
158
+ next_state = self.agent.extract_features(node.children[0].thought if node.children else node.thought, query)
159
+ self.agent.remember_worker(state, 0, reward, next_state, False)
160
+
161
+ # Perform agent's replay to update RL models
162
+ self.agent.replay_worker()
163
+ self.agent.replay_manager()
164
+
165
+ defer.returnValue(root)
166
+
167
+ def get_best_path(self, root: ToTNode) -> List[str]:
168
+ path = [root.thought]
169
+ current = root
170
+ while current.children:
171
+ current = max(current.children, key=lambda child: child.value / child.visits if child.visits > 0 else float('-inf'))
172
+ path.append(current.thought)
173
+ return path
174
+
175
+ @defer.inlineCallbacks
176
+ def synthesize_results(self, root: ToTNode, query: str) -> Generator[Deferred, Any, str]:
177
+ best_path = self.get_best_path(root)
178
+ all_results = []
179
+
180
+ def collect_results(node):
181
+ all_results.extend(node.search_results)
182
+ for child in node.children:
183
+ collect_results(child)
184
+
185
+ collect_results(root)
186
+
187
+ # Sort results by relevance
188
+ all_results.sort(key=lambda x: self.evaluate_thought(x['content'], query), reverse=True)
189
+
190
+ # Generate a summary of the top results
191
+ top_results = all_results[:5] # Adjust the number as needed
192
+ summary_prompt = f"Synthesize the following information into a coherent answer for the query '{query}':\n\n"
193
+ summary_prompt += f"Thought path: {' -> '.join(best_path)}\n\n"
194
+ for result in top_results:
195
+ summary_prompt += f"- {result['content'][:200]}...\n"
196
+
197
+ # Use the agent's RAG capabilities for final answer generation
198
+ final_answer = yield self.agent.generate_rag_response(query, top_results)
199
+
200
+ # Save the generated answer and thought path to the agent's knowledge base
201
+ self.agent.add_document_to_kb(
202
+ title=f"ToT Search Result: {query}",
203
+ content=final_answer,
204
+ metadata={"thought_path": best_path}
205
+ )
206
+
207
+ defer.returnValue(final_answer)
208
+
209
+ @defer.inlineCallbacks
210
+ def search(self, query: str) -> Generator[Deferred, Any, str]:
211
+ logger.info(f"Starting ToT search for query: {query}")
212
+ root = yield self.tot_search(query)
213
+ final_answer = yield self.synthesize_results(root, query)
214
+ logger.info(f"ToT search completed for query: {query}")
215
+ defer.returnValue(final_answer)
216
+
217
+ # Usage example:
218
+ # tot_search = ToTSearch(agent)
219
+ # final_answer = yield tot_search.search("What are the latest advancements in renewable energy?")
agent.py ADDED
@@ -0,0 +1,1082 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # agent.py
3
+ # agent.py
4
+ import numpy as np
5
+ from mcts import MCTS
6
+ from ranking import train_ranking_model
7
+ from bs4 import BeautifulSoup
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+ from collections import deque, OrderedDict
12
+ import random
13
+ from sklearn.metrics.pairwise import cosine_similarity
14
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
15
+ from sentence_transformers import SentenceTransformer
16
+ import hashlib
17
+ from twisted.internet import defer
18
+ import logging
19
+ import json
20
+ import os
21
+ from urllib.parse import urlparse
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # ==========================
26
+ # Prioritized Experience Replay
27
+ # ==========================
28
+
29
+ class SumTree:
30
+ """
31
+ SumTree data structure where the parent’s value is the sum of its children.
32
+ Leaf nodes contain the priorities of experiences.
33
+ """
34
+ def __init__(self, capacity):
35
+ self.capacity = capacity
36
+ self.tree = np.zeros(2 * capacity - 1)
37
+ self.data = np.zeros(capacity, dtype=object)
38
+ self.write = 0
39
+ self.n_entries = 0
40
+
41
+ def _propagate(self, idx, change):
42
+ parent = (idx - 1) // 2
43
+ self.tree[parent] += change
44
+ if parent != 0:
45
+ self._propagate(parent, change)
46
+
47
+ def _retrieve(self, idx, s):
48
+ left = 2 * idx + 1
49
+ right = left + 1
50
+
51
+ if left >= len(self.tree):
52
+ return idx
53
+
54
+ if s <= self.tree[left]:
55
+ return self._retrieve(left, s)
56
+ else:
57
+ return self._retrieve(right, s - self.tree[left])
58
+
59
+ def total(self):
60
+ return self.tree[0]
61
+
62
+ def add(self, p, data):
63
+ idx = self.write + self.capacity - 1
64
+
65
+ self.data[self.write] = data
66
+ self.update(idx, p)
67
+
68
+ self.write += 1
69
+ if self.write >= self.capacity:
70
+ self.write = 0
71
+
72
+ if self.n_entries < self.capacity:
73
+ self.n_entries += 1
74
+
75
+ def update(self, idx, p):
76
+ change = p - self.tree[idx]
77
+ self.tree[idx] = p
78
+ self._propagate(idx, change)
79
+
80
+ def get(self, s):
81
+ idx = self._retrieve(0, s)
82
+ data_idx = idx - self.capacity + 1
83
+
84
+ return (idx, self.tree[idx], self.data[data_idx])
85
+
86
+ class PrioritizedReplayMemory:
87
+ def __init__(self, capacity, alpha=0.6):
88
+ self.tree = SumTree(capacity)
89
+ self.alpha = alpha # [0,1] convert the importance of TD error to priority
90
+ self.epsilon = 1e-6 # small amount to avoid zero priority
91
+
92
+ def add(self, error, sample):
93
+ p = (np.abs(error) + self.epsilon) ** self.alpha
94
+ self.tree.add(p, sample)
95
+
96
+ def sample(self, batch_size, beta=0.4):
97
+ batch = []
98
+ idxs = []
99
+ segment = self.tree.total() / batch_size
100
+ priorities = []
101
+
102
+ for i in range(batch_size):
103
+ a = segment * i
104
+ b = segment * (i + 1)
105
+ s = random.uniform(a, b)
106
+ idx, p, data = self.tree.get(s)
107
+ batch.append(data)
108
+ idxs.append(idx)
109
+ priorities.append(p)
110
+
111
+ total = self.tree.total()
112
+ probs = priorities / total
113
+ weights = (self.tree.n_entries * probs) ** (-beta)
114
+ weights /= weights.max()
115
+ return batch, idxs, weights
116
+
117
+ def update(self, idx, error):
118
+ p = (np.abs(error) + self.epsilon) ** self.alpha
119
+ self.tree.update(idx, p)
120
+
121
+ # ==========================
122
+ # Hierarchical Reinforcement Learning (HRL)
123
+ # ==========================
124
+
125
+ class ManagerModel(nn.Module):
126
+ """
127
+ High-level policy model (Manager) that decides which option to execute.
128
+ """
129
+ def __init__(self, input_size, hidden_size, num_options):
130
+ super(ManagerModel, self).__init__()
131
+ self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
132
+ self.fc = nn.Linear(hidden_size, num_options)
133
+ self.layer_norm = nn.LayerNorm(hidden_size)
134
+
135
+ def forward(self, x, hidden=None):
136
+ if x.dim() == 2:
137
+ x = x.unsqueeze(1) # Add a time dimension
138
+ out, hidden = self.lstm(x, hidden)
139
+ last_output = out[:, -1, :]
140
+ last_output = self.layer_norm(last_output)
141
+ option_scores = self.fc(last_output)
142
+ return option_scores, hidden
143
+
144
+ class WorkerModel(nn.Module):
145
+ """
146
+ Low-level policy model (Worker) that executes actions based on the selected option.
147
+ """
148
+ def __init__(self, input_size, hidden_size, action_size):
149
+ super(WorkerModel, self).__init__()
150
+ self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
151
+ self.fc = nn.Linear(hidden_size, action_size)
152
+ self.layer_norm = nn.LayerNorm(hidden_size)
153
+ self.action_size = action_size # Store action_size for reference
154
+
155
+ def forward(self, x, hidden=None):
156
+ if x.dim() == 2:
157
+ x = x.unsqueeze(1) # Add a time dimension
158
+ out, hidden = self.lstm(x, hidden)
159
+ last_output = out[:, -1, :]
160
+ last_output = self.layer_norm(last_output)
161
+ action_scores = self.fc(last_output)
162
+ return action_scores, hidden
163
+
164
+ def act(self, state, epsilon=0.1):
165
+ """
166
+ Selects an action using epsilon-greedy policy.
167
+ """
168
+ if random.random() < epsilon:
169
+ action = random.randint(0, self.action_size - 1)
170
+ return action
171
+ state = torch.FloatTensor(state).unsqueeze(0).to(next(self.parameters()).device)
172
+ with torch.no_grad():
173
+ action_scores, _ = self(state)
174
+ action = torch.argmax(action_scores, dim=1).item()
175
+ return action
176
+
177
+ # ==========================
178
+ # RAGSummarizer Class
179
+ # ==========================
180
+
181
+ class RAGSummarizer:
182
+ def __init__(self, model_name='gpt2', embedding_model='all-MiniLM-L6-v2',
183
+ max_length=150, cache_capacity=100, persistent_cache_path='rag_cache.json'):
184
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
185
+ self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
186
+ self.model = GPT2LMHeadModel.from_pretrained(model_name).to(self.device)
187
+ # Explicitly set the device for SentenceTransformer
188
+ self.embedding_model = SentenceTransformer(embedding_model, device=self.device)
189
+ self.max_length = max_length
190
+ self.cache = LRUCache(cache_capacity)
191
+ self.persistent_cache_path = persistent_cache_path
192
+ self.load_persistent_cache()
193
+
194
+ def load_persistent_cache(self):
195
+ if os.path.exists(self.persistent_cache_path):
196
+ with open(self.persistent_cache_path, 'r', encoding='utf-8') as f:
197
+ try:
198
+ persistent_data = json.load(f)
199
+ for key, value in persistent_data.items():
200
+ self.cache.put(key, value)
201
+ logger.info(f"Loaded persistent cache with {len(persistent_data)} entries.")
202
+ except json.JSONDecodeError:
203
+ logger.warning("Persistent cache file is corrupted. Initializing empty cache.")
204
+ else:
205
+ logger.info("No persistent cache found. Starting with empty cache.")
206
+
207
+ def save_persistent_cache(self):
208
+ with open(self.persistent_cache_path, 'w', encoding='utf-8') as f:
209
+ json.dump(self.cache.cache, f, indent=2)
210
+ logger.info(f"Saved persistent cache with {len(self.cache.cache)} entries.")
211
+
212
+ def save_rag_data(self, query, chunks, embeddings):
213
+ data = {
214
+ "query": query,
215
+ "chunks": chunks,
216
+ "embeddings": embeddings.tolist()
217
+ }
218
+
219
+ os.makedirs("rag_data", exist_ok=True)
220
+
221
+ filename = f"rag_data/{hash(query)}.json"
222
+ with open(filename, 'w') as f:
223
+ json.dump(data, f, indent=2)
224
+
225
+ logger.info(f"Saved RAG data to {filename}")
226
+
227
+ def split_into_chunks(self, text, chunk_size=200):
228
+ words = text.split()
229
+ return [' '.join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]
230
+
231
+ def retrieve_relevant_chunks(self, query, chunks, embeddings, top_k=3):
232
+ if embeddings.size(0) == 0:
233
+ logger.warning("Embeddings are empty. Cannot retrieve relevant chunks.")
234
+ return []
235
+ query_embedding = self.embedding_model.encode([query], convert_to_tensor=True)
236
+ cosine_scores = cosine_similarity(query_embedding.cpu().numpy(), embeddings.cpu().numpy())[0]
237
+ top_indices = cosine_scores.argsort()[-top_k:][::-1]
238
+ # Ensure indices are within bounds
239
+ top_indices = [idx for idx in top_indices if idx < len(chunks)]
240
+ return [chunks[i] for i in top_indices]
241
+
242
+ def get_embeddings(self, chunks):
243
+ # Implement batch processing
244
+ batch_size = 32
245
+ embeddings = []
246
+ for i in range(0, len(chunks), batch_size):
247
+ batch = chunks[i:i+batch_size]
248
+ batch_embeddings = self.embedding_model.encode(batch, convert_to_tensor=True)
249
+ embeddings.append(batch_embeddings)
250
+ if embeddings:
251
+ return torch.cat(embeddings, dim=0)
252
+ else:
253
+ return torch.tensor([])
254
+
255
+ def generate_summary(self, query, relevant_chunks):
256
+ cache_key = hashlib.md5((query + ''.join(relevant_chunks)).encode()).hexdigest()
257
+ cached_summary = self.cache.get(cache_key)
258
+ if cached_summary:
259
+ return cached_summary
260
+
261
+ context = " ".join(relevant_chunks)
262
+ prompt = f"Summarize the following content in relation to '{query}': {context}\n\nSummary:"
263
+
264
+ input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
265
+
266
+ try:
267
+ output = self.model.generate(
268
+ input_ids,
269
+ max_length=input_ids.shape[1] + self.max_length,
270
+ num_return_sequences=1,
271
+ no_repeat_ngram_size=2,
272
+ top_k=50,
273
+ top_p=0.95,
274
+ temperature=0.7,
275
+ early_stopping=True
276
+ )
277
+ except Exception as e:
278
+ logger.error(f"Error during summary generation: {str(e)}")
279
+ return "Summary generation failed."
280
+
281
+ self.save_rag_data(query, relevant_chunks, self.get_embeddings(relevant_chunks))
282
+
283
+ summary = self.tokenizer.decode(output[0], skip_special_tokens=True)
284
+ summary = summary.split("Summary:")[-1].strip()
285
+
286
+ self.cache.put(cache_key, summary)
287
+ self.save_persistent_cache()
288
+
289
+ return summary
290
+
291
+ # ==========================
292
+ # WorldModel Class
293
+ # ==========================
294
+
295
+ class WorldModel(nn.Module):
296
+ def __init__(self, input_size, hidden_size, output_size, num_layers=2, dropout=0.3):
297
+ super(WorldModel, self).__init__()
298
+ self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layers,
299
+ batch_first=True, dropout=dropout)
300
+ self.fc = nn.Linear(hidden_size, output_size)
301
+ self.value_head = nn.Linear(hidden_size, 1)
302
+ self.layer_norm = nn.LayerNorm(hidden_size)
303
+
304
+ def forward(self, x, hidden=None):
305
+ if x.dim() == 2:
306
+ x = x.unsqueeze(1) # Add a time dimension
307
+ out, hidden = self.lstm(x, hidden)
308
+ last_output = out[:, -1, :]
309
+ last_output = self.layer_norm(last_output)
310
+ action_scores = self.fc(last_output)
311
+ state_value = self.value_head(last_output)
312
+ return action_scores, state_value, hidden
313
+
314
+ # ==========================
315
+ # Manager and Worker Classes for HRL
316
+ # ==========================
317
+
318
+ class Manager:
319
+ def __init__(self, state_size, num_options, hidden_size=128, learning_rate=0.001, gamma=0.99,
320
+ epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.01, memory_capacity=1000, device=torch.device("cpu")):
321
+ self.state_size = state_size
322
+ self.num_options = num_options
323
+ self.gamma = gamma
324
+ self.epsilon = epsilon
325
+ self.epsilon_decay = epsilon_decay
326
+ self.epsilon_min = epsilon_min
327
+ self.device = device
328
+
329
+ self.model = ManagerModel(state_size, hidden_size, num_options).to(self.device)
330
+ self.target_model = ManagerModel(state_size, hidden_size, num_options).to(self.device)
331
+ self.optimizer = optim.AdamW(self.model.parameters(), lr=learning_rate, weight_decay=1e-5)
332
+ self.loss_fn = nn.MSELoss()
333
+ self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min', patience=5, factor=0.5, verbose=True)
334
+
335
+ self.memory = PrioritizedReplayMemory(capacity=memory_capacity, alpha=0.6)
336
+
337
+ self.update_target_model()
338
+
339
+ def update_target_model(self):
340
+ self.target_model.load_state_dict(self.model.state_dict())
341
+
342
+ def remember(self, state, option, reward, next_state, done, td_error):
343
+ sample = (state, option, reward, next_state, done)
344
+ self.memory.add(td_error, sample)
345
+
346
+ def act(self, state):
347
+ if random.random() < self.epsilon:
348
+ option = random.randint(0, self.num_options - 1)
349
+ return option
350
+ state = torch.FloatTensor(state).unsqueeze(0).to(self.model.lstm.weight.device)
351
+ with torch.no_grad():
352
+ option_scores, _ = self.model(state)
353
+ option = torch.argmax(option_scores).item()
354
+ return option
355
+
356
+ def replay(self, batch_size, beta=0.4):
357
+ if self.memory.tree.n_entries < batch_size:
358
+ return
359
+ batch, idxs, weights = self.memory.sample(batch_size, beta)
360
+ states, options, rewards, next_states, dones = zip(*batch)
361
+
362
+ states = torch.FloatTensor(states).to(self.model.lstm.weight.device)
363
+ next_states = torch.FloatTensor(next_states).to(self.model.lstm.weight.device)
364
+ options = torch.LongTensor(options).unsqueeze(1).to(self.model.lstm.weight.device)
365
+ rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.model.lstm.weight.device)
366
+ dones = torch.FloatTensor(dones).unsqueeze(1).to(self.model.lstm.weight.device)
367
+ weights = torch.FloatTensor(weights).unsqueeze(1).to(self.model.lstm.weight.device)
368
+
369
+ # Current Q values
370
+ current_q_values, _ = self.model(states)
371
+ current_q_values = current_q_values.gather(1, options)
372
+
373
+ # Target Q values
374
+ with torch.no_grad():
375
+ next_q_values, _ = self.target_model(next_states)
376
+ max_next_q_values = next_q_values.max(1)[0].unsqueeze(1)
377
+ target_q_values = rewards + (self.gamma * max_next_q_values * (1 - dones))
378
+
379
+ # Compute TD errors
380
+ td_errors = target_q_values - current_q_values
381
+
382
+ # Compute loss with importance-sampling weights
383
+ loss = (td_errors.pow(2) * weights).mean()
384
+
385
+ # Optimize the model
386
+ self.optimizer.zero_grad()
387
+ loss.backward()
388
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
389
+ self.optimizer.step()
390
+ self.scheduler.step(loss.item())
391
+
392
+ # Update priorities
393
+ td_errors_np = td_errors.detach().cpu().numpy().squeeze()
394
+ for idx, td_error in zip(idxs, td_errors_np):
395
+ self.memory.update(idx, np.abs(td_error))
396
+
397
+ # Decay epsilon
398
+ if self.epsilon > self.epsilon_min:
399
+ self.epsilon *= self.epsilon_decay
400
+
401
+ # ==========================
402
+ # AutonomousWebAgent Class
403
+ # ==========================
404
+
405
+ def truncate_text(text, max_length=1024):
406
+ tokens = text.split()
407
+ if len(tokens) > max_length:
408
+ return ' '.join(tokens[:max_length])
409
+ return text
410
+
411
+ class AutonomousWebAgent:
412
+ def __init__(self, state_size, action_size, num_options, hidden_size=64, learning_rate=0.001,
413
+ gamma=0.99, epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.01,
414
+ knowledge_base_path='knowledge_base.json'):
415
+ self.state_size = state_size
416
+ self.action_size = action_size
417
+ self.num_options = num_options # Number of high-level options for HRL
418
+ self.gamma = gamma
419
+ self.epsilon = epsilon
420
+ self.epsilon_decay = epsilon_decay
421
+ self.epsilon_min = epsilon_min
422
+
423
+ # Initialize RAGSummarizer first to get the device
424
+ self.summarizer = RAGSummarizer()
425
+ self.device = self.summarizer.device
426
+
427
+ # Initialize SentenceTransformer with the correct device
428
+ self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device)
429
+
430
+ # Low-level (Worker) Model
431
+ self.worker_model = WorldModel(state_size, hidden_size, action_size).to(self.device)
432
+ self.worker_target_model = WorldModel(state_size, hidden_size, action_size).to(self.device)
433
+ self.worker_optimizer = optim.AdamW(self.worker_model.parameters(), lr=learning_rate, weight_decay=1e-5)
434
+ self.worker_loss_fn = nn.MSELoss()
435
+ self.worker_scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.worker_optimizer, 'min', patience=5, factor=0.5, verbose=True)
436
+ self.worker_memory = PrioritizedReplayMemory(capacity=2000, alpha=0.6)
437
+ self.update_worker_target_model()
438
+
439
+ # High-level (Manager) Model
440
+ self.manager = Manager(state_size, num_options, hidden_size=128, learning_rate=learning_rate,
441
+ gamma=gamma, epsilon=epsilon, epsilon_decay=epsilon_decay,
442
+ epsilon_min=epsilon_min, memory_capacity=1000, device=self.device)
443
+
444
+ self.mcts = MCTS(initial_state="")
445
+ logger.info(f"Initialized AutonomousWebAgent with state_size={state_size}, action_size={action_size}, num_options={num_options}")
446
+
447
+ self.site_performance = {} # {(site, query): performance_score}
448
+
449
+ # List of all search sites (base URLs without the query)
450
+ self.all_search_sites = [
451
+ "https://en.wikibooks.org/w/index.php?search=",
452
+ "https://en.wikiversity.org/w/index.php?search=",
453
+ "https://commons.wikimedia.org/w/index.php?search=",
454
+ "https://stackexchange.com/search?q=",
455
+ "https://arxiv.org/search/?query=",
456
+ "https://www.ncbi.nlm.nih.gov/pmc/?term=",
457
+ "https://www.gutenberg.org/ebooks/search/?query=",
458
+ "https://openlibrary.org/search?q=",
459
+ "https://doaj.org/search/articles?ref=homepage&q=",
460
+ "https://www.ted.com/search?q=",
461
+ "https://en.citizendium.org/wiki?search=",
462
+ "https://www.jstor.org/action/doBasicSearch?Query=",
463
+ "https://archive.org/search.php?query=",
464
+ "https://search.scielo.org/?q=",
465
+ "https://paperswithcode.com/search?q=",
466
+ "https://www.reddit.com/search/?q=",
467
+ "https://huggingface.co/models?search=",
468
+ "https://huggingface.co/datasets?search=",
469
+ "https://machinelearningmastery.com/?s=",
470
+ "https://www.kaggle.com/search?q=",
471
+ "https://towardsdatascience.com/search?q=",
472
+ "https://github.com/search?q=",
473
+ "https://stackoverflow.com/search?q=",
474
+ "https://www.youtube.com/results?search_query=",
475
+ "https://www.slideshare.net/search/slideshow?searchfrom=header&q="
476
+ ]
477
+
478
+ # Initialize Knowledge Base
479
+ self.knowledge_base_path = knowledge_base_path
480
+ self.knowledge_base = []
481
+ self.kb_embeddings = None
482
+ self.load_knowledge_base()
483
+
484
+ # Additional Features for State Representation
485
+ self.additional_features = ['image_count', 'script_count', 'css_count']
486
+
487
+ def save(self, filename):
488
+ """Save the entire agent state."""
489
+ state = {
490
+ 'worker_model': self.worker_model.state_dict(),
491
+ 'manager_model': self.manager.model.state_dict(),
492
+ 'worker_optimizer': self.worker_optimizer.state_dict(),
493
+ 'manager_optimizer': self.manager.optimizer.state_dict(),
494
+ 'epsilon': self.epsilon
495
+ }
496
+ torch.save(state, filename)
497
+ logger.info(f"Saved agent state to {filename}")
498
+
499
+ def load(self, filename):
500
+ """Load the entire agent state."""
501
+ state = torch.load(filename, map_location=self.device)
502
+ self.worker_model.load_state_dict(state['worker_model'])
503
+ self.manager.model.load_state_dict(state['manager_model'])
504
+ self.worker_optimizer.load_state_dict(state['worker_optimizer'])
505
+ self.manager.optimizer.load_state_dict(state['manager_optimizer'])
506
+ self.epsilon = state['epsilon']
507
+ logger.info(f"Loaded agent state from {filename}")
508
+
509
+ # ==========================
510
+ # Text Generation
511
+ # ==========================
512
+
513
+ def generate_text(self, prompt):
514
+ # Use the RAGSummarizer to generate text
515
+ chunks = self.summarizer.split_into_chunks(prompt)
516
+ embeddings = self.summarizer.get_embeddings(chunks)
517
+ relevant_chunks = self.summarizer.retrieve_relevant_chunks(query=prompt, chunks=chunks, embeddings=embeddings)
518
+ generated_text = self.summarizer.generate_summary(prompt, relevant_chunks)
519
+ return generated_text
520
+
521
+ # ==========================
522
+ # Knowledge Base Management
523
+ # ==========================
524
+
525
+ def load_knowledge_base(self):
526
+ if not os.path.exists(self.knowledge_base_path):
527
+ logger.warning(f"Knowledge base file {self.knowledge_base_path} does not exist. Initializing empty KB.")
528
+ self.knowledge_base = []
529
+ self.kb_embeddings = torch.tensor([]).to(self.device)
530
+ return
531
+
532
+ with open(self.knowledge_base_path, 'r', encoding='utf-8') as f:
533
+ self.knowledge_base = json.load(f)
534
+
535
+ if self.knowledge_base:
536
+ texts = [doc['content'] for doc in self.knowledge_base]
537
+ self.kb_embeddings = self.embedding_model.encode(texts, convert_to_tensor=True)
538
+ logger.info(f"Loaded {len(self.knowledge_base)} documents into the knowledge base.")
539
+ else:
540
+ self.kb_embeddings = torch.tensor([]).to(self.device)
541
+ logger.info("Knowledge base is empty.")
542
+
543
+ def save_knowledge_base(self):
544
+ with open(self.knowledge_base_path, 'w', encoding='utf-8') as f:
545
+ json.dump(self.knowledge_base, f, indent=2)
546
+ logger.info(f"Knowledge base saved with {len(self.knowledge_base)} documents.")
547
+
548
+ def add_document_to_kb(self, title, content, metadata=None):
549
+ document = {
550
+ "title": title,
551
+ "content": content,
552
+ "metadata": metadata or {}
553
+ }
554
+ self.knowledge_base.append(document)
555
+ # Update embeddings
556
+ new_embedding = self.embedding_model.encode([content], convert_to_tensor=True).to(self.device)
557
+ if self.kb_embeddings.numel() == 0:
558
+ self.kb_embeddings = new_embedding
559
+ else:
560
+ self.kb_embeddings = torch.cat([self.kb_embeddings, new_embedding], dim=0)
561
+ # Save to knowledge base
562
+ self.save_knowledge_base()
563
+ logger.info(f"Added new document to knowledge base: {title}")
564
+
565
+ def retrieve_from_kb(self, query, top_k=5):
566
+ if not self.knowledge_base:
567
+ logger.warning("Knowledge base is empty. No documents to retrieve.")
568
+ return []
569
+
570
+ query_embedding = self.embedding_model.encode([query], convert_to_tensor=True).to(self.device)
571
+
572
+ if self.kb_embeddings is None or self.kb_embeddings.numel() == 0:
573
+ logger.warning("Knowledge base embeddings are empty. No documents to retrieve.")
574
+ return []
575
+
576
+ if query_embedding.size(1) != self.kb_embeddings.size(1):
577
+ logger.error("Dimension mismatch between query embedding and KB embeddings.")
578
+ return []
579
+
580
+ cosine_scores = cosine_similarity(query_embedding.cpu().numpy(), self.kb_embeddings.cpu().numpy())[0]
581
+ top_indices = cosine_scores.argsort()[-top_k:][::-1]
582
+
583
+ # Ensure indices are within the knowledge_base length
584
+ top_indices = [idx for idx in top_indices if idx < len(self.knowledge_base)]
585
+
586
+ retrieved_docs = []
587
+ for idx in top_indices:
588
+ doc = self.knowledge_base[idx]
589
+ doc['score'] = cosine_scores[idx]
590
+ retrieved_docs.append(doc)
591
+
592
+ logger.info(f"Retrieved top {len(retrieved_docs)} documents from Knowledge Base for the query.")
593
+ return retrieved_docs
594
+
595
+ # ==========================
596
+ # RAG Integration
597
+ # ==========================
598
+
599
+ def retrieve_from_web(self, query, top_k=5):
600
+ logger.info(f"Performing web search for query: {query}")
601
+ mcts_iterations = self.calculate_mcts_iterations(np.zeros(self.state_size, dtype=np.float32))
602
+ self.mcts = MCTS(initial_state=query, num_simulations=mcts_iterations)
603
+
604
+ try:
605
+ new_query = yield self.mcts.run()
606
+ logger.debug(f"New query from MCTS: {new_query}")
607
+ # Select search sites
608
+ search_sites = self.select_search_sites(new_query)
609
+ results = yield self.mcts.web_search(new_query, search_sites)
610
+ logger.debug(f"Web search completed. Found {len(results)} results")
611
+ return results[:top_k] if results else []
612
+ except Exception as e:
613
+ logger.error(f"Error during MCTS or web search: {str(e)}", exc_info=True)
614
+ return []
615
+
616
+ def combine_documents(self, kb_docs, web_docs):
617
+ combined = kb_docs + web_docs
618
+ logger.info(f"Combined {len(kb_docs)} KB documents and {len(web_docs)} Web documents.")
619
+ return combined
620
+
621
+ def save_llm_training_data(self, query, content, summary=None, link=None, title=None):
622
+ data = {
623
+ "query": query,
624
+ "search_result": {
625
+ "link": link,
626
+ "title": title
627
+ },
628
+ "content": content,
629
+ "description": summary
630
+ }
631
+
632
+ os.makedirs("llm_training_data", exist_ok=True)
633
+ file_path = "llm_training_data/llm_training_data.jsonl"
634
+
635
+ # Append the new data as a new line in the JSONL file
636
+ with open(file_path, 'a', encoding='utf-8') as f:
637
+ json.dump(data, f)
638
+ f.write('\n')
639
+
640
+ logger.info(f"Appended LLM training data to {file_path}")
641
+
642
+ # ==========================
643
+ # Hierarchical RL Integration
644
+ # ==========================
645
+
646
+ def remember_manager(self, state, option, reward, next_state, done, td_error):
647
+ self.manager.remember(state, option, reward, next_state, done, td_error)
648
+
649
+ def remember_worker(self, state, action, reward, next_state, done):
650
+ self.worker_memory.add(reward, (state, action, reward, next_state, done))
651
+
652
+ # ==========================
653
+ # Action Selection and Execution
654
+ # ==========================
655
+
656
+ def act_manager(self, state):
657
+ option = self.manager.act(state)
658
+ return option
659
+
660
+ def act_worker(self, state):
661
+ action = self.worker_model.act(state, epsilon=self.epsilon)
662
+ return action
663
+
664
+ # ==========================
665
+ # Replay Methods
666
+ # ==========================
667
+
668
+ def replay_manager(self, batch_size=32, beta=0.4):
669
+ self.manager.replay(batch_size, beta)
670
+
671
+ def replay_worker(self, batch_size=32, beta=0.4):
672
+ result = self.worker_memory.replay(batch_size, beta)
673
+ if result is None:
674
+ return
675
+ batch, idxs, weights = result
676
+ if len(self.worker_memory.tree.data) >= batch_size:
677
+ batch, idxs, weights = self.worker_memory.sample(batch_size, beta)
678
+ states, actions, rewards, next_states, dones = zip(*batch)
679
+
680
+ states = torch.FloatTensor(states).to(self.worker_model.lstm.weight.device)
681
+ next_states = torch.FloatTensor(next_states).to(self.worker_model.lstm.weight.device)
682
+ actions = torch.LongTensor(actions).unsqueeze(1).to(self.worker_model.lstm.weight.device)
683
+ rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.worker_model.lstm.weight.device)
684
+ dones = torch.FloatTensor(dones).unsqueeze(1).to(self.worker_model.lstm.weight.device)
685
+ weights = torch.FloatTensor(weights).unsqueeze(1).to(self.worker_model.lstm.weight.device)
686
+
687
+ # Current Q values
688
+ current_q_values, _ = self.worker_model(states)
689
+ current_q_values = current_q_values.gather(1, actions)
690
+
691
+ # Target Q values
692
+ with torch.no_grad():
693
+ next_q_values, _ = self.worker_target_model(next_states)
694
+ max_next_q_values = next_q_values.max(1)[0].unsqueeze(1)
695
+ target_q_values = rewards + (self.gamma * max_next_q_values * (1 - dones))
696
+
697
+ # Compute TD errors
698
+ td_errors = target_q_values - current_q_values
699
+
700
+ # Compute loss with importance-sampling weights
701
+ loss = (td_errors.pow(2) * weights).mean()
702
+
703
+ # Optimize the model
704
+ self.worker_optimizer.zero_grad()
705
+ loss.backward()
706
+ torch.nn.utils.clip_grad_norm_(self.worker_model.parameters(), max_norm=1.0)
707
+ self.worker_optimizer.step()
708
+ self.worker_scheduler.step(loss.item())
709
+
710
+ # Update priorities
711
+ td_errors_np = td_errors.detach().cpu().numpy().squeeze()
712
+ for idx, td_error in zip(idxs, td_errors_np):
713
+ self.worker_memory.update(idx, np.abs(td_error))
714
+
715
+ # Decay epsilon
716
+ if self.epsilon > self.epsilon_min:
717
+ self.epsilon *= self.epsilon_decay
718
+ logger.debug(f"Updated epsilon to: {self.epsilon}")
719
+
720
+ # ==========================
721
+ # Load and Save Models
722
+ # ==========================
723
+
724
+ def load_worker_model(self, name):
725
+ self.worker_model.load_state_dict(torch.load(name, map_location=self.device))
726
+ logger.info(f"Loaded worker model weights from {name}")
727
+
728
+ def save_worker_model(self, name):
729
+ torch.save(self.worker_model.state_dict(), name)
730
+ logger.info(f"Saved worker model weights to {name}")
731
+
732
+ def load_manager_model(self, name):
733
+ self.manager.model.load_state_dict(torch.load(name, map_location=self.device))
734
+ self.manager.update_target_model()
735
+ logger.info(f"Loaded manager model weights from {name}")
736
+
737
+ def save_manager_model(self, name):
738
+ torch.save(self.manager.model.state_dict(), name)
739
+ logger.info(f"Saved manager model weights to {name}")
740
+
741
+ # ==========================
742
+ # Update Target Models
743
+ # ==========================
744
+
745
+ def update_worker_target_model(self):
746
+ self.worker_target_model.load_state_dict(self.worker_model.state_dict())
747
+ logger.info("Updated worker target model with current model weights")
748
+
749
+ def update_manager_target_model(self):
750
+ self.manager.update_target_model()
751
+ logger.info("Updated manager target model with current model weights")
752
+
753
+ # ==========================
754
+ # Feature Extraction
755
+ # ==========================
756
+
757
+ def extract_features(self, content, query):
758
+ content = truncate_text(content)
759
+ query = truncate_text(query)
760
+ soup = BeautifulSoup(content, 'html.parser')
761
+ text = soup.get_text()
762
+ word_count = len(text.split())
763
+ link_count = len(soup.find_all('a'))
764
+ header_count = len(soup.find_all(['h1', 'h2', 'h3', 'h4', 'h5', 'h6']))
765
+
766
+ # Calculate semantic similarity
767
+ text_embedding = self.embedding_model.encode([text], convert_to_tensor=True).to(self.device)
768
+ query_embedding = self.embedding_model.encode([query], convert_to_tensor=True).to(self.device)
769
+ semantic_similarity = cosine_similarity(text_embedding.cpu().numpy(), query_embedding.cpu().numpy())[0][0]
770
+
771
+ # Additional Features
772
+ image_count = len(soup.find_all('img'))
773
+ script_count = len(soup.find_all('script'))
774
+ css_count = len(soup.find_all('link', rel='stylesheet'))
775
+
776
+ return np.array([word_count, link_count, header_count, semantic_similarity, image_count, script_count, css_count])
777
+
778
+ # ==========================
779
+ # Reward Calculation
780
+ # ==========================
781
+
782
+ def calculate_reward(self, content, query):
783
+ try:
784
+ ranked_results = train_ranking_model(query, [{'content': content}])
785
+ logger.debug(f"Ranked results: {ranked_results}")
786
+ if ranked_results and isinstance(ranked_results[0], dict) and 'predicted_score' in ranked_results[0]:
787
+ reward = max(ranked_results[0]['predicted_score'], 0)
788
+ logger.debug(f"Calculated reward: {reward}")
789
+ return reward
790
+ else:
791
+ logger.warning(f"Invalid ranked results: {ranked_results}")
792
+ return 0
793
+ except Exception as e:
794
+ logger.error(f"Error in calculate_reward: {str(e)}", exc_info=True)
795
+ return 0
796
+
797
+ # ==========================
798
+ # Search Site Selection
799
+ # ==========================
800
+
801
+ def select_search_sites(self, query, num_sites=5):
802
+ # Select top sites based on past performance for this query
803
+ site_scores = {}
804
+ for (site, q), score in self.site_performance.items():
805
+ if q == query:
806
+ site_scores[site] = site_scores.get(site, 0) + score
807
+ if site_scores:
808
+ sorted_sites = sorted(site_scores.items(), key=lambda x: x[1], reverse=True)
809
+ top_sites = [site for site, score in sorted_sites[:num_sites]]
810
+ else:
811
+ # If no past data, select random sites
812
+ top_sites = random.sample(self.all_search_sites, num_sites)
813
+ # Construct full URLs with query
814
+ search_sites = [site + query for site in top_sites]
815
+ return search_sites
816
+
817
+ # ==========================
818
+ # Search Method with HRL
819
+ # ==========================
820
+
821
+ @defer.inlineCallbacks
822
+ def search(self, query, max_steps=2):
823
+ logger.info(f"Starting search for query: {query}")
824
+ state = np.zeros(self.state_size, dtype=np.float32)
825
+ total_reward = 0
826
+ content = ""
827
+ done = False
828
+ results = None
829
+
830
+ try:
831
+ # High-Level: Manager selects an option
832
+ option = self.act_manager(state)
833
+ logger.debug(f"Manager selected option: {option}")
834
+
835
+ # Execute the selected option
836
+ if option == 0: # Search Option
837
+ logger.debug("Executing Search Option")
838
+ results = yield self.retrieve_from_web(query)
839
+ if results:
840
+ content = results[0]['content']
841
+ site = urlparse(results[0]['link']).netloc
842
+ self.save_llm_training_data(
843
+ query,
844
+ content,
845
+ summary=results[0].get('summary'),
846
+ link=results[0].get('link'),
847
+ title=results[0].get('title')
848
+ )
849
+ self.add_document_to_kb(title=results[0].get('title', 'No Title'), content=content, metadata=results[0].get('meta', {}))
850
+ next_state = self.extract_features(content, query)
851
+ reward = self.calculate_reward(content, query)
852
+ logger.debug(f"Extracted features: {next_state}, Reward: {reward}")
853
+ # Update site performance
854
+ key = (site, query)
855
+ self.site_performance[key] = self.site_performance.get(key, 0) + reward
856
+
857
+ # Remember Manager's experience
858
+ self.remember_manager(state, option, reward, next_state, done, td_error=reward)
859
+
860
+ # Remember Worker's experience
861
+ self.remember_worker(state, 0, reward, next_state, done)
862
+
863
+ state = next_state.astype(np.float32)
864
+ total_reward += reward
865
+
866
+ else:
867
+ reward = -1
868
+ logger.warning(f"No results for query: {query}")
869
+ # Remember Manager's experience
870
+ self.remember_manager(state, option, reward, state, True, td_error=reward)
871
+
872
+ elif option == 1: # Summarize Option
873
+ logger.debug("Executing Summarize Option")
874
+ if content:
875
+ summary = self.summarizer.generate_summary(content, query)
876
+ self.save_llm_training_data(
877
+ query,
878
+ content,
879
+ summary=summary,
880
+ link=results[0].get('link') if results else None,
881
+ title=results[0].get('title') if results else None
882
+ )
883
+ reward = self.calculate_reward(summary, query)
884
+ next_state = self.extract_features(summary, query)
885
+ logger.info(f"Summary:\n{summary}")
886
+ logger.info(f"Summarized content. Reward: {reward}")
887
+
888
+ # Remember Manager's experience
889
+ self.remember_manager(state, option, reward, next_state, done, td_error=reward)
890
+
891
+ # Remember Worker's experience
892
+ self.remember_worker(state, 1, reward, next_state, done)
893
+
894
+ state = next_state.astype(np.float32)
895
+ total_reward += reward
896
+ else:
897
+ reward = -1
898
+ logger.warning("No content to summarize")
899
+ # Remember Manager's experience
900
+ self.remember_manager(state, option, reward, state, True, td_error=reward)
901
+
902
+ elif option == 2: # RAG-based Generation Option
903
+ logger.debug("Executing RAG-based Generation Option")
904
+ kb_docs = self.retrieve_from_kb(query, top_k=5)
905
+ web_docs = [] # Assuming web_docs are already retrieved
906
+ combined_docs = self.combine_documents(kb_docs, web_docs)
907
+ generated_output = self.generate_rag_response(query, combined_docs)
908
+ logger.info(f"Generated Output:\n{generated_output}")
909
+ self.save_llm_training_data(
910
+ query,
911
+ generated_output,
912
+ summary=None,
913
+ link=None,
914
+ title="RAG-generated response"
915
+ )
916
+ reward = self.calculate_reward(generated_output, query)
917
+ next_state = self.extract_features(generated_output, query)
918
+
919
+ # Remember Manager's experience
920
+ self.remember_manager(state, option, reward, next_state, done, td_error=reward)
921
+
922
+ # Remember Worker's experience
923
+ self.remember_worker(state, 2, reward, next_state, done)
924
+
925
+ state = next_state.astype(np.float32)
926
+ total_reward += reward
927
+
928
+ else:
929
+ logger.warning(f"Unknown option selected by Manager: {option}")
930
+
931
+ # Perform replay for both Manager and Worker
932
+ self.replay_manager(batch_size=32, beta=0.4)
933
+ self.replay_worker(batch_size=32, beta=0.4)
934
+
935
+ # Update target models periodically
936
+ self.update_worker_target_model()
937
+ self.update_manager_target_model()
938
+
939
+ logger.info(f"Search completed. Total reward: {total_reward}")
940
+ defer.returnValue(total_reward)
941
+ except Exception as e:
942
+ logger.error(f"Error during search: {str(e)}", exc_info=True)
943
+ defer.returnValue(-1) # Return a negative reward on error
944
+
945
+ # ==========================
946
+ # Summarization Method
947
+ # ==========================
948
+
949
+ def summarize(self, content, query):
950
+ chunks = self.summarizer.split_into_chunks(content)
951
+ embeddings = self.summarizer.get_embeddings(chunks)
952
+ relevant_chunks = self.summarizer.retrieve_relevant_chunks(query, chunks, embeddings)
953
+ summary = self.summarizer.generate_summary(query, relevant_chunks)
954
+
955
+ # Save RAG data
956
+ self.summarizer.save_rag_data(query, chunks, embeddings)
957
+
958
+ return summary
959
+
960
+ # ==========================
961
+ # MCTS Iterations Calculation
962
+ # ==========================
963
+
964
+ def calculate_mcts_iterations(self, state):
965
+ # Calculate MCTS iterations based on state complexity
966
+ base_iterations = 2
967
+ complexity_factor = np.mean(state) / 100 # Normalize state values
968
+ iterations = int(base_iterations * (1 + complexity_factor))
969
+ max_iterations = 5 # Set a reasonable maximum
970
+ return min(iterations, max_iterations)
971
+
972
+ # ==========================
973
+ # RAG-based Response Generation
974
+ # ==========================
975
+
976
+ def generate_rag_response(self, query, combined_docs):
977
+ if not combined_docs:
978
+ logger.warning("No documents available for RAG-based generation.")
979
+ return "I'm sorry, I couldn't find any relevant information."
980
+
981
+ # Prepare context for the generator
982
+ context = "\n\n".join([f"Title: {doc.get('title', 'No Title')}\nContent: {doc.get('content', '')}" for doc in combined_docs])
983
+ prompt = f"Query: {query}\n\nContext:\n{context}\n\nAnswer:"
984
+
985
+ # Check cache first
986
+ cache_key = hashlib.md5(prompt.encode()).hexdigest()
987
+ cached_response = self.summarizer.cache.get(cache_key)
988
+ if cached_response:
989
+ logger.debug("Using cached RAG response.")
990
+ return cached_response
991
+
992
+ # Generate response
993
+ input_ids = self.summarizer.tokenizer.encode(prompt, return_tensors='pt').to(self.summarizer.device)
994
+ try:
995
+ output = self.summarizer.model.generate(
996
+ input_ids,
997
+ max_length=input_ids.shape[1] + self.summarizer.max_length,
998
+ num_return_sequences=1,
999
+ no_repeat_ngram_size=2,
1000
+ top_k=50,
1001
+ top_p=0.95,
1002
+ temperature=0.7,
1003
+ early_stopping=True
1004
+ )
1005
+ except Exception as e:
1006
+ logger.error(f"Error during RAG response generation: {str(e)}")
1007
+ return "RAG response generation failed."
1008
+
1009
+ response = self.summarizer.tokenizer.decode(output[0], skip_special_tokens=True)
1010
+ answer = response.split("Answer:")[-1].strip()
1011
+
1012
+ # Cache the response
1013
+ self.summarizer.cache.put(cache_key, answer)
1014
+ self.summarizer.save_persistent_cache()
1015
+ return answer
1016
+
1017
+ # ==========================
1018
+ # Manager and Worker Interaction
1019
+ # ==========================
1020
+
1021
+ def select_option(self, option):
1022
+
1023
+ """
1024
+ Define the mapping of options to their corresponding actions.
1025
+ """
1026
+ # This can be expanded based on the number of options
1027
+ option_actions = {
1028
+ 0: self.perform_search,
1029
+ 1: self.perform_summarization,
1030
+ 2: self.perform_rag_generation
1031
+ }
1032
+ action = option_actions.get(option, None)
1033
+ if action:
1034
+ return action
1035
+ else:
1036
+ logger.error(f"No action defined for option: {option}")
1037
+ return None
1038
+
1039
+ def perform_search(self, query):
1040
+ """
1041
+ Perform the search action.
1042
+ """
1043
+ # Implementation is handled in the 'search' method
1044
+ pass
1045
+
1046
+ def perform_summarization(self, content, query):
1047
+ """
1048
+ Perform the summarization action.
1049
+ """
1050
+ # Implementation is handled in the 'summarize' method
1051
+ pass
1052
+
1053
+ def perform_rag_generation(self, query, combined_docs):
1054
+ """
1055
+ Perform the RAG-based generation action.
1056
+ """
1057
+ # Implementation is handled in the 'generate_rag_response' method
1058
+ pass
1059
+
1060
+ # ==========================
1061
+ # LRUCache Class
1062
+ # ==========================
1063
+
1064
+ class LRUCache:
1065
+ def __init__(self, capacity):
1066
+ self.cache = OrderedDict()
1067
+ self.capacity = capacity
1068
+
1069
+ def get(self, key):
1070
+ if key not in self.cache:
1071
+ return None
1072
+ self.cache.move_to_end(key)
1073
+ return self.cache[key]
1074
+
1075
+ def put(self, key, value):
1076
+ if key in self.cache:
1077
+ self.cache.move_to_end(key)
1078
+ self.cache[key] = value
1079
+ if len(self.cache) > self.capacity:
1080
+ self.cache.popitem(last=False)
1081
+
1082
+
lightbulb.py ADDED
@@ -0,0 +1,1696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.optim as optim
8
+ from torch.utils.data import DataLoader
9
+ import copy
10
+ from torch.optim.lr_scheduler import CosineAnnealingLR
11
+ from torch.cuda.amp import autocast, GradScaler
12
+ from datasets import load_dataset
13
+ from transformers import AutoTokenizer
14
+
15
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+
17
+ def parse_args():
18
+ parser = argparse.ArgumentParser(description='Train or Inference with World Model and Tree of Thought.')
19
+ parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name or path')
20
+ parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name from HuggingFace Datasets')
21
+ parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
22
+ parser.add_argument('--batch_size', type=int, default=4, help='Batch size')
23
+ parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs')
24
+ parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length')
25
+ parser.add_argument('--mcts_iterations', type=int, default=3, help='Number of MCTS Iterations')
26
+ parser.add_argument('--mcts_exploration_constant', type=float, default=1.414, help='Exploration constant for MCTS')
27
+ parser.add_argument('--accumulation_steps', type=int, default=4, help='Gradient accumulation steps')
28
+ parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
29
+ parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
30
+ parser.add_argument('--alpha', type=float, default=0.1, help='Entropy regularization weight')
31
+ parser.add_argument('--beta', type=float, default=0.1, help='Variance regularization weight')
32
+ parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
33
+ parser.add_argument('--save_dir', type=str, default='./models', help='Directory to save the models')
34
+ parser.add_argument('--temperature', type=float, default=1.0, help='Temperature parameter for entropy and variance')
35
+ parser.add_argument('--mode', type=str, choices=['train', 'inference'], default='train', help='Mode: train or inference')
36
+ parser.add_argument('--inference_mode', type=str, choices=['world_model', 'without_world_model', 'world_model_tree_of_thought'], default='world_model_tree_of_thought', help='Inference mode')
37
+ parser.add_argument('--query', type=str, default='', help='Input query for inference')
38
+ parser.add_argument('--train_mode', type=str, choices=['world_model', 'language_model'], default='world_model', help='Train world model or language model only')
39
+
40
+ # Use parse_known_args to ignore unknown arguments
41
+ args, unknown = parser.parse_known_args()
42
+ return args
43
+
44
+ def load_data(args, tokenizer):
45
+ # Load the dataset
46
+ dataset = load_dataset(args.dataset_name, args.dataset_config)
47
+
48
+ # Ensure the tokenizer has a padding token
49
+ if tokenizer.pad_token is None:
50
+ tokenizer.pad_token = tokenizer.eos_token
51
+
52
+ def tokenize_function(examples):
53
+ return tokenizer(examples['text'], truncation=True, max_length=args.max_length)
54
+
55
+ tokenized_datasets = dataset.map(
56
+ tokenize_function,
57
+ batched=True,
58
+ num_proc=4,
59
+ remove_columns=dataset['train'].column_names,
60
+ )
61
+
62
+ # Build inputs and labels for language modeling
63
+ block_size = args.max_length
64
+
65
+ def group_texts(examples):
66
+ # Concatenate all texts
67
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
68
+ total_length = len(concatenated_examples['input_ids'])
69
+ # We drop the small remainder
70
+ total_length = (total_length // block_size) * block_size
71
+ # Split by chunks of block_size
72
+ result = {
73
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
74
+ for k, t in concatenated_examples.items()
75
+ }
76
+ result['labels'] = result['input_ids'].copy()
77
+ return result
78
+
79
+ lm_datasets = tokenized_datasets.map(
80
+ group_texts,
81
+ batched=True,
82
+ num_proc=4,
83
+ )
84
+
85
+ # Create DataLoader
86
+ train_dataset = lm_datasets['train']
87
+ eval_dataset = lm_datasets['validation'] if 'validation' in lm_datasets else lm_datasets['test']
88
+
89
+ def data_collator(data):
90
+ return {
91
+ 'input_ids': torch.tensor([f['input_ids'] for f in data], dtype=torch.long),
92
+ 'labels': torch.tensor([f['labels'] for f in data], dtype=torch.long)
93
+ }
94
+
95
+ train_loader = DataLoader(
96
+ train_dataset,
97
+ shuffle=True,
98
+ batch_size=args.batch_size,
99
+ collate_fn=data_collator,
100
+ pin_memory=True, # Speeds up transfer to GPU
101
+ num_workers=4
102
+ )
103
+ eval_loader = DataLoader(
104
+ eval_dataset,
105
+ shuffle=False,
106
+ batch_size=args.batch_size,
107
+ collate_fn=data_collator,
108
+ pin_memory=True,
109
+ num_workers=4
110
+ )
111
+
112
+ return train_loader, eval_loader
113
+
114
+ def save_all_models(transformer_model, representation_network, dynamics_network, prediction_network, action_encoder, save_dir, epoch):
115
+ """
116
+ Save all models to the specified directory.
117
+
118
+ Args:
119
+ transformer_model (nn.Module): Transformer model.
120
+ representation_network (nn.Module): Representation network.
121
+ dynamics_network (nn.Module): Dynamics network.
122
+ prediction_network (nn.Module): Prediction network.
123
+ action_encoder (nn.Module): Action encoder.
124
+ save_dir (str): Directory to save the models.
125
+ epoch (int): Current epoch number.
126
+ """
127
+ os.makedirs(save_dir, exist_ok=True)
128
+
129
+ torch.save(transformer_model.state_dict(), os.path.join(save_dir, f'transformer_model_epoch_{epoch}.pt'))
130
+ torch.save(representation_network.state_dict(), os.path.join(save_dir, f'representation_network_epoch_{epoch}.pt'))
131
+ torch.save(dynamics_network.state_dict(), os.path.join(save_dir, f'dynamics_network_epoch_{epoch}.pt'))
132
+ torch.save(prediction_network.state_dict(), os.path.join(save_dir, f'prediction_network_epoch_{epoch}.pt'))
133
+ torch.save(action_encoder.state_dict(), os.path.join(save_dir, f'action_encoder_epoch_{epoch}.pt'))
134
+
135
+ print(f"All models saved for epoch {epoch}.")
136
+
137
+ class RotaryPositionalEncoding(nn.Module):
138
+ def __init__(self, d_model):
139
+ super(RotaryPositionalEncoding, self).__init__()
140
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
141
+ self.register_buffer('inv_freq', inv_freq)
142
+
143
+ def forward(self, x):
144
+ seq_len, batch_size, _ = x.size()
145
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
146
+ sinusoid_inp = torch.einsum("i,j->ij", t, self.inv_freq)
147
+ sin = sinusoid_inp.sin().unsqueeze(1) # (seq_len, 1, d_model/2)
148
+ cos = sinusoid_inp.cos().unsqueeze(1) # (seq_len, 1, d_model/2)
149
+
150
+ x1 = x[..., 0::2]
151
+ x2 = x[..., 1::2]
152
+
153
+ # Apply rotation
154
+ x_rotated = torch.zeros_like(x)
155
+ x_rotated[..., 0::2] = x1 * cos - x2 * sin
156
+ x_rotated[..., 1::2] = x1 * sin + x2 * cos
157
+
158
+ return x_rotated
159
+
160
+ class MultiHeadAttention(nn.Module):
161
+ def __init__(self, d_model, num_heads):
162
+ super(MultiHeadAttention, self).__init__()
163
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
164
+ self.d_k = d_model // num_heads
165
+ self.num_heads = num_heads
166
+ self.linear_q = nn.Linear(d_model, d_model)
167
+ self.linear_k = nn.Linear(d_model, d_model)
168
+ self.linear_v = nn.Linear(d_model, d_model)
169
+ self.linear_out = nn.Linear(d_model, d_model)
170
+
171
+ def forward(self, query, key, value, mask=None):
172
+ batch_size = query.size(0)
173
+ query = self.linear_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
174
+ key = self.linear_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
175
+ value = self.linear_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
176
+
177
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)
178
+ if mask is not None:
179
+ scores = scores.masked_fill(mask == 0, -1e4)
180
+ attn = F.softmax(scores, dim=-1)
181
+ output = torch.matmul(attn, value)
182
+
183
+ output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
184
+ return self.linear_out(output)
185
+
186
+ class MoE(nn.Module):
187
+ def __init__(self, d_model, num_experts, d_ff, top_k=2, dropout=0.1):
188
+ super(MoE, self).__init__()
189
+ self.num_experts = num_experts
190
+ self.top_k = top_k
191
+ self.experts = nn.ModuleList([
192
+ nn.Sequential(
193
+ nn.Linear(d_model, d_ff),
194
+ nn.GELU() if i % 2 == 0 else nn.SiLU(),
195
+ nn.Linear(d_ff, d_model)
196
+ )
197
+ for i in range(num_experts)
198
+ ])
199
+ self.gate = nn.Linear(d_model, num_experts)
200
+ self.dropout = nn.Dropout(dropout)
201
+
202
+ def forward(self, x):
203
+ batch_size, seq_len, d_model = x.size()
204
+ # Compute gating scores
205
+ gate_scores = self.gate(x) # (batch_size, seq_len, num_experts)
206
+ top_k_scores, top_k_indices = torch.topk(gate_scores, self.top_k, dim=-1) # (batch_size, seq_len, top_k)
207
+ top_k_scores = F.softmax(top_k_scores, dim=-1) # (batch_size, seq_len, top_k)
208
+
209
+ # Initialize output
210
+ output = torch.zeros_like(x)
211
+
212
+ # Flatten batch and sequence dimensions
213
+ x_flat = x.view(-1, d_model) # (batch_size * seq_len, d_model)
214
+ output_flat = output.view(-1, d_model)
215
+ top_k_indices_flat = top_k_indices.view(-1, self.top_k) # (batch_size * seq_len, top_k)
216
+ top_k_scores_flat = top_k_scores.view(-1, self.top_k) # (batch_size * seq_len, top_k)
217
+
218
+ for k in range(self.top_k):
219
+ expert_idx_flat = top_k_indices_flat[:, k] # (batch_size * seq_len)
220
+ expert_scores_flat = top_k_scores_flat[:, k] # (batch_size * seq_len)
221
+ for e in range(self.num_experts):
222
+ mask = (expert_idx_flat == e) # Boolean mask
223
+ if mask.any():
224
+ x_masked = x_flat[mask] # Select tokens for expert e
225
+ expert_output = self.experts[e](x_masked) # Apply expert e
226
+ output_flat[mask] += expert_scores_flat[mask].unsqueeze(-1) * expert_output
227
+
228
+ output = output_flat.view(batch_size, seq_len, d_model)
229
+ return self.dropout(output)
230
+
231
+ class TransformerBlock(nn.Module):
232
+ def __init__(self, d_model, num_heads, d_ff, num_experts, dropout=0.1, top_k=2):
233
+ super(TransformerBlock, self).__init__()
234
+ self.self_attention = MultiHeadAttention(d_model, num_heads)
235
+ self.norm1 = nn.LayerNorm(d_model)
236
+ self.cross_attention = MultiHeadAttention(d_model, num_heads)
237
+ self.norm2 = nn.LayerNorm(d_model)
238
+ self.moe = MoE(d_model, num_experts, d_ff, top_k, dropout)
239
+ self.norm3 = nn.LayerNorm(d_model)
240
+
241
+ def forward(self, x, mask=None, enc_output=None, enc_mask=None):
242
+ # Self-attention
243
+ attn_output = self.self_attention(x, x, x, mask)
244
+ x = self.norm1(x + attn_output)
245
+ # Cross-attention (only in decoder)
246
+ if enc_output is not None:
247
+ cross_attn_output = self.cross_attention(x, enc_output, enc_output, enc_mask)
248
+ x = self.norm2(x + cross_attn_output)
249
+ # Feedforward/MoE
250
+ moe_output = self.moe(x)
251
+ return self.norm3(x + moe_output)
252
+
253
+ class Transformer(nn.Module):
254
+ def __init__(self, input_dim, d_model, num_heads, num_layers, d_ff, num_experts, output_dim, dropout=0.1, top_k=2):
255
+ super(Transformer, self).__init__()
256
+ self.embedding = nn.Embedding(input_dim, d_model, padding_idx=input_dim - 1)
257
+ self.rotary_positional_encoding = RotaryPositionalEncoding(d_model)
258
+ self.encoder_layers = nn.ModuleList(
259
+ [TransformerBlock(d_model, num_heads, d_ff, num_experts, dropout, top_k) for _ in range(num_layers)]
260
+ )
261
+ self.decoder_layers = nn.ModuleList(
262
+ [TransformerBlock(d_model, num_heads, d_ff, num_experts, dropout, top_k) for _ in range(num_layers)]
263
+ )
264
+ self.output_layer = nn.Linear(d_model, output_dim)
265
+ self.d_model = d_model
266
+
267
+ def forward(self, src, tgt, src_mask=None, tgt_mask=None):
268
+ # Encoder
269
+ src = self.embedding(src) * math.sqrt(self.d_model)
270
+ src = src.transpose(0, 1) # (batch_size, seq_len, d_model) -> (seq_len, batch_size, d_model)
271
+ src = self.rotary_positional_encoding(src)
272
+ src = src.transpose(0, 1) # (seq_len, batch_size, d_model) -> (batch_size, seq_len, d_model)
273
+ for layer in self.encoder_layers:
274
+ src = layer(src, src_mask)
275
+
276
+ # Decoder
277
+ tgt = self.embedding(tgt) * math.sqrt(self.d_model)
278
+ tgt = tgt.transpose(0, 1)
279
+ tgt = self.rotary_positional_encoding(tgt)
280
+ tgt = tgt.transpose(0, 1)
281
+ for layer in self.decoder_layers:
282
+ tgt = layer(tgt, tgt_mask, src, src_mask)
283
+ output = self.output_layer(tgt)
284
+ return output
285
+
286
+ def generate(self, src, tokenizer, max_length=20, temperature=1.0):
287
+ """
288
+ Generate sequences using differentiable sampling (Gumbel-Softmax).
289
+
290
+ Args:
291
+ src (torch.Tensor): Source input tensor of shape (batch_size, seq_len)
292
+ tokenizer (transformers.PreTrainedTokenizer): Tokenizer to access special tokens
293
+ max_length (int): Maximum length of the generated sequence
294
+ temperature (float): Temperature parameter for Gumbel-Softmax
295
+
296
+ Returns:
297
+ torch.Tensor: Generated sequences of shape (batch_size, max_length)
298
+ torch.Tensor: Entropy values for each time step
299
+ torch.Tensor: Variance values for each time step
300
+ """
301
+ batch_size = src.size(0)
302
+
303
+ # Encode the source
304
+ src_enc = self.embedding(src) * math.sqrt(self.d_model)
305
+ src_enc = src_enc.transpose(0, 1)
306
+ src_enc = self.rotary_positional_encoding(src_enc)
307
+ src_enc = src_enc.transpose(0, 1)
308
+ for layer in self.encoder_layers:
309
+ src_enc = layer(src_enc)
310
+
311
+ # Initialize decoder input with <sos> tokens
312
+ tgt_seq = torch.full((batch_size, 1), tokenizer.bos_token_id, dtype=torch.long, device=src.device)
313
+ entropies = []
314
+ variances = []
315
+
316
+ for _ in range(max_length):
317
+ tgt_emb = self.embedding(tgt_seq) * math.sqrt(self.d_model)
318
+ tgt_emb = tgt_emb.transpose(0, 1)
319
+ tgt_emb = self.rotary_positional_encoding(tgt_emb)
320
+ tgt_emb = tgt_emb.transpose(0, 1)
321
+ tgt_dec = tgt_emb
322
+ for layer in self.decoder_layers:
323
+ tgt_dec = layer(tgt_dec, None, src_enc, None)
324
+ output = self.output_layer(tgt_dec) # (batch_size, seq_len, vocab_size)
325
+ logits = output[:, -1, :] # Get logits for the last time step
326
+
327
+ # Compute token probabilities
328
+ probs = F.softmax(logits / temperature, dim=-1) # (batch_size, vocab_size)
329
+
330
+ # Compute entropy
331
+ entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1) # (batch_size)
332
+ entropies.append(entropy)
333
+
334
+ # Sample token using Gumbel-Softmax
335
+ gumbel_noise = -torch.log(-torch.log(torch.rand_like(probs) + 1e-9) + 1e-9)
336
+ y = (logits + gumbel_noise) / temperature
337
+ y = F.softmax(y, dim=-1) # (batch_size, vocab_size)
338
+
339
+ # Compute variance
340
+ variance = torch.var(y, dim=-1) # (batch_size)
341
+ variances.append(variance)
342
+
343
+ # Get token indices (argmax for hard selection)
344
+ next_tokens = torch.argmax(y, dim=-1, keepdim=True) # (batch_size, 1)
345
+ tgt_seq = torch.cat([tgt_seq, next_tokens], dim=1)
346
+
347
+ # Stack entropies and variances
348
+ entropies = torch.stack(entropies, dim=1) # (batch_size, max_length)
349
+ variances = torch.stack(variances, dim=1) # (batch_size, max_length)
350
+
351
+ return tgt_seq[:, 1:], entropies, variances # Exclude the initial <sos> token
352
+
353
+ # Objective Functions
354
+
355
+ class InfoNCE_Loss(nn.Module):
356
+ def __init__(self, temperature=0.07):
357
+ super(InfoNCE_Loss, self).__init__()
358
+ self.temperature = temperature
359
+ self.cross_entropy = nn.CrossEntropyLoss()
360
+
361
+ def forward(self, z_i, z_j):
362
+ """
363
+ Args:
364
+ z_i (torch.Tensor): Flattened representations from view i, shape (2n, embed_dim)
365
+ z_j (torch.Tensor): Flattened representations from view j, shape (2n, embed_dim)
366
+
367
+ Returns:
368
+ torch.Tensor: InfoNCE loss
369
+ """
370
+ n = z_i.size(0)
371
+ z = torch.cat([z_i, z_j], dim=0) # Shape: (2n, embed_dim)
372
+
373
+ z = F.normalize(z, dim=1)
374
+ similarity_matrix = torch.matmul(z, z.T) # Shape: (2n, 2n)
375
+
376
+ # Create a mask to exclude self-similarity
377
+ mask = torch.eye(2 * n, device=z.device, dtype=torch.bool)
378
+ similarity_matrix = similarity_matrix.masked_fill(mask, -1e4) # Use a manageable negative value
379
+
380
+ # Create labels for contrastive learning
381
+ labels = torch.arange(n, device=z.device)
382
+ labels = torch.cat([labels + n, labels], dim=0) # Shape: (2n,)
383
+
384
+ # Apply temperature scaling
385
+ similarity_matrix /= self.temperature
386
+
387
+ # Compute cross-entropy loss
388
+ loss = self.cross_entropy(similarity_matrix, labels)
389
+ return loss
390
+
391
+ class CovarianceRegularization(nn.Module):
392
+ def __init__(self, lambda_reg=1e-3):
393
+ super(CovarianceRegularization, self).__init__()
394
+ self.lambda_reg = lambda_reg
395
+
396
+ def forward(self, embeddings):
397
+ """
398
+ Args:
399
+ embeddings (torch.Tensor): Embedding tensor, shape (batch_size, embed_dim)
400
+
401
+ Returns:
402
+ torch.Tensor: Covariance regularization loss
403
+ """
404
+ batch_size, embed_dim = embeddings.size()
405
+ mean = embeddings.mean(dim=0)
406
+ embeddings_centered = embeddings - mean
407
+ cov = (embeddings_centered.T @ embeddings_centered) / (batch_size - 1)
408
+ cov_loss = torch.sum(cov ** 2) - torch.sum(torch.diag(cov) ** 2)
409
+ return self.lambda_reg * cov_loss
410
+
411
+ class DynamicsPerformanceLoss(nn.Module):
412
+ def __init__(self, lambda_var=1e-3):
413
+ super(DynamicsPerformanceLoss, self).__init__()
414
+ self.lambda_var = lambda_var
415
+
416
+ def forward(self, true_next_state, predicted_next_state):
417
+ """
418
+ Args:
419
+ true_next_state (torch.Tensor): Ground truth next state, shape (batch_size, state_dim)
420
+ predicted_next_state (torch.Tensor): Predicted next state, shape (batch_size, state_dim)
421
+
422
+ Returns:
423
+ torch.Tensor: Dynamics performance loss
424
+ """
425
+ mse_loss = F.mse_loss(predicted_next_state, true_next_state)
426
+ variance_loss = torch.var(predicted_next_state, dim=0).mean()
427
+ return mse_loss + self.lambda_var * variance_loss
428
+
429
+ class ThoughtConsistencyLoss(nn.Module):
430
+ def __init__(self):
431
+ super(ThoughtConsistencyLoss, self).__init__()
432
+
433
+ def forward(self, true_next_state, perturbed_next_state):
434
+ """
435
+ Args:
436
+ true_next_state (torch.Tensor): Ground truth next state, shape (batch_size, state_dim)
437
+ perturbed_next_state (torch.Tensor): Perturbed next state, shape (batch_size, state_dim)
438
+
439
+ Returns:
440
+ torch.Tensor: Thought-consistency loss
441
+ """
442
+ return F.mse_loss(true_next_state, perturbed_next_state)
443
+
444
+ class PolicyValueJointLoss(nn.Module):
445
+ def __init__(self, lambda_value=0.5):
446
+ super(PolicyValueJointLoss, self).__init__()
447
+ self.lambda_value = lambda_value
448
+ self.cross_entropy = nn.CrossEntropyLoss()
449
+ self.mse_loss = nn.MSELoss()
450
+
451
+ def forward(self, policy_logits, true_policy, value_pred, true_value):
452
+ """
453
+ Args:
454
+ policy_logits (torch.Tensor): Logits from the policy network, shape (batch_size * seq_len, num_actions)
455
+ true_policy (torch.Tensor): Ground truth policy, shape (batch_size * seq_len, num_actions)
456
+ value_pred (torch.Tensor): Predicted values, shape (batch_size * seq_len)
457
+ true_value (torch.Tensor): Ground truth values, shape (batch_size * seq_len)
458
+
459
+ Returns:
460
+ torch.Tensor: Combined policy and value loss
461
+ """
462
+ policy_logits = policy_logits.view(-1, policy_logits.size(-1))
463
+ true_policy = true_policy.view(-1, true_policy.size(-1))
464
+ value_pred = value_pred.view(-1)
465
+ true_value = true_value.view(-1)
466
+
467
+ policy_loss = self.cross_entropy(policy_logits, true_policy.argmax(dim=1))
468
+ value_loss = self.mse_loss(value_pred, true_value)
469
+ return policy_loss + self.lambda_value * value_loss
470
+
471
+ class ActionDiversityReward(nn.Module):
472
+ def __init__(self, lambda_div=1e-3):
473
+ super(ActionDiversityReward, self).__init__()
474
+ self.lambda_div = lambda_div
475
+
476
+ def forward(self, action_embeddings):
477
+ """
478
+ Args:
479
+ action_embeddings (torch.Tensor): Embeddings of actions, shape (batch_size, embed_dim)
480
+
481
+ Returns:
482
+ torch.Tensor: Action diversity loss
483
+ """
484
+ similarity_matrix = F.cosine_similarity(action_embeddings.unsqueeze(1), action_embeddings.unsqueeze(0), dim=2)
485
+ # Zero out self-similarity
486
+ similarity_matrix = similarity_matrix - torch.eye(similarity_matrix.size(0)).to(action_embeddings.device)
487
+ diversity_loss = torch.sum(similarity_matrix ** 2)
488
+ return self.lambda_div * diversity_loss
489
+
490
+ class ExpectedThoughtValueLoss(nn.Module):
491
+ def __init__(self):
492
+ super(ExpectedThoughtValueLoss, self).__init__()
493
+
494
+ def forward(self, mcts_best_values):
495
+ """
496
+ Args:
497
+ mcts_best_values (torch.Tensor): Best values from MCTS, shape (batch_size)
498
+
499
+ Returns:
500
+ torch.Tensor: ETV loss
501
+ """
502
+ return -mcts_best_values.mean()
503
+
504
+ class ExplorationRegularization(nn.Module):
505
+ def __init__(self, lambda_expl=1e-3):
506
+ super(ExplorationRegularization, self).__init__()
507
+ self.lambda_expl = lambda_expl
508
+
509
+ def forward(self, visit_counts):
510
+ """
511
+ Args:
512
+ visit_counts (torch.Tensor): Visit counts for actions, shape (batch_size, num_actions)
513
+
514
+ Returns:
515
+ torch.Tensor: Exploration regularization loss
516
+ """
517
+ reward = torch.sum(1.0 / (visit_counts + 1), dim=-1)
518
+ return self.lambda_expl * reward.mean()
519
+
520
+ class KL_DivergenceLoss(nn.Module):
521
+ def __init__(self):
522
+ super(KL_DivergenceLoss, self).__init__()
523
+
524
+ def forward(self, old_policy, new_policy):
525
+ """
526
+ Args:
527
+ old_policy (torch.Tensor): Old policy probabilities, shape (batch_size, num_actions)
528
+ new_policy (torch.Tensor): New policy probabilities, shape (batch_size, num_actions)
529
+
530
+ Returns:
531
+ torch.Tensor: KL divergence loss
532
+ """
533
+ kl_div = F.kl_div(new_policy.log(), old_policy, reduction='batchmean')
534
+ return kl_div
535
+
536
+ # MuZero Components
537
+
538
+ class ActionEncoder(nn.Module):
539
+ def __init__(self, action_vocab_size, embed_dim):
540
+ super(ActionEncoder, self).__init__()
541
+ self.embedding = nn.Embedding(action_vocab_size, embed_dim)
542
+
543
+ def forward(self, action_indices):
544
+ """
545
+ Args:
546
+ action_indices (torch.Tensor): Tensor of shape (batch_size, seq_len)
547
+
548
+ Returns:
549
+ torch.Tensor: Encoded actions of shape (batch_size, seq_len, embed_dim)
550
+ """
551
+ return self.embedding(action_indices)
552
+
553
+ class RepresentationNetwork(nn.Module):
554
+ def __init__(self, vocab_dim, d_model, state_dim):
555
+ super(RepresentationNetwork, self).__init__()
556
+ self.proj = nn.Linear(vocab_dim, d_model) # Project from vocab_dim to d_model
557
+ self.linear = nn.Linear(d_model, state_dim) # Project from d_model to state_dim
558
+ self.norm = nn.LayerNorm(state_dim)
559
+
560
+ def forward(self, transformer_output):
561
+ """
562
+ Args:
563
+ transformer_output (torch.Tensor): Shape (batch_size, seq_len, vocab_dim)
564
+
565
+ Returns:
566
+ torch.Tensor: Encoded state of shape (batch_size, seq_len, state_dim)
567
+ """
568
+ # First project down from vocab_dim to d_model
569
+ projected_output = self.proj(transformer_output)
570
+ # Then project down from d_model to state_dim
571
+ state = self.linear(projected_output)
572
+ state = self.norm(state)
573
+ return state
574
+
575
+ class DynamicsNetwork(nn.Module):
576
+ def __init__(self, state_dim, action_dim, hidden_dim):
577
+ super(DynamicsNetwork, self).__init__()
578
+ self.rms_norm = nn.LayerNorm(state_dim)
579
+ self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
580
+ self.activation = nn.GELU()
581
+ self.fc2 = nn.Linear(hidden_dim, state_dim)
582
+
583
+ def forward(self, state, action):
584
+ """
585
+ Args:
586
+ state (torch.Tensor): Current state, shape (batch_size, seq_len, state_dim)
587
+ action (torch.Tensor): Action embedding, shape (batch_size, seq_len, action_dim)
588
+
589
+ Returns:
590
+ torch.Tensor: Predicted next state, shape (batch_size, seq_len, state_dim)
591
+ """
592
+ norm_state = self.rms_norm(state)
593
+ combined = torch.cat([norm_state, action], dim=-1)
594
+ hidden = self.activation(self.fc1(combined))
595
+ next_state = self.fc2(hidden)
596
+ return next_state
597
+
598
+ class PredictionNetwork(nn.Module):
599
+ def __init__(self, state_dim, action_vocab_size, value_dim):
600
+ super(PredictionNetwork, self).__init__()
601
+ self.state_dim = state_dim
602
+ self.rms_norm = nn.LayerNorm(state_dim)
603
+ self.policy_head = nn.Linear(state_dim, action_vocab_size) # Output size is action_vocab_size
604
+ self.value_head = nn.Linear(state_dim, value_dim)
605
+
606
+ def forward(self, state):
607
+ """
608
+ Args:
609
+ state (torch.Tensor): State representation, shape (batch_size, seq_len, state_dim)
610
+ Returns:
611
+ Tuple[torch.Tensor, torch.Tensor]: Policy logits and value estimates
612
+ """
613
+ norm_state = self.rms_norm(state)
614
+ policy_logits = self.policy_head(norm_state) # Shape: (batch_size, seq_len, action_vocab_size)
615
+ value_estimates = self.value_head(norm_state).squeeze(-1) # Shape: (batch_size, seq_len)
616
+ return policy_logits, value_estimates
617
+
618
+
619
+ # Tree of Thought Components
620
+
621
+ class ThoughtNode:
622
+ def __init__(self, name):
623
+ self.name = name
624
+ self.children = []
625
+ self.parent = None
626
+
627
+ def add_child(self, child_node):
628
+ child_node.parent = self
629
+ self.children.append(child_node)
630
+
631
+ # Function to build the Tree of Thought from your detailed structure
632
+ def build_tree_of_thought():
633
+ # Create the root node
634
+ root = ThoughtNode('Problem-Solving Process')
635
+
636
+ # Level 1 nodes
637
+ problem_identification = ThoughtNode('Problem Identification')
638
+ problem_analysis = ThoughtNode('Problem Analysis')
639
+ solution_generation = ThoughtNode('Solution Generation')
640
+ implementation = ThoughtNode('Implementation')
641
+ evaluation_adjustment = ThoughtNode('Evaluation and Adjustment')
642
+
643
+ root.add_child(problem_identification)
644
+ root.add_child(problem_analysis)
645
+ root.add_child(solution_generation)
646
+ root.add_child(implementation)
647
+ root.add_child(evaluation_adjustment)
648
+
649
+ # Problem Identification children
650
+ B1 = ThoughtNode('Define the Problem')
651
+ B2 = ThoughtNode('Identify Stakeholders')
652
+ B3 = ThoughtNode('Determine Constraints')
653
+ B4 = ThoughtNode('Recognize Problem Type')
654
+ B5 = ThoughtNode('Historical Context')
655
+ problem_identification.add_child(B1)
656
+ problem_identification.add_child(B2)
657
+ problem_identification.add_child(B3)
658
+ problem_identification.add_child(B4)
659
+ problem_identification.add_child(B5)
660
+
661
+ # Define the Problem children
662
+ B1a = ThoughtNode('Problem Statement Formulation')
663
+ B1b = ThoughtNode('Scope Definition')
664
+ B1c = ThoughtNode('Objective Setting')
665
+ B1.add_child(B1a)
666
+ B1.add_child(B1b)
667
+ B1.add_child(B1c)
668
+
669
+ # Identify Stakeholders children
670
+ B2a = ThoughtNode('Stakeholder Mapping')
671
+ B2b = ThoughtNode('Interest and Influence Analysis')
672
+ B2c = ThoughtNode('Engagement Strategy')
673
+ B2.add_child(B2a)
674
+ B2.add_child(B2b)
675
+ B2.add_child(B2c)
676
+
677
+ # Determine Constraints children
678
+ B3a = ThoughtNode('Resource Limitations')
679
+ B3b = ThoughtNode('Time Constraints')
680
+ B3c = ThoughtNode('Legal and Regulatory Constraints')
681
+ B3.add_child(B3a)
682
+ B3.add_child(B3b)
683
+ B3.add_child(B3c)
684
+
685
+ # Recognize Problem Type children
686
+ B4a = ThoughtNode('Simple vs Complex')
687
+ B4b = ThoughtNode('Known vs Unknown')
688
+ B4c = ThoughtNode('Tame vs Wicked Problems')
689
+ B4.add_child(B4a)
690
+ B4.add_child(B4b)
691
+ B4.add_child(B4c)
692
+
693
+ # Historical Context children
694
+ B5a = ThoughtNode('Previous Attempts')
695
+ B5b = ThoughtNode('Lessons Learned')
696
+ B5c = ThoughtNode('Environmental Factors')
697
+ B5.add_child(B5a)
698
+ B5.add_child(B5b)
699
+ B5.add_child(B5c)
700
+
701
+ # Problem Analysis children
702
+ C1 = ThoughtNode('Root Cause Analysis')
703
+ C2 = ThoughtNode('System Mapping')
704
+ C3 = ThoughtNode('Data Collection')
705
+ C4 = ThoughtNode('Impact Assessment')
706
+ C5 = ThoughtNode('Theoretical Framework')
707
+ problem_analysis.add_child(C1)
708
+ problem_analysis.add_child(C2)
709
+ problem_analysis.add_child(C3)
710
+ problem_analysis.add_child(C4)
711
+ problem_analysis.add_child(C5)
712
+
713
+ # Root Cause Analysis children
714
+ C1a = ThoughtNode('5 Whys Technique')
715
+ C1b = ThoughtNode('Fishbone Diagram')
716
+ C1c = ThoughtNode('Pareto Analysis')
717
+ C1.add_child(C1a)
718
+ C1.add_child(C1b)
719
+ C1.add_child(C1c)
720
+
721
+ # System Mapping children
722
+ C2a = ThoughtNode('Causal Loop Diagrams')
723
+ C2b = ThoughtNode('Stock and Flow Models')
724
+ C2c = ThoughtNode('Network Analysis')
725
+ C2.add_child(C2a)
726
+ C2.add_child(C2b)
727
+ C2.add_child(C2c)
728
+
729
+ # Data Collection children
730
+ C3a = ThoughtNode('Quantitative Data')
731
+ C3b = ThoughtNode('Qualitative Data')
732
+ C3c = ThoughtNode('Data Validation')
733
+ C3.add_child(C3a)
734
+ C3.add_child(C3b)
735
+ C3.add_child(C3c)
736
+
737
+ # Quantitative Data children
738
+ C3a1 = ThoughtNode('Surveys and Questionnaires')
739
+ C3a2 = ThoughtNode('Experimental Data')
740
+ C3a3 = ThoughtNode('Big Data Analytics')
741
+ C3a.add_child(C3a1)
742
+ C3a.add_child(C3a2)
743
+ C3a.add_child(C3a3)
744
+
745
+ # Qualitative Data children
746
+ C3b1 = ThoughtNode('Interviews')
747
+ C3b2 = ThoughtNode('Focus Groups')
748
+ C3b3 = ThoughtNode('Observational Studies')
749
+ C3b.add_child(C3b1)
750
+ C3b.add_child(C3b2)
751
+ C3b.add_child(C3b3)
752
+
753
+ # Data Validation children
754
+ C3c1 = ThoughtNode('Statistical Validation')
755
+ C3c2 = ThoughtNode('Cross-Validation')
756
+ C3c3 = ThoughtNode('Expert Review')
757
+ C3c.add_child(C3c1)
758
+ C3c.add_child(C3c2)
759
+ C3c.add_child(C3c3)
760
+
761
+ # Impact Assessment children
762
+ C4a = ThoughtNode('Environmental Impact')
763
+ C4b = ThoughtNode('Social Impact')
764
+ C4c = ThoughtNode('Economic Impact')
765
+ C4.add_child(C4a)
766
+ C4.add_child(C4b)
767
+ C4.add_child(C4c)
768
+
769
+ # Theoretical Framework children
770
+ C5a = ThoughtNode('Literature Review')
771
+ C5b = ThoughtNode('Conceptual Modeling')
772
+ C5c = ThoughtNode('Hypothesis Formation')
773
+ C5.add_child(C5a)
774
+ C5.add_child(C5b)
775
+ C5.add_child(C5c)
776
+
777
+ # Solution Generation children
778
+ D1 = ThoughtNode('Creative Problem Solving')
779
+ D2 = ThoughtNode('Analytical Approach')
780
+ D3 = ThoughtNode('Mathematical Computation')
781
+ D4 = ThoughtNode('Decision Making')
782
+ solution_generation.add_child(D1)
783
+ solution_generation.add_child(D2)
784
+ solution_generation.add_child(D3)
785
+ solution_generation.add_child(D4)
786
+
787
+ # Action Planning, Resource Allocation, Change Management children (implementation phase)
788
+ E1 = ThoughtNode('Action Planning')
789
+ E2 = ThoughtNode('Resource Allocation')
790
+ E3 = ThoughtNode('Change Management')
791
+ implementation.add_child(E1)
792
+ implementation.add_child(E2)
793
+ implementation.add_child(E3)
794
+
795
+ # Verification, Performance Metrics, Feedback Loops, Continuous Improvement children (evaluation phase)
796
+ F1 = ThoughtNode('Verification')
797
+ F2 = ThoughtNode('Performance Metrics')
798
+ F3 = ThoughtNode('Feedback Loops')
799
+ F4 = ThoughtNode('Continuous Improvement')
800
+ evaluation_adjustment.add_child(F1)
801
+ evaluation_adjustment.add_child(F2)
802
+ evaluation_adjustment.add_child(F3)
803
+ evaluation_adjustment.add_child(F4)
804
+
805
+ # Cross-Cutting Considerations children
806
+ G = ThoughtNode('Cross-Cutting Considerations')
807
+ root.add_child(G)
808
+
809
+ # Cross-Cutting Considerations children
810
+ G1 = ThoughtNode('Ethical Framework')
811
+ G2 = ThoughtNode('Stakeholder Management')
812
+ G3 = ThoughtNode('Interdisciplinary Connections')
813
+ G4 = ThoughtNode('Technological Integration')
814
+ G5 = ThoughtNode('Emotional Intelligence')
815
+ G6 = ThoughtNode('Collaborative Problem Solving')
816
+ G7 = ThoughtNode('Computational Considerations') # Assuming H was intended as G7
817
+ G8 = ThoughtNode('Order of Operations') # Assuming I was intended as G8
818
+ G9 = ThoughtNode('Critical Thinking') # Assuming J was intended as G9
819
+ G10 = ThoughtNode('Future Perspective') # Assuming K was intended as G10
820
+ G11 = ThoughtNode('Learning and Adaptation') # Assuming L was intended as G11
821
+ G.add_child(G1)
822
+ G.add_child(G2)
823
+ G.add_child(G3)
824
+ G.add_child(G4)
825
+ G.add_child(G5)
826
+ G.add_child(G6)
827
+ G.add_child(G7)
828
+ G.add_child(G8)
829
+ G.add_child(G9)
830
+ G.add_child(G10)
831
+ G.add_child(G11)
832
+
833
+ # Ethical Framework children
834
+ G1a = ThoughtNode('Value-based Decision Making')
835
+ G1b = ThoughtNode('Long-term Consequences')
836
+ G1.add_child(G1a)
837
+ G1.add_child(G1b)
838
+
839
+ # Value-based Decision Making children
840
+ G1a1 = ThoughtNode('Ethical Theories Application')
841
+ G1a2 = ThoughtNode('Moral Dilemma Resolution')
842
+ G1a.add_child(G1a1)
843
+ G1a.add_child(G1a2)
844
+
845
+ # Long-term Consequences children
846
+ G1b1 = ThoughtNode('Sustainability Assessment')
847
+ G1b2 = ThoughtNode('Intergenerational Impact')
848
+ G1b.add_child(G1b1)
849
+ G1b.add_child(G1b2)
850
+
851
+ # Stakeholder Management children
852
+ G2a = ThoughtNode('Direct Stakeholders')
853
+ G2b = ThoughtNode('Indirect Stakeholders')
854
+ G2c = ThoughtNode('Conflicting Interests')
855
+ G2.add_child(G2a)
856
+ G2.add_child(G2b)
857
+ G2.add_child(G2c)
858
+
859
+ # Conflicting Interests children
860
+ G2c1 = ThoughtNode('Negotiation Strategies')
861
+ G2c2 = ThoughtNode('Conflict Resolution Techniques')
862
+ G2c.add_child(G2c1)
863
+ G2c.add_child(G2c2)
864
+
865
+ # Interdisciplinary Connections children
866
+ G3a = ThoughtNode('Related Fields')
867
+ G3b = ThoughtNode('Cross-disciplinary Impact')
868
+ G3.add_child(G3a)
869
+ G3.add_child(G3b)
870
+
871
+ # Related Fields children
872
+ G3a1 = ThoughtNode('Cross-domain Knowledge Transfer')
873
+ G3a2 = ThoughtNode('Interdisciplinary Collaboration')
874
+ G3a.add_child(G3a1)
875
+ G3a.add_child(G3a2)
876
+
877
+ # Cross-disciplinary Impact children
878
+ G3b1 = ThoughtNode('Synergy Identification')
879
+ G3b2 = ThoughtNode('Holistic Impact Assessment')
880
+ G3b.add_child(G3b1)
881
+ G3b.add_child(G3b2)
882
+
883
+ # Technological Integration children
884
+ G4a = ThoughtNode('AI-assisted Problem Solving')
885
+ G4b = ThoughtNode('Data-driven Insights')
886
+ G4c = ThoughtNode('Digital Collaboration Tools')
887
+ G4.add_child(G4a)
888
+ G4.add_child(G4b)
889
+ G4.add_child(G4c)
890
+
891
+ # AI-assisted Problem Solving children
892
+ G4a1 = ThoughtNode('Machine Learning Models')
893
+ G4a2 = ThoughtNode('Natural Language Processing')
894
+ G4a.add_child(G4a1)
895
+ G4a.add_child(G4a2)
896
+
897
+ # Data-driven Insights children
898
+ G4b1 = ThoughtNode('Big Data Analytics')
899
+ G4b2 = ThoughtNode('Predictive Modeling')
900
+ G4b.add_child(G4b1)
901
+ G4b.add_child(G4b2)
902
+
903
+ # Digital Collaboration Tools children
904
+ G4c1 = ThoughtNode('Project Management Platforms')
905
+ G4c2 = ThoughtNode('Virtual Reality Collaboration')
906
+ G4c.add_child(G4c1)
907
+ G4c.add_child(G4c2)
908
+
909
+ # Emotional Intelligence children
910
+ G5a = ThoughtNode('Self-Awareness')
911
+ G5b = ThoughtNode('Empathy')
912
+ G5c = ThoughtNode('Stress Management')
913
+ G5.add_child(G5a)
914
+ G5.add_child(G5b)
915
+ G5.add_child(G5c)
916
+
917
+ # Self-Awareness children
918
+ G5a1 = ThoughtNode('Emotional Recognition')
919
+ G5a2 = ThoughtNode('Personal Bias Identification')
920
+ G5a.add_child(G5a1)
921
+ G5a.add_child(G5a2)
922
+
923
+ # Empathy children
924
+ G5b1 = ThoughtNode('Perspective Taking')
925
+ G5b2 = ThoughtNode('Active Listening')
926
+ G5b.add_child(G5b1)
927
+ G5b.add_child(G5b2)
928
+
929
+ # Stress Management children
930
+ G5c1 = ThoughtNode('Mindfulness Techniques')
931
+ G5c2 = ThoughtNode('Resilience Building')
932
+ G5c.add_child(G5c1)
933
+ G5c.add_child(G5c2)
934
+
935
+ # Collaborative Problem Solving children
936
+ G6a = ThoughtNode('Team Dynamics')
937
+ G6b = ThoughtNode('Communication Strategies')
938
+ G6c = ThoughtNode('Conflict Resolution')
939
+ G6.add_child(G6a)
940
+ G6.add_child(G6b)
941
+ G6.add_child(G6c)
942
+
943
+ # Team Dynamics children
944
+ G6a1 = ThoughtNode('Team Formation Strategies')
945
+ G6a2 = ThoughtNode('Role Assignment')
946
+ G6a.add_child(G6a1)
947
+ G6a.add_child(G6a2)
948
+
949
+ # Communication Strategies children
950
+ G6b1 = ThoughtNode('Clear Messaging')
951
+ G6b2 = ThoughtNode('Feedback Mechanisms')
952
+ G6b.add_child(G6b1)
953
+ G6b.add_child(G6b2)
954
+
955
+ # Conflict Resolution children
956
+ G6c1 = ThoughtNode('Mediation Techniques')
957
+ G6c2 = ThoughtNode('Consensus Building')
958
+ G6c.add_child(G6c1)
959
+ G6c.add_child(G6c2)
960
+
961
+ # Computational Considerations children
962
+ G7a = ThoughtNode('CPU Operations')
963
+ G7b = ThoughtNode('GPU Parallelization')
964
+ G7c = ThoughtNode('Floating-Point Precision')
965
+ G7.add_child(G7a)
966
+ G7.add_child(G7b)
967
+ G7.add_child(G7c)
968
+
969
+ # CPU Operations children
970
+ G7a1 = ThoughtNode('Instruction Set Architecture')
971
+ G7a2 = ThoughtNode('Pipelining and Parallelism')
972
+ G7a.add_child(G7a1)
973
+ G7a.add_child(G7a2)
974
+
975
+ # GPU Parallelization children
976
+ G7b1 = ThoughtNode('CUDA Programming')
977
+ G7b2 = ThoughtNode('OpenCL Framework')
978
+ G7b.add_child(G7b1)
979
+ G7b.add_child(G7b2)
980
+
981
+ # Floating-Point Precision children
982
+ G7c1 = ThoughtNode('IEEE 754 Standard')
983
+ G7c2 = ThoughtNode('Error Propagation Analysis')
984
+ G7c.add_child(G7c1)
985
+ G7c.add_child(G7c2)
986
+
987
+ # Order of Operations children
988
+ G8a = ThoughtNode('Parentheses')
989
+ G8b = ThoughtNode('Exponents')
990
+ G8c = ThoughtNode('Multiplication and Division')
991
+ G8d = ThoughtNode('Addition and Subtraction')
992
+ G8.add_child(G8a)
993
+ G8.add_child(G8b)
994
+ G8.add_child(G8c)
995
+ G8.add_child(G8d)
996
+
997
+ # Critical Thinking children
998
+ G9a = ThoughtNode('Assumptions Questioning')
999
+ G9b = ThoughtNode('Bias Recognition')
1000
+ G9.add_child(G9a)
1001
+ G9.add_child(G9b)
1002
+
1003
+ # Assumptions Questioning children
1004
+ G9a1 = ThoughtNode('Socratic Questioning')
1005
+ G9a2 = ThoughtNode('Devil\'s Advocate Approach')
1006
+ G9a.add_child(G9a1)
1007
+ G9a.add_child(G9a2)
1008
+
1009
+ # Bias Recognition children
1010
+ G9b1 = ThoughtNode('Cognitive Bias Identification')
1011
+ G9b2 = ThoughtNode('Debiasing Techniques')
1012
+ G9b.add_child(G9b1)
1013
+ G9b.add_child(G9b2)
1014
+
1015
+ # Future Perspective children
1016
+ G10a = ThoughtNode('Short-term Projections')
1017
+ G10b = ThoughtNode('Long-term Scenarios')
1018
+ G10c = ThoughtNode('Potential Impacts')
1019
+ G10.add_child(G10a)
1020
+ G10.add_child(G10b)
1021
+ G10.add_child(G10c)
1022
+
1023
+ # Short-term Projections children
1024
+ G10a1 = ThoughtNode('Trend Analysis')
1025
+ G10a2 = ThoughtNode('Scenario Planning')
1026
+ G10a.add_child(G10a1)
1027
+ G10a.add_child(G10a2)
1028
+
1029
+ # Long-term Scenarios children
1030
+ G10b1 = ThoughtNode('Futures Wheel')
1031
+ G10b2 = ThoughtNode('Backcasting')
1032
+ G10b.add_child(G10b1)
1033
+ G10b.add_child(G10b2)
1034
+
1035
+ # Potential Impacts children
1036
+ G10c1 = ThoughtNode('Risk Assessment')
1037
+ G10c2 = ThoughtNode('Opportunity Identification')
1038
+ G10c.add_child(G10c1)
1039
+ G10c.add_child(G10c2)
1040
+
1041
+ # Learning and Adaptation children
1042
+ G11a = ThoughtNode('Reflective Practice')
1043
+ G11b = ThoughtNode('Knowledge Transfer')
1044
+ G11c = ThoughtNode('Adaptive Problem Solving')
1045
+ G11.add_child(G11a)
1046
+ G11.add_child(G11b)
1047
+ G11.add_child(G11c)
1048
+
1049
+ # Reflective Practice children
1050
+ G11a1 = ThoughtNode('After Action Review')
1051
+ G11a2 = ThoughtNode('Learning Journals')
1052
+ G11a.add_child(G11a1)
1053
+ G11a.add_child(G11a2)
1054
+
1055
+ # Knowledge Transfer children
1056
+ G11b1 = ThoughtNode('Best Practice Documentation')
1057
+ G11b2 = ThoughtNode('Mentoring Programs')
1058
+ G11b.add_child(G11b1)
1059
+ G11b.add_child(G11b2)
1060
+
1061
+ # Adaptive Problem Solving children
1062
+ G11c1 = ThoughtNode('Iterative Approaches')
1063
+ G11c2 = ThoughtNode('Flexibility in Methodology')
1064
+ G11c.add_child(G11c1)
1065
+ G11c.add_child(G11c2)
1066
+
1067
+ return root
1068
+
1069
+ def traverse_tree(node, action_list):
1070
+ if node.name not in action_list:
1071
+ action_list.append(node.name)
1072
+ for child in node.children:
1073
+ traverse_tree(child, action_list)
1074
+
1075
+ class MCTSNode:
1076
+ __slots__ = [
1077
+ 'state',
1078
+ 'parent',
1079
+ 'action',
1080
+ 'children',
1081
+ 'visit_count',
1082
+ 'value_sum',
1083
+ 'prior',
1084
+ 'cached_policy',
1085
+ 'cached_value',
1086
+ 'thought_node' # Added to keep track of the current thought node
1087
+ ]
1088
+
1089
+ def __init__(self, state, thought_node, parent=None, action=None):
1090
+ self.state = state
1091
+ self.thought_node = thought_node # Reference to the ThoughtNode
1092
+ self.parent = parent
1093
+ self.action = action
1094
+ self.children = {}
1095
+ self.visit_count = 0
1096
+ self.value_sum = 0.0
1097
+ self.prior = 0.0
1098
+ self.cached_policy = None
1099
+ self.cached_value = None
1100
+
1101
+ def expand(self, priors):
1102
+ """
1103
+ Expand the node by adding all valid child nodes from the thought tree.
1104
+
1105
+ Args:
1106
+ priors (dict): A dictionary mapping action names to prior probabilities.
1107
+ """
1108
+ for child_thought_node in self.thought_node.children:
1109
+ action = child_thought_node.name # Action name
1110
+ if action not in self.children:
1111
+ # Assume batch size of 1 for individual nodes
1112
+ child_state = self.state.apply_action(action)
1113
+ child_node = MCTSNode(
1114
+ state=child_state,
1115
+ thought_node=child_thought_node,
1116
+ parent=self,
1117
+ action=action
1118
+ )
1119
+ child_node.prior = priors.get(action, 1.0 / len(self.thought_node.children)) # Default prior if not provided
1120
+ self.children[action] = child_node
1121
+
1122
+ def is_leaf(self):
1123
+ return len(self.children) == 0
1124
+
1125
+ def ucb_score(self, total_visits, exploration_constant=math.sqrt(2)):
1126
+ if self.visit_count == 0:
1127
+ return float('inf')
1128
+ avg_value = self.value_sum / self.visit_count
1129
+ exploration_term = exploration_constant * self.prior * math.sqrt(total_visits) / (1 + self.visit_count)
1130
+ return avg_value + exploration_term
1131
+
1132
+ class MCTS:
1133
+ def __init__(self, prediction_network, dynamics_network, action_encoder, num_iterations=10, exploration_constant=math.sqrt(2)):
1134
+ self.prediction_network = prediction_network
1135
+ self.dynamics_network = dynamics_network
1136
+ self.action_encoder = action_encoder
1137
+ self.num_iterations = num_iterations
1138
+ self.exploration_constant = exploration_constant
1139
+ self.cache = {}
1140
+
1141
+ def search(self, root_state):
1142
+ """
1143
+ Perform MCTS starting from the root state.
1144
+
1145
+ Args:
1146
+ root_state (State): The root state from which to start the search.
1147
+
1148
+ Returns:
1149
+ str: The best action to take from the root state.
1150
+ """
1151
+ root_node = MCTSNode(state=root_state, thought_node=root_state.thought_node)
1152
+
1153
+ for _ in range(self.num_iterations):
1154
+ node = self.select(root_node)
1155
+ value = self.evaluate(node)
1156
+ self.backpropagate(node, value)
1157
+
1158
+ best_action = self.best_action(root_node)
1159
+ return best_action
1160
+
1161
+ def select(self, node):
1162
+ while not node.is_leaf():
1163
+ total_visits = sum(child.visit_count for child in node.children.values())
1164
+ _, node = max(
1165
+ node.children.items(),
1166
+ key=lambda item: item[1].ucb_score(total_visits, self.exploration_constant)
1167
+ )
1168
+ return node
1169
+
1170
+ def evaluate(self, node):
1171
+ # Use the prediction network to get policy and value estimates
1172
+ state_representation = node.state.representation # Shape: (batch_size=1, seq_len, state_dim)
1173
+ policy_logits, value_estimate = self.prediction_network(state_representation)
1174
+ value_estimate = value_estimate.item() # Convert tensor to scalar
1175
+
1176
+ # Convert policy logits to probabilities
1177
+ policy_probs = F.softmax(policy_logits, dim=-1).squeeze(0) # Shape: (seq_len, action_vocab_size)
1178
+ # For simplicity, use the last time step's policy
1179
+ policy_probs = policy_probs[-1] # Shape: (action_vocab_size,)
1180
+
1181
+ # Map policy probabilities to the actions available from the current thought node
1182
+ priors = {}
1183
+ for child in node.thought_node.children:
1184
+ action_name = child.name
1185
+ action_idx = action_to_index.get(action_name, None)
1186
+ if action_idx is not None and action_idx < policy_probs.size(0):
1187
+ priors[action_name] = policy_probs[action_idx].item()
1188
+ else:
1189
+ priors[action_name] = 1.0 / len(node.thought_node.children) # Uniform prior if not found
1190
+
1191
+ # Expand the node
1192
+ node.expand(priors)
1193
+
1194
+ return value_estimate
1195
+
1196
+ def backpropagate(self, node, value):
1197
+ while node is not None:
1198
+ node.visit_count += 1
1199
+ node.value_sum += value
1200
+ node = node.parent
1201
+
1202
+ def best_action(self, root_node):
1203
+ # Select the child with the highest visit count
1204
+ best_child = max(root_node.children.values(), key=lambda n: n.visit_count)
1205
+ return best_child.action
1206
+
1207
+ class State:
1208
+ def __init__(self, representation, dynamics_network, action_encoder, thought_node):
1209
+ """
1210
+ Args:
1211
+ representation (torch.Tensor): Encoded state representation, shape (batch_size, seq_len, state_dim)
1212
+ dynamics_network (nn.Module): The Dynamics Network to predict next states
1213
+ action_encoder (nn.Module): The Action Encoder to encode actions
1214
+ thought_node (ThoughtNode): The current node in the Tree of Thought
1215
+ """
1216
+ self.representation = representation # Shape: (batch_size, seq_len, state_dim)
1217
+ self.dynamics_network = dynamics_network
1218
+ self.action_encoder = action_encoder
1219
+ self.thought_node = thought_node # Current position in the Tree of Thought
1220
+
1221
+ def apply_action(self, action):
1222
+ """
1223
+ Apply an action to the current state to get a new state.
1224
+
1225
+ Args:
1226
+ action (str): The action to apply (the name of the ThoughtNode)
1227
+
1228
+ Returns:
1229
+ State: The new state after applying the action
1230
+ """
1231
+ # Find the corresponding child node in the thought tree
1232
+ next_thought_node = None
1233
+ for child in self.thought_node.children:
1234
+ if child.name == action:
1235
+ next_thought_node = child
1236
+ break
1237
+ if next_thought_node is None:
1238
+ raise ValueError(f"Action '{action}' is not valid from the current thought node.")
1239
+
1240
+ # Encode action
1241
+ action_index = torch.tensor([[action_to_index[action]]], device=self.representation.device)
1242
+ action_embedding = self.action_encoder(action_index)
1243
+
1244
+ # Predict the next state using the Dynamics Network
1245
+ next_state_representation = self.dynamics_network(self.representation, action_embedding)
1246
+
1247
+ return State(
1248
+ representation=next_state_representation,
1249
+ dynamics_network=self.dynamics_network,
1250
+ action_encoder=self.action_encoder,
1251
+ thought_node=next_thought_node
1252
+ )
1253
+
1254
+ class PPOAgent:
1255
+ def __init__(self, policy_network, optimizer, clip_epsilon=0.2, entropy_coef=0.01, value_coef=0.5):
1256
+ self.policy_network = policy_network
1257
+ self.optimizer = optimizer
1258
+ self.clip_epsilon = clip_epsilon
1259
+ self.entropy_coef = entropy_coef
1260
+ self.value_coef = value_coef
1261
+
1262
+ def compute_loss(self, states, old_log_probs, actions, returns, advantages):
1263
+ # Get policy logits and value estimates
1264
+ policy_logits, value_estimates = self.policy_network(states)
1265
+ batch_size, seq_len, num_actions = policy_logits.size()
1266
+
1267
+ # Flatten tensors using reshape
1268
+ policy_logits = policy_logits.reshape(-1, num_actions) # Shape: (batch_size * seq_len, num_actions)
1269
+ value_estimates = value_estimates.view(-1)
1270
+ actions = actions.reshape(-1) # Shape: (batch_size * seq_len)
1271
+ old_log_probs = old_log_probs.reshape(-1) # Shape: (batch_size * seq_len)
1272
+ returns = returns.view(-1)
1273
+ advantages = advantages.reshape(-1) # Shape: (batch_size * seq_len)
1274
+
1275
+ # Ensure value_estimates and returns are the same size
1276
+ if value_estimates.size() != returns.size():
1277
+ print(f"Shape mismatch: value_estimates shape: {value_estimates.size()}, returns shape: {returns.size()}")
1278
+ value_estimates = value_estimates[:returns.size(0)]
1279
+
1280
+ # Compute new log probabilities
1281
+ new_log_probs_all = F.log_softmax(policy_logits, dim=-1) # Shape: (batch_size * seq_len, num_actions)
1282
+ new_log_probs = new_log_probs_all.gather(1, actions.unsqueeze(-1)).squeeze(-1) # Shape: (batch_size * seq_len)
1283
+
1284
+ # Compute ratios
1285
+ ratios = torch.exp(new_log_probs - old_log_probs)
1286
+
1287
+ # PPO surrogate loss
1288
+ surr1 = ratios * advantages
1289
+ surr2 = torch.clamp(ratios, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages
1290
+ policy_loss = -torch.min(surr1, surr2).mean()
1291
+
1292
+ # Value loss
1293
+ value_loss = F.mse_loss(value_estimates, returns)
1294
+
1295
+ # Entropy loss
1296
+ entropy = -(new_log_probs * torch.exp(new_log_probs)).mean()
1297
+
1298
+ # Total loss
1299
+ total_loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy
1300
+ return total_loss
1301
+
1302
+
1303
+
1304
+ def infer(query, world_model_components, root_thought_node, tokenizer, max_length=20, inference_mode='world_model'):
1305
+ """
1306
+ Perform inference given a query, utilizing the Tree of Thought and MCTS.
1307
+
1308
+ Args:
1309
+ query (str): The input query or prompt.
1310
+ world_model_components (tuple): Tuple containing the model components.
1311
+ root_thought_node (ThoughtNode): The root node of the Tree of Thought.
1312
+ tokenizer (transformers.PreTrainedTokenizer): The tokenizer used.
1313
+ max_length (int): Maximum length for the generated sequence.
1314
+ inference_mode (str): Inference mode ('world_model', 'without_world_model', 'world_model_tree_of_thought')
1315
+
1316
+ Returns:
1317
+ List[str] or str: The sequence of actions (thoughts) selected or generated text.
1318
+ """
1319
+ representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer = world_model_components
1320
+
1321
+ # Tokenize and encode the query
1322
+ input_ids = tokenizer.encode(query, return_tensors='pt').to(device)
1323
+ attention_mask = (input_ids != tokenizer.pad_token_id).long()
1324
+
1325
+ if inference_mode == 'without_world_model':
1326
+ # Directly use the transformer model to generate text
1327
+ with torch.no_grad():
1328
+ generated_ids, entropies, variances = model_transformer.generate(src=input_ids, tokenizer=tokenizer, max_length=max_length, temperature=args.temperature)
1329
+ generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
1330
+ return generated_text
1331
+
1332
+ else:
1333
+ # Use the world model components
1334
+ with torch.no_grad():
1335
+ transformer_output = model_transformer(input_ids, input_ids)
1336
+ # Get the initial state representation
1337
+ initial_representation = representation_network(transformer_output) # Shape: (batch_size=1, seq_len, state_dim)
1338
+ initial_state = State(
1339
+ representation=initial_representation,
1340
+ dynamics_network=dynamics_network,
1341
+ action_encoder=action_encoder,
1342
+ thought_node=root_thought_node
1343
+ )
1344
+
1345
+ if inference_mode == 'world_model_tree_of_thought':
1346
+ # Use MCTS with Tree of Thought
1347
+ mcts = MCTS(prediction_network, dynamics_network, action_encoder, num_iterations=args.mcts_iterations, exploration_constant=args.mcts_exploration_constant)
1348
+ current_state = initial_state
1349
+ thought_sequence = []
1350
+
1351
+ for _ in range(max_length):
1352
+ best_action = mcts.search(current_state)
1353
+ thought_sequence.append(best_action)
1354
+
1355
+ # Apply the best action to get the next state
1356
+ current_state = current_state.apply_action(best_action)
1357
+
1358
+ # Check if we've reached a leaf node (no further actions)
1359
+ if len(current_state.thought_node.children) == 0:
1360
+ break
1361
+
1362
+ return thought_sequence
1363
+ else:
1364
+ # Use the world model without Tree of Thought
1365
+ # For simplicity, we will generate actions based on the prediction network
1366
+ policy_logits, _ = prediction_network(initial_state.representation)
1367
+ policy_probs = F.softmax(policy_logits, dim=-1)
1368
+ # Select actions with highest probabilities
1369
+ top_actions = torch.argmax(policy_probs, dim=-1)
1370
+ generated_actions = [index_to_action[idx.item()] for idx in top_actions[0]]
1371
+ return generated_actions
1372
+
1373
+ def train_epoch_world_model(world_model_components, train_loader, optimizer, scheduler, scaler, args, model_transformer, state_dim, embed_dim, input_dim):
1374
+ representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, _ = world_model_components
1375
+ representation_network.train()
1376
+ dynamics_network.train()
1377
+ prediction_network.train()
1378
+ action_encoder.train()
1379
+ ppo_agent.policy_network.train()
1380
+
1381
+ total_loss = 0.0
1382
+ optimizer.zero_grad()
1383
+ print(f"Starting World Model training epoch with {len(train_loader)}batches...")
1384
+
1385
+ for i, batch in enumerate(train_loader):
1386
+ print(f"Processing batch {i+1}/{len(train_loader)}...")
1387
+
1388
+ # Move batches to the device
1389
+ src_batch = batch['input_ids'].to(device)
1390
+ tgt_batch = batch['labels'].to(device)
1391
+
1392
+ with torch.amp.autocast(device_type='cuda'):
1393
+ print("Forward pass through Transformer (frozen)...")
1394
+ with torch.no_grad():
1395
+ transformer_output = model_transformer(src_batch, tgt_batch[:, :-1])
1396
+
1397
+ # World Model - Representation
1398
+ state_representation = representation_network(transformer_output) # On GPU
1399
+
1400
+ # For simplicity, let's assume true actions are provided (e.g., next tokens)
1401
+ true_actions = tgt_batch[:, :-1] # Shape: (batch_size, seq_len)
1402
+ action_sequences = true_actions
1403
+
1404
+ # Get action embeddings
1405
+ action_embeddings = action_encoder(action_sequences) # Shape: (batch_size, seq_len, embed_dim)
1406
+
1407
+ # Apply dynamics network
1408
+ predicted_next_state_batch = dynamics_network(state_representation, action_embeddings) # Shape: (batch_size, seq_len, state_dim)
1409
+
1410
+ # Prediction Network - Policy logits and value
1411
+ policy_logits, value_estimates = prediction_network(predicted_next_state_batch)
1412
+ # value_estimates now has shape (batch_size, seq_len)
1413
+
1414
+ # Define true_policy and true_value as placeholders on the GPU
1415
+ true_policy = F.one_hot(true_actions, num_classes=input_dim).float() # Shape: (batch_size, seq_len, input_dim)
1416
+ true_value = torch.zeros_like(value_estimates).to(device)
1417
+
1418
+ # Compute PPO loss
1419
+ actions_selected = true_actions # Shape: (batch_size, seq_len)
1420
+ old_log_probs = torch.zeros_like(actions_selected, dtype=torch.float32).to(device)
1421
+ returns = torch.zeros_like(actions_selected, dtype=torch.float32).to(device)
1422
+ advantages = torch.zeros_like(actions_selected, dtype=torch.float32).to(device)
1423
+
1424
+ # Compute PPO loss using states
1425
+ ppo_loss = ppo_agent.compute_loss(state_representation, old_log_probs, actions_selected, returns, advantages)
1426
+
1427
+ # Compute InfoNCE Loss
1428
+ z_i = state_representation.view(-1, state_dim) # Shape: (batch_size * seq_len, state_dim)
1429
+ z_j = F.dropout(z_i, p=0.1, training=True)
1430
+ info_nce = InfoNCE_Loss()(z_i, z_j)
1431
+
1432
+ # Compute other losses
1433
+ covariance = CovarianceRegularization()(predicted_next_state_batch.view(-1, predicted_next_state_batch.size(-1)))
1434
+ dynamics_loss = DynamicsPerformanceLoss()(state_representation, predicted_next_state_batch)
1435
+ perturbed_next_state = predicted_next_state_batch + torch.randn_like(predicted_next_state_batch) * 0.01
1436
+ thought_loss = ThoughtConsistencyLoss()(predicted_next_state_batch, perturbed_next_state)
1437
+ pv_loss = PolicyValueJointLoss()(policy_logits, true_policy, value_estimates.squeeze(-1), true_value.squeeze(-1))
1438
+ action_diversity = ActionDiversityReward()(action_embeddings.view(-1, embed_dim))
1439
+ mcts_best_values = torch.zeros(actions_selected.size(0)).to(device)
1440
+ etv = ExpectedThoughtValueLoss()(mcts_best_values)
1441
+ visit_counts = torch.ones(actions_selected.size(0), policy_logits.size(-1)).to(device)
1442
+ exploration = ExplorationRegularization()(visit_counts)
1443
+ old_policy = F.softmax(policy_logits.detach(), dim=-1)
1444
+ new_policy = F.softmax(policy_logits, dim=-1)
1445
+ kl_loss = KL_DivergenceLoss()(old_policy, new_policy)
1446
+
1447
+ # Total Loss
1448
+ loss = (
1449
+ ppo_loss +
1450
+ info_nce +
1451
+ covariance +
1452
+ dynamics_loss +
1453
+ thought_loss +
1454
+ pv_loss +
1455
+ action_diversity +
1456
+ etv +
1457
+ exploration +
1458
+ kl_loss
1459
+ )
1460
+ loss = loss / args.accumulation_steps
1461
+
1462
+ print("Backward pass...")
1463
+ scaler.scale(loss).backward()
1464
+
1465
+ if (i + 1) % args.accumulation_steps == 0 or (i + 1) == len(train_loader):
1466
+ print("Gradient clipping...")
1467
+ scaler.unscale_(optimizer)
1468
+ torch.nn.utils.clip_grad_norm_(
1469
+ [param for group in optimizer.param_groups for param in group['params']],
1470
+ args.max_grad_norm
1471
+ )
1472
+
1473
+ print("Optimizer step...")
1474
+ scaler.step(optimizer)
1475
+ scaler.update()
1476
+
1477
+ print("Zeroing gradients...")
1478
+ optimizer.zero_grad()
1479
+
1480
+ print("Updating learning rate...")
1481
+ scheduler.step()
1482
+
1483
+ total_loss += loss.item() * args.accumulation_steps
1484
+ print(f"Batch {i+1} completed. Current loss: {loss.item():.4f}")
1485
+
1486
+ avg_loss = total_loss / len(train_loader)
1487
+ print(f"World Model training epoch completed. Average loss: {avg_loss:.4f}")
1488
+ return avg_loss
1489
+
1490
+
1491
+ def train_epoch_language_model(model, train_loader, optimizer, scheduler, scaler, args):
1492
+ model.train()
1493
+ total_loss = 0.0
1494
+ optimizer.zero_grad()
1495
+ print(f"Starting Language Model training epoch with {len(train_loader)} batches...")
1496
+
1497
+ for i, batch in enumerate(train_loader):
1498
+ input_ids = batch['input_ids'].to(device)
1499
+ labels = batch['labels'].to(device)
1500
+
1501
+ with autocast():
1502
+ outputs = model(input_ids, input_ids)
1503
+ logits = outputs.view(-1, outputs.size(-1))
1504
+ labels = labels.view(-1)
1505
+ loss = F.cross_entropy(logits, labels, ignore_index=model.embedding.padding_idx)
1506
+ loss = loss / args.accumulation_steps
1507
+
1508
+ scaler.scale(loss).backward()
1509
+
1510
+ if (i + 1) % args.accumulation_steps == 0 or (i + 1) == len(train_loader):
1511
+ scaler.unscale_(optimizer)
1512
+ torch.nn.utils.clip_grad_norm_(
1513
+ [param for group in optimizer.param_groups for param in group['params']],
1514
+ args.max_grad_norm
1515
+ )
1516
+ scaler.step(optimizer)
1517
+ scaler.update()
1518
+ optimizer.zero_grad()
1519
+ scheduler.step()
1520
+
1521
+ total_loss += loss.item() * args.accumulation_steps
1522
+ print(f"Batch {i + 1} completed. Current loss: {loss.item():.4f}")
1523
+
1524
+ avg_loss = total_loss / len(train_loader)
1525
+ print(f"Language Model training epoch completed. Average loss: {avg_loss:.4f}")
1526
+ return avg_loss
1527
+
1528
+
1529
+
1530
+ def main():
1531
+ args = parse_args()
1532
+ print("Arguments parsed successfully.")
1533
+
1534
+ # Create save directory
1535
+ os.makedirs(args.save_dir, exist_ok=True)
1536
+ print(f"Save directory created: {args.save_dir}")
1537
+
1538
+ # Load tokenizer
1539
+ print("Loading tokenizer...")
1540
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
1541
+ if tokenizer.pad_token is None:
1542
+ tokenizer.pad_token = tokenizer.eos_token
1543
+ print("Tokenizer loaded successfully.")
1544
+
1545
+ # Define padding_idx and input dimension based on tokenizer
1546
+ padding_idx = tokenizer.pad_token_id
1547
+ input_dim = len(tokenizer)
1548
+
1549
+ # Initialize the Transformer model on GPU
1550
+ print("Initializing Transformer model...")
1551
+ model_transformer = Transformer(
1552
+ input_dim=input_dim,
1553
+ d_model=128,
1554
+ num_heads=4,
1555
+ num_layers=4,
1556
+ d_ff=256,
1557
+ num_experts=2,
1558
+ output_dim=input_dim,
1559
+ dropout=0.1,
1560
+ top_k=2
1561
+ ).to(device)
1562
+ model_transformer.train()
1563
+ print("Transformer model initialized on device.")
1564
+
1565
+ # Define model parameters (adjusted for speed)
1566
+ d_model = 128
1567
+ state_dim = 128
1568
+ action_dim = d_model
1569
+ hidden_dim = 256
1570
+ vocab_dim = input_dim
1571
+ embed_dim = d_model
1572
+
1573
+ # Define World Model components
1574
+ representation_network = RepresentationNetwork(vocab_dim, d_model, state_dim).to(device)
1575
+ dynamics_network = DynamicsNetwork(state_dim, action_dim, hidden_dim).to(device)
1576
+ prediction_network = PredictionNetwork(state_dim, input_dim, 1).to(device)
1577
+ action_encoder = ActionEncoder(input_dim, action_dim).to(device)
1578
+
1579
+ # Initialize PPO Agent
1580
+ ppo_agent = PPOAgent(
1581
+ policy_network=prediction_network,
1582
+ optimizer=optim.AdamW(prediction_network.parameters(), lr=args.learning_rate),
1583
+ clip_epsilon=0.2,
1584
+ entropy_coef=0.01,
1585
+ value_coef=0.5
1586
+ )
1587
+
1588
+ # Bundle World Model components
1589
+ world_model_components = (representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer)
1590
+
1591
+ if args.mode == 'train':
1592
+ print("Loading and preprocessing data...")
1593
+ train_loader, eval_loader = load_data(args, tokenizer)
1594
+ print("Data loaded and preprocessed successfully.")
1595
+
1596
+ # Optimizer and Scheduler
1597
+ optimizer = optim.AdamW(
1598
+ list(representation_network.parameters()) +
1599
+ list(dynamics_network.parameters()) +
1600
+ list(prediction_network.parameters()) +
1601
+ list(action_encoder.parameters()),
1602
+ lr=args.learning_rate, weight_decay=args.weight_decay
1603
+ ) if args.train_mode == 'world_model' else optim.AdamW(model_transformer.parameters(), lr=args.learning_rate)
1604
+ scheduler = CosineAnnealingLR(optimizer, T_max=args.num_epochs)
1605
+ scaler = GradScaler()
1606
+
1607
+ print(f"Starting {args.train_mode} training...")
1608
+
1609
+ for epoch in range(args.num_epochs):
1610
+ if args.train_mode == 'world_model':
1611
+ avg_loss = train_epoch_world_model(
1612
+ world_model_components,
1613
+ train_loader,
1614
+ optimizer,
1615
+ scheduler,
1616
+ scaler,
1617
+ args,
1618
+ model_transformer,
1619
+ state_dim,
1620
+ embed_dim,
1621
+ input_dim
1622
+ )
1623
+ else:
1624
+ avg_loss = train_epoch_language_model(
1625
+ model_transformer,
1626
+ train_loader,
1627
+ optimizer,
1628
+ scheduler,
1629
+ scaler,
1630
+ args
1631
+ )
1632
+
1633
+ print(f"{args.train_mode.capitalize()} training epoch {epoch + 1} completed. Average loss: {avg_loss:.4f}")
1634
+
1635
+ if args.train_mode == 'world_model':
1636
+ save_all_models(model_transformer, representation_network, dynamics_network, prediction_network, action_encoder, args.save_dir, epoch + 1)
1637
+ print(f"Models saved for epoch {epoch + 1}")
1638
+ else:
1639
+ torch.save(model_transformer.state_dict(), os.path.join(args.save_dir, f'language_model_epoch_{epoch + 1}.pt'))
1640
+ print(f"Language model saved for epoch {epoch + 1}")
1641
+
1642
+ print("Training completed.")
1643
+
1644
+ elif args.mode == 'inference':
1645
+ # Build Tree of Thought if needed
1646
+ tree_root = build_tree_of_thought()
1647
+ # Generate action list
1648
+ action_list = []
1649
+ traverse_tree(tree_root, action_list)
1650
+
1651
+ # Create mappings
1652
+ global action_to_index, index_to_action
1653
+ action_to_index = {action: idx for idx, action in enumerate(action_list)}
1654
+ index_to_action = {idx: action for action, idx in action_to_index.items()}
1655
+ action_vocab_size = len(action_list)
1656
+
1657
+ # Update action encoder and prediction network with new vocab size
1658
+ action_encoder = ActionEncoder(action_vocab_size, action_dim).to(device)
1659
+ prediction_network = PredictionNetwork(state_dim, action_vocab_size, 1).to(device)
1660
+
1661
+ # Load the saved models
1662
+ # Assuming the models are saved after training
1663
+ # You need to adjust the paths and epoch numbers as necessary
1664
+ model_transformer.load_state_dict(torch.load(os.path.join(args.save_dir, 'transformer_model_epoch_2.pt')))
1665
+ representation_network.load_state_dict(torch.load(os.path.join(args.save_dir, 'representation_network_epoch_2.pt')))
1666
+ dynamics_network.load_state_dict(torch.load(os.path.join(args.save_dir, 'dynamics_network_epoch_2.pt')))
1667
+ saved_state_dict = torch.load(os.path.join(args.save_dir, 'prediction_network_epoch_2.pt'))
1668
+ prediction_network.policy_head = nn.Linear(prediction_network.state_dim, 50257) # Update to match saved model size
1669
+ prediction_network.load_state_dict(saved_state_dict, strict=False)
1670
+
1671
+ # Resize `policy_head` back to 158 after loading
1672
+ prediction_network.policy_head = nn.Linear(prediction_network.state_dim, 158).to(device)
1673
+
1674
+
1675
+
1676
+ action_encoder.load_state_dict(torch.load(os.path.join(args.save_dir, 'action_encoder_epoch_2.pt')))
1677
+
1678
+ # Prepare the components
1679
+ world_model_components = (representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer)
1680
+
1681
+ # Perform inference
1682
+ if not args.query:
1683
+ args.query = input("Please enter your query: ")
1684
+
1685
+ result = infer(args.query, world_model_components, tree_root, tokenizer, inference_mode=args.inference_mode)
1686
+
1687
+ if args.inference_mode == 'without_world_model':
1688
+ print("Generated Text:")
1689
+ print(result)
1690
+ else:
1691
+ print("Generated Thought Sequence:")
1692
+ for thought in result:
1693
+ print(thought)
1694
+
1695
+ if __name__ == '__main__':
1696
+ main()
main_menu.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main_menu.py
2
+
3
+ import argparse
4
+ import sys
5
+ from train_agent import train_agent
6
+ from test_agent import TestAgent, run_test_session
7
+ from lightbulb import main as world_model_main
8
+
9
+ def parse_main_args():
10
+ parser = argparse.ArgumentParser(description="Main Menu for Selecting Tasks")
11
+ parser.add_argument('--task', type=str, choices=['train_llm_world', 'train_agent', 'test_agent'],
12
+ required=True, help='Choose task to execute: train_llm_world, train_agent, test_agent')
13
+ # Optional arguments for more granular control
14
+ parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name for LLM')
15
+ parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name for training')
16
+ parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
17
+ parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training')
18
+ parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs for training')
19
+ parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length for training')
20
+ parser.add_argument('--mode', type=str, choices=['train', 'inference'], default='train', help='Train or inference mode for LLM')
21
+ parser.add_argument('--query', type=str, default='', help='Query for the test_agent')
22
+ return parser.parse_args()
23
+
24
+ def main():
25
+ # Parse arguments for the main function
26
+ args = parse_main_args()
27
+
28
+ # Execute tasks based on user input
29
+ if args.task == 'train_llm_world':
30
+ print("Starting LLM and World Model Training...")
31
+ # Directly call the world model main function
32
+ sys.argv = ['lightbulb.py', '--mode', args.mode, '--model_name', args.model_name,
33
+ '--dataset_name', args.dataset_name, '--dataset_config', args.dataset_config,
34
+ '--batch_size', str(args.batch_size), '--num_epochs', str(args.num_epochs),
35
+ '--max_length', str(args.max_length)]
36
+ world_model_main()
37
+
38
+ elif args.task == 'train_agent':
39
+ print("Starting Agent Training...")
40
+ # Call the train_agent function from train_agent.py
41
+ from twisted.internet import reactor, task
42
+ d = task.deferLater(reactor, 0, train_agent)
43
+ d.addErrback(lambda failure: print(f"An error occurred: {failure}", exc_info=True))
44
+ d.addBoth(lambda _: reactor.stop())
45
+ reactor.run()
46
+
47
+ elif args.task == 'test_agent':
48
+ print("Starting Test Agent...")
49
+ test_agent = TestAgent()
50
+ if args.query:
51
+ # Directly process a single query
52
+ result = test_agent.process_query(args.query)
53
+ print("\nAgent's response:")
54
+ print(result)
55
+ else:
56
+ # Run the interactive session
57
+ reactor.callWhenRunning(run_test_session)
58
+ reactor.run()
59
+
60
+ if __name__ == "__main__":
61
+ main()
mcts.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mcts.py
2
+ import math
3
+ import random
4
+ from nltk.corpus import wordnet
5
+ from scrapy.crawler import CrawlerRunner
6
+ from scrapy.utils.log import configure_logging
7
+ from scrapy.utils.project import get_project_settings
8
+ from twisted.internet import reactor, defer
9
+ from scrapy import signals
10
+ import logging
11
+ from my_search_engine.my_search_engine.spiders.search_spider import SearchSpider
12
+ from sentence_transformers import SentenceTransformer, util
13
+ from ranking import train_ranking_model
14
+ import time
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ class MCTSNode:
19
+ def __init__(self, state, parent=None, action=None):
20
+ self.state = state
21
+ self.parent = parent
22
+ self.action = action
23
+ self.children = []
24
+ self.visits = 0
25
+ self.value = 0
26
+ self.ucb_score = float('inf')
27
+
28
+ def is_leaf(self):
29
+ return len(self.children) == 0
30
+
31
+ def add_child(self, child_state, action=None):
32
+ child_node = MCTSNode(child_state, parent=self, action=action)
33
+ self.children.append(child_node)
34
+ return child_node
35
+
36
+ def update(self, reward):
37
+ self.visits += 1
38
+ self.value += reward
39
+ if self.parent: # Only calculate UCB if not root
40
+ self.ucb_score = self.calculate_ucb()
41
+
42
+ def calculate_ucb(self, exploration_weight=1.41):
43
+ if self.visits == 0 or not self.parent:
44
+ return float('inf')
45
+ exploitation = self.value / self.visits
46
+ exploration = exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits)
47
+ return exploitation + exploration
48
+
49
+ class MCTS:
50
+ def __init__(self, initial_state, num_simulations=20, exploration_weight=1.41):
51
+ self.root = MCTSNode(initial_state)
52
+ self.num_simulations = num_simulations
53
+ self.exploration_weight = exploration_weight
54
+ self.query_model = SentenceTransformer('all-MiniLM-L6-v2')
55
+ self.results = []
56
+ self.crawler_runner = CrawlerRunner(get_project_settings())
57
+ self.initial_state = initial_state
58
+ self.num_iterations = 5
59
+
60
+ def select(self, node):
61
+ while not node.is_leaf():
62
+ if not node.children:
63
+ return node
64
+ node = max(node.children, key=lambda c: c.calculate_ucb(self.exploration_weight))
65
+ return node
66
+
67
+ def expand(self, node):
68
+ if node.visits == 0:
69
+ return node
70
+ possible_refinements = self.get_possible_refinements(node.state)
71
+ for refinement in possible_refinements:
72
+ node.add_child(refinement)
73
+ return random.choice(node.children) if node.children else node
74
+
75
+ def calculate_combined_reward(self, ranking_score, state):
76
+ state_length_reward = len(state) / 100
77
+ if state:
78
+ query_complexity = len(set(state.split())) / len(state.split())
79
+ else:
80
+ query_complexity = 0
81
+ semantic_similarity = self.calculate_semantic_similarity(state, self.root.state)
82
+
83
+ combined_reward = (
84
+ 0.5 * ranking_score +
85
+ 0.2 * state_length_reward +
86
+ 0.2 * query_complexity +
87
+ 0.1 * semantic_similarity
88
+ )
89
+ return combined_reward
90
+
91
+ def calculate_semantic_similarity(self, query1, query2):
92
+ embedding1 = self.query_model.encode(query1)
93
+ embedding2 = self.query_model.encode(query2)
94
+ return util.pytorch_cos_sim(embedding1, embedding2).item()
95
+
96
+ def backpropagate(self, node, reward):
97
+ while node is not None:
98
+ node.update(reward)
99
+ node = node.parent
100
+
101
+ def best_action(self):
102
+ if not self.root.children:
103
+ return self.root
104
+
105
+ def score(node):
106
+ if node.visits == 0:
107
+ return float('-inf')
108
+ return node.value / node.visits
109
+
110
+ return max(self.root.children, key=score)
111
+
112
+ def refine_query(self, query):
113
+ words = query.split()
114
+ refined_query = []
115
+
116
+ for word in words:
117
+ if word.lower() not in {"how", "to", "get", "an", "the", "and", "or", "of", "build"}:
118
+ synonyms = wordnet.synsets(word)
119
+ if synonyms:
120
+ synonym_words = [lemma.name() for lemma in synonyms[0].lemmas()
121
+ if len(lemma.name().split()) == 1 and word != lemma.name()]
122
+ if synonym_words:
123
+ refined_query.append(random.choice(synonym_words))
124
+ else:
125
+ refined_query.append(word)
126
+ else:
127
+ refined_query.append(word)
128
+ else:
129
+ refined_query.append(word)
130
+
131
+ possible_intent_keywords = ['guide', 'tutorial', 'LLM', 'language model', 'NLP', 'GPT']
132
+ refined_query.append(random.choice(possible_intent_keywords))
133
+
134
+ return ' '.join(refined_query)
135
+
136
+ def get_related_queries(self, query):
137
+ query_embedding = self.query_model.encode(query)
138
+ refined_query_variations = [query]
139
+ words_to_avoid = {'how', 'to', 'get'}
140
+ words = query.split()
141
+
142
+ for word in words:
143
+ if word.lower() not in words_to_avoid:
144
+ synonyms = wordnet.synsets(word)
145
+ if synonyms:
146
+ synonym_words = [lemma.name() for lemma in synonyms[0].lemmas() if lemma.name() != word]
147
+ if synonym_words:
148
+ refined_query = query.replace(word, random.choice(synonym_words))
149
+ refined_query_variations.append(refined_query)
150
+
151
+ refined_query_variations = list(set(refined_query_variations))
152
+ refined_query_embeddings = [self.query_model.encode(variation) for variation in refined_query_variations]
153
+ similarity_scores = util.pytorch_cos_sim(query_embedding, refined_query_embeddings).tolist()[0]
154
+
155
+ similarity_threshold = 0.8
156
+ filtered_queries = [variation for idx, variation in enumerate(refined_query_variations)
157
+ if similarity_scores[idx] > similarity_threshold]
158
+
159
+ return filtered_queries[:2] if filtered_queries else [query]
160
+
161
+ def get_possible_refinements(self, query):
162
+ refined_queries = self.get_related_queries(query)
163
+ return refined_queries + [self.refine_query(query)]
164
+
165
+ @defer.inlineCallbacks
166
+ def web_search(self, query, search_sites=None):
167
+ if not query.strip():
168
+ logger.error("Cannot perform web search with an empty query.")
169
+ defer.returnValue([])
170
+
171
+ logger.info(f"Starting web search for query: {query}")
172
+ configure_logging(install_root_handler=False)
173
+ logging.basicConfig(level=logging.INFO)
174
+
175
+ results = []
176
+
177
+ def crawler_results(item, response, spider):
178
+ logger.info(f"Received result: {item['title']}")
179
+ results.append(item)
180
+
181
+ try:
182
+ crawler = self.crawler_runner.create_crawler(SearchSpider)
183
+ crawler.signals.connect(crawler_results, signal=signals.item_scraped)
184
+
185
+ # Start crawling, passing query and search_sites to the spider
186
+ yield self.crawler_runner.crawl(crawler, query=query, search_sites=search_sites)
187
+ except Exception as e:
188
+ logger.error(f"Error during web search: {str(e)}")
189
+ defer.returnValue([])
190
+
191
+ logger.info(f"Web search completed. Found {len(results)} results.")
192
+ defer.returnValue(results)
193
+
194
+ @defer.inlineCallbacks
195
+ def run(self):
196
+ logger.info(f"Starting MCTS run with {self.num_iterations} iterations")
197
+ for i in range(self.num_iterations):
198
+ logger.debug(f"Iteration {i+1}/{self.num_iterations}")
199
+ leaf = self.select(self.root)
200
+ child = self.expand(leaf)
201
+ reward = yield self.simulate(child)
202
+ self.backpropagate(child, reward)
203
+
204
+ best_child = self.best_action()
205
+ logger.info(f"MCTS run completed. Best action: {best_child.state}")
206
+ defer.returnValue(best_child.state if best_child != self.root else self.root.state)
207
+
208
+ @defer.inlineCallbacks
209
+ def simulate(self, node):
210
+ query_results = yield self.web_search(node.state)
211
+ ranked_results = train_ranking_model(node.state, query_results)
212
+
213
+ if ranked_results:
214
+ top_score = ranked_results[0]['predicted_score']
215
+ else:
216
+ top_score = 0
217
+
218
+ reward = self.calculate_combined_reward(top_score, node.state)
219
+ defer.returnValue(reward)
220
+
221
+
222
+
223
+
224
+
225
+
my_search_engine/my_search_engine/__init__.py ADDED
File without changes
my_search_engine/my_search_engine/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (165 Bytes). View file
 
my_search_engine/my_search_engine/__pycache__/items.cpython-312.pyc ADDED
Binary file (799 Bytes). View file
 
my_search_engine/my_search_engine/items.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import scrapy
2
+
3
+ class MySearchEngineItem(scrapy.Item):
4
+ title = scrapy.Field()
5
+ link = scrapy.Field()
6
+ content = scrapy.Field()
7
+ score = scrapy.Field() # Will be set later during ranking (MCTS or NLP)
8
+ meta = scrapy.Field()
9
+ predicted_score = scrapy.Field()
10
+ summary = scrapy.Field()
my_search_engine/my_search_engine/middlewares.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # middlewares.py
2
+
3
+ import random
4
+ import logging
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ class RotateUserAgentMiddleware:
9
+ """Middleware for rotating user agents to avoid detection."""
10
+
11
+ USER_AGENTS = [
12
+ # Chrome User Agents
13
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
14
+ " Chrome/93.0.4577.63 Safari/537.36",
15
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko)"
16
+ " Chrome/93.0.4577.63 Safari/537.36",
17
+ "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko)"
18
+ " Chrome/93.0.4577.63 Safari/537.36",
19
+ # Firefox User Agents
20
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:92.0) Gecko/20100101 Firefox/92.0",
21
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 11.5; rv:92.0) Gecko/20100101 Firefox/92.0",
22
+ "Mozilla/5.0 (X11; Linux x86_64; rv:92.0) Gecko/20100101 Firefox/92.0",
23
+ # Safari User Agents
24
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.1.2 Safari/605.1.15",
25
+ "Mozilla/5.0 (iPhone; CPU iPhone OS 14_7_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.1.2 Mobile/15E148 Safari/604.1",
26
+ # Edge User Agents
27
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/93.0.4577.63 Safari/537.36 Edg/93.0.961.38",
28
+ # Opera User Agents
29
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/93.0.4577.63 Safari/537.36 OPR/78.0.4093.184",
30
+ # Mobile User Agents
31
+ "Mozilla/5.0 (Linux; Android 11; SM-G981B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/93.0.4577.62 Mobile Safari/537.36",
32
+ "Mozilla/5.0 (iPhone; CPU iPhone OS 14_7_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) CriOS/93.0.4577.62 Mobile/15E148 Safari/604.1",
33
+ # Add more user agents from different browsers and devices
34
+ ]
35
+
36
+ def process_request(self, request, spider):
37
+ """Assign a random user agent to each request."""
38
+ user_agent = random.choice(self.USER_AGENTS)
39
+ request.headers['User-Agent'] = user_agent
40
+ logger.debug(f"Using User-Agent: {user_agent}")
41
+
42
+ # Optional: Proxy Middleware
43
+ class ProxyMiddleware:
44
+ """Middleware for rotating proxies."""
45
+
46
+ PROXIES = [
47
+ # Add proxy URLs if using proxies
48
+ # 'http://proxy1.example.com:8000',
49
+ # 'http://proxy2.example.com:8031',
50
+ ]
51
+
52
+ def process_request(self, request, spider):
53
+ if self.PROXIES:
54
+ proxy = random.choice(self.PROXIES)
55
+ request.meta['proxy'] = proxy
56
+ logger.debug(f"Using Proxy: {proxy}")
my_search_engine/my_search_engine/pipelines.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pipelines.py
2
+
3
+ import json
4
+
5
+ class SaveToJSONPipeline:
6
+ """Pipeline that saves scraped data to a JSON file."""
7
+
8
+ def open_spider(self, spider):
9
+ """Open the file when the spider starts."""
10
+ self.file = open('scraped_results.json', 'w', encoding='utf-8')
11
+
12
+ def close_spider(self, spider):
13
+ """Close the file when the spider finishes."""
14
+ self.file.close()
15
+
16
+ def process_item(self, item, spider):
17
+ """Write each scraped item to the JSON file."""
18
+ line = json.dumps(dict(item), ensure_ascii=False) + "\n"
19
+ self.file.write(line)
20
+ return item
21
+
22
+
23
+ class ContentCleanupPipeline:
24
+ """Pipeline to clean up content by removing unnecessary whitespace."""
25
+
26
+ def process_item(self, item, spider):
27
+ """Clean up content field."""
28
+ item['content'] = ' '.join(item['content'].split()) # Clean up content by removing extra spaces
29
+ return item
30
+
31
+
32
+ class DisplayResultsPipeline:
33
+ """Pipeline that formats and prints the search results in a Google-like format."""
34
+
35
+ def open_spider(self, spider):
36
+ """Initialize an empty results list when the spider starts."""
37
+ self.results = []
38
+
39
+ def process_item(self, item, spider):
40
+ """Store the item in the results list."""
41
+ self.results.append({
42
+ 'title': item['title'],
43
+ 'summary': item['content'],
44
+ 'link': item['link']
45
+ })
46
+ return item
47
+
48
+ def close_spider(self, spider):
49
+ """Print out the formatted results when the spider finishes."""
50
+ print("\nTop 10 Related Links for the Search Query:")
51
+ for i, result in enumerate(self.results[:10], start=1):
52
+ print(f"{i}. {result['title']}\n {result['summary'][:200]}...\n {result['link']}\n")
53
+
my_search_engine/my_search_engine/settings.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # settings.py
2
+
3
+ # Scrapy configurations
4
+ BOT_NAME = 'my_search_engine'
5
+ SPIDER_MODULES = ['my_search_engine.spiders']
6
+ NEWSPIDER_MODULE = 'my_search_engine.spiders'
7
+
8
+ # Obey robots.txt rules
9
+ ROBOTSTXT_OBEY = True
10
+
11
+ # Configure maximum concurrent requests performed by Scrapy
12
+ CONCURRENT_REQUESTS = 16
13
+ CONCURRENT_REQUESTS_PER_DOMAIN = 1
14
+
15
+ # Configure a delay for requests to the same website
16
+ DOWNLOAD_DELAY = 2 # Fixed delay of 2 seconds
17
+
18
+ # Disable cookies (enabled by default)
19
+ COOKIES_ENABLED = False
20
+
21
+ # Enable AutoThrottle
22
+ AUTOTHROTTLE_ENABLED = True
23
+ AUTOTHROTTLE_START_DELAY = 1 # Initial download delay
24
+ AUTOTHROTTLE_MAX_DELAY = 10 # Maximum download delay in case of high latencies
25
+ AUTOTHROTTLE_TARGET_CONCURRENCY = 1.0 # Average number of requests Scrapy should be sending in parallel
26
+
27
+ # User Agent (default, only used if RotateUserAgentMiddleware fails)
28
+ USER_AGENT = "Mozilla/5.0 (compatible; MySearchEngine/1.0)"
29
+
30
+ # Downloader middlewares
31
+ DOWNLOADER_MIDDLEWARES = {
32
+ 'my_search_engine.middlewares.RotateUserAgentMiddleware': 543,
33
+ # Uncomment the following line if using the ProxyMiddleware
34
+ # 'my_search_engine.middlewares.ProxyMiddleware': 544,
35
+ 'scrapy.downloadermiddlewares.useragent.UserAgentMiddleware': None, # Disable default user agent middleware
36
+ }
37
+
38
+ # Item pipelines
39
+ ITEM_PIPELINES = {
40
+ 'my_search_engine.pipelines.SaveToJSONPipeline': 300,
41
+ 'my_search_engine.pipelines.ContentCleanupPipeline': 400,
42
+ 'my_search_engine.pipelines.DisplayResultsPipeline': 200,
43
+ }
44
+
45
+ # Enable logging
46
+ LOG_ENABLED = True
47
+ LOG_LEVEL = 'INFO' # Set to 'DEBUG' to see all logs, including middleware logs
48
+
49
+ # Additional settings can be added below as needed
my_search_engine/my_search_engine/spiders/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # This package will contain the spiders of your Scrapy project
2
+ #
3
+ # Please refer to the documentation for information on how to create and manage
4
+ # your spiders.
my_search_engine/my_search_engine/spiders/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (173 Bytes). View file
 
my_search_engine/my_search_engine/spiders/__pycache__/search_spider.cpython-312.pyc ADDED
Binary file (11.2 kB). View file
 
my_search_engine/my_search_engine/spiders/search_spider.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # search_spider.py
2
+ import scrapy
3
+ from bs4 import BeautifulSoup
4
+ from my_search_engine.my_search_engine.items import MySearchEngineItem
5
+ import random
6
+ from urllib.parse import urlparse, urljoin
7
+ import traceback
8
+ import re
9
+ from twisted.internet.error import TCPTimedOutError, ConnectionRefusedError, TimeoutError
10
+ from scrapy.exceptions import CloseSpider
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class SearchSpider(scrapy.Spider):
16
+ name = "search_spider"
17
+ allowed_domains = [] # To be set dynamically from search_sites
18
+
19
+ def __init__(self, query=None, search_sites=None, max_depth=2, max_links_per_page=3, *args, **kwargs):
20
+ super(SearchSpider, self).__init__(*args, **kwargs)
21
+ self.query = query
22
+ if not self.query:
23
+ raise CloseSpider("No search query provided")
24
+ self.max_depth = max_depth
25
+ self.max_links_per_page = max_links_per_page
26
+ if search_sites is None:
27
+ self.search_sites = [
28
+ f"https://en.wikibooks.org/w/index.php?search={self.query}",
29
+ f"https://en.wikiversity.org/w/index.php?search={self.query}",
30
+ f"https://commons.wikimedia.org/w/index.php?search={self.query}",
31
+ f"https://stackexchange.com/search?q={self.query}",
32
+ f"https://arxiv.org/search/?query={self.query}&searchtype=all",
33
+ f"https://www.ncbi.nlm.nih.gov/pmc/?term={self.query}",
34
+ f"https://www.gutenberg.org/ebooks/search/?query={self.query}",
35
+ f"https://openlibrary.org/search?q={self.query}",
36
+ f"https://doaj.org/search/articles?ref=homepage&q={self.query}",
37
+ f"https://www.ted.com/search?q={self.query}",
38
+ f"https://en.citizendium.org/wiki?search={self.query}",
39
+ f"https://www.jstor.org/action/doBasicSearch?Query={self.query}",
40
+ f"https://archive.org/search.php?query={self.query}",
41
+ f"https://search.scielo.org/?q={self.query}",
42
+ f"https://paperswithcode.com/search?q={self.query}",
43
+ f"https://www.reddit.com/search/?q={self.query}",
44
+ f"https://huggingface.co/models?search={self.query}",
45
+ f"https://huggingface.co/datasets?search={self.query}",
46
+ f"https://machinelearningmastery.com/?s={self.query}",
47
+ f"https://www.kaggle.com/search?q={self.query}",
48
+ f"https://towardsdatascience.com/search?q={self.query}",
49
+ f"https://github.com/search?q={self.query}",
50
+ f"https://stackoverflow.com/search?q={self.query}",
51
+ f"https://www.youtube.com/results?search_query={self.query}",
52
+ f"https://www.slideshare.net/search/slideshow?searchfrom=header&q={self.query}"
53
+ ]
54
+
55
+ else:
56
+ self.search_sites = search_sites
57
+
58
+
59
+ def start_requests(self):
60
+ if not self.query:
61
+ logger.error("No search query provided in start_requests")
62
+ return
63
+
64
+ self.allowed_domains = list(set([urlparse(url).netloc for url in self.search_sites]))
65
+ logger.info(f"Starting requests for query: {self.query}")
66
+
67
+ for url in self.search_sites:
68
+ yield scrapy.Request(
69
+ url,
70
+ callback=self.parse,
71
+ meta={
72
+ 'dont_retry': True,
73
+ 'handle_httpstatus_list': [302, 403, 404, 420, 429, 500, 503],
74
+ 'depth': 1 # Start at depth 1
75
+ },
76
+ errback=self.errback_httpbin
77
+ )
78
+
79
+ def parse(self, response):
80
+ depth = response.meta.get('depth', 1)
81
+ if depth > self.max_depth:
82
+ logger.debug(f"Reached max depth at {response.url}")
83
+ return
84
+
85
+ logger.info(f"Parsing response from {response.url} at depth {depth}")
86
+
87
+ try:
88
+ soup = BeautifulSoup(response.text, 'html.parser')
89
+
90
+ # Check for irrelevant or blocked content
91
+ if any(term in soup.text.lower() for term in ['captcha', 'verification', 'no items found', 'no results', 'access denied']):
92
+ logger.warning(f"Irrelevant page detected: {response.url}")
93
+ return
94
+
95
+ title = soup.find('title').get_text().strip() if soup.find('title') else 'No title'
96
+ meta_description = soup.find('meta', {'name': 'description'})
97
+ meta_description = meta_description['content'].strip() if meta_description else 'No description'
98
+
99
+ content = self.extract_main_content(soup)
100
+ summary = self.generate_summary(content, 200)
101
+ total_links = len(soup.find_all('a', href=True))
102
+ content_length = len(content.split())
103
+
104
+ if content_length < 100:
105
+ logger.info(f"Content too short ({content_length} words) for {response.url}")
106
+ return
107
+
108
+ item = MySearchEngineItem()
109
+ item['title'] = title
110
+ item['link'] = response.url
111
+ item['content'] = content
112
+ item['summary'] = summary
113
+ item['meta'] = {
114
+ 'description': meta_description,
115
+ 'total_links': total_links,
116
+ 'content_length': content_length,
117
+ 'domain': urlparse(response.url).netloc,
118
+ }
119
+ yield item
120
+
121
+ # Limit the number of links per page
122
+ links = soup.find_all('a', href=True)
123
+ random.shuffle(links)
124
+ links = links[:self.max_links_per_page] # Limit the number of links
125
+
126
+ for link in links:
127
+ href = link.get('href')
128
+ full_url = urljoin(response.url, href)
129
+ if self.is_valid_link(full_url):
130
+ logger.debug(f"Following link: {full_url}")
131
+ yield scrapy.Request(
132
+ url=full_url,
133
+ callback=self.parse,
134
+ meta={'depth': depth + 1},
135
+ errback=self.errback_httpbin
136
+ )
137
+ except Exception as e:
138
+ logger.error(f"Error parsing {response.url}: {str(e)}")
139
+ logger.error(traceback.format_exc())
140
+
141
+ def extract_main_content(self, soup):
142
+ for element in soup(['script', 'style', 'nav', 'header', 'footer']):
143
+ element.decompose()
144
+
145
+ main_content = soup.find('main') or soup.find('article') or soup.find('div', class_='content')
146
+
147
+ if main_content:
148
+ return ' '.join(main_content.stripped_strings)
149
+
150
+ paragraphs = soup.find_all('p')
151
+ return ' '.join([p.get_text().strip() for p in paragraphs])
152
+
153
+ def generate_summary(self, content, max_length=200):
154
+ sentences = re.split(r'(?<=[.!?])\s+', content)
155
+ summary = ""
156
+ for sentence in sentences:
157
+ if len(summary) + len(sentence) <= max_length:
158
+ summary += sentence + " "
159
+ else:
160
+ break
161
+ return summary.strip()
162
+
163
+ def is_valid_link(self, url):
164
+ parsed_url = urlparse(url)
165
+ return any(domain in parsed_url.netloc for domain in self.allowed_domains)
166
+
167
+ def errback_httpbin(self, failure):
168
+ logger.error(f"Error on {failure.request.url}: {str(failure.value)}")
169
+ logger.error(traceback.format_exc())
170
+
171
+ if failure.check(ConnectionRefusedError):
172
+ logger.warning(f"Connection refused: {failure.request.url}")
173
+ elif failure.check(TimeoutError, TCPTimedOutError):
174
+ logger.warning(f"Timeout: {failure.request.url}")
175
+ else:
176
+ logger.error(f"Failed to process: {failure.request.url}")
my_search_engine/scrapy.cfg ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Automatically created by: scrapy startproject
2
+ #
3
+ # For more information about the [deploy] section see:
4
+ # https://scrapyd.readthedocs.io/en/latest/deploy.html
5
+
6
+ [settings]
7
+ default = my_search_engine.settings
8
+
9
+ [deploy]
10
+ #url = http://localhost:6800/
11
+ project = my_search_engine
ranking.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ranking.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import pandas as pd
6
+ from sentence_transformers import SentenceTransformer, util
7
+ import numpy as np
8
+ from sklearn.preprocessing import MinMaxScaler
9
+ from collections import Counter
10
+ import re
11
+ import string
12
+ from collections import Counter
13
+ from sklearn.feature_extraction.text import TfidfVectorizer
14
+ from nltk.corpus import stopwords
15
+ from nltk.stem import WordNetLemmatizer
16
+ from nltk.tokenize import word_tokenize
17
+ import spacy
18
+
19
+ def truncate_text(text, max_length=1024):
20
+ tokens = text.split()
21
+ if len(tokens) > max_length:
22
+ return ' '.join(tokens[:max_length])
23
+ return text
24
+
25
+ class RankingNN(nn.Module):
26
+ def __init__(self, input_size=7):
27
+ super(RankingNN, self).__init__()
28
+ self.fc1 = nn.Linear(input_size, 64)
29
+ self.fc2 = nn.Linear(64, 32)
30
+ self.fc3 = nn.Linear(32, 16)
31
+ self.fc4 = nn.Linear(16, 1)
32
+ self.dropout = nn.Dropout(0.2)
33
+
34
+ def forward(self, x):
35
+ x = torch.relu(self.fc1(x))
36
+ x = self.dropout(x)
37
+ x = torch.relu(self.fc2(x))
38
+ x = self.dropout(x)
39
+ x = torch.relu(self.fc3(x))
40
+ x = self.fc4(x)
41
+ return x
42
+
43
+ transformer_model = SentenceTransformer('all-MiniLM-L6-v2')
44
+ ranking_model = RankingNN()
45
+ optimizer = optim.Adam(ranking_model.parameters(), lr=0.001, weight_decay=1e-5)
46
+ criterion = nn.MSELoss()
47
+ scaler = MinMaxScaler()
48
+
49
+ # Download necessary resources
50
+ import nltk
51
+ nltk.download('punkt')
52
+ nltk.download('stopwords')
53
+ nltk.download('wordnet')
54
+
55
+ # Initialize resources
56
+ stop_words = set(stopwords.words('english'))
57
+ lemmatizer = WordNetLemmatizer()
58
+ nlp = spacy.load("en_core_web_sm") # Small model to keep compute low
59
+
60
+ def preprocess_text(text):
61
+ """
62
+ Preprocess the input text by lowercasing, removing punctuation, and filtering out stopwords.
63
+ Lemmatization is applied as well.
64
+ """
65
+ # Lowercase the text
66
+ text = text.lower()
67
+
68
+ # Remove punctuation using regex
69
+ text = re.sub(r'[' + string.punctuation + ']', ' ', text)
70
+
71
+ # Tokenize the text into words
72
+ words = word_tokenize(text)
73
+
74
+ # Lemmatize, filter out stopwords and non-alphabetic words
75
+ processed_words = [lemmatizer.lemmatize(word) for word in words if word.isalpha() and word not in stop_words]
76
+
77
+ return processed_words
78
+
79
+ def extract_named_entities(text):
80
+ """
81
+ Extract named entities (e.g., people, organizations, locations) from the text.
82
+ """
83
+ doc = nlp(text)
84
+ named_entities = [ent.text for ent in doc.ents if ent.label_ in {"PERSON", "ORG", "GPE", "LOC"}]
85
+ return named_entities
86
+
87
+ def extract_keywords_tfidf(corpus, text, n=5):
88
+ """
89
+ Extract keywords from the text using TF-IDF, combined with Named Entity Recognition and lemmatization.
90
+ """
91
+ # Preprocess the text and the entire corpus
92
+ preprocessed_texts = [' '.join(preprocess_text(doc)) for doc in corpus]
93
+ preprocessed_text = ' '.join(preprocess_text(text))
94
+
95
+ # Named entities extraction
96
+ named_entities = extract_named_entities(text)
97
+
98
+ # Use TF-IDF vectorizer to find the most important words
99
+ vectorizer = TfidfVectorizer(max_features=1000) # Keep it light, max 1000 features
100
+ X = vectorizer.fit_transform(preprocessed_texts)
101
+
102
+ # Get the feature names (i.e., the words)
103
+ feature_names = vectorizer.get_feature_names_out()
104
+
105
+ # Transform the current text into TF-IDF scores
106
+ response = vectorizer.transform([preprocessed_text])
107
+ tfidf_scores = zip(feature_names, response.toarray()[0])
108
+
109
+ # Sort by TF-IDF score
110
+ sorted_tfidf = sorted(tfidf_scores, key=lambda x: x[1], reverse=True)
111
+
112
+ # Combine top TF-IDF words with named entities for more richness
113
+ keywords = [word for word, score in sorted_tfidf[:n]]
114
+ combined_keywords = keywords + named_entities
115
+
116
+ return combined_keywords[:n]
117
+
118
+ def extract_keywords(text, corpus, n=5):
119
+ """
120
+ Wrapper function that combines preprocessing, TF-IDF, and Named Entity Recognition to extract top N keywords.
121
+ """
122
+ if not text.strip():
123
+ return []
124
+
125
+ # Extract keywords using the TF-IDF based approach
126
+ keywords = extract_keywords_tfidf(corpus, text, n)
127
+
128
+ # If no meaningful keywords are found, fallback to keyword frequency
129
+ if not keywords:
130
+ return extract_fallback_keywords(text, n)
131
+
132
+ return keywords
133
+
134
+ def extract_fallback_keywords(text, n=5):
135
+ """
136
+ Fallback method to extract keywords based on word frequency in case TF-IDF or NER fails.
137
+ """
138
+ words = preprocess_text(text)
139
+ word_freq = Counter(words)
140
+ return [word for word, _ in word_freq.most_common(n)]
141
+
142
+ def calculate_keyword_overlap(query_keywords, result_keywords):
143
+ if len(query_keywords) == 0:
144
+ return 0 # No keywords in query, so overlap is 0
145
+ return len(set(query_keywords) & set(result_keywords)) / len(query_keywords)
146
+
147
+ def train_ranking_model(query, results, corpus=None, epochs=1):
148
+ query = truncate_text(query)
149
+ if not results:
150
+ print("No results available. Skipping training.")
151
+ return []
152
+
153
+ if corpus is None:
154
+ # If no corpus is provided, use results as a fallback
155
+ corpus = [truncate_text(result['content']) for result in results if 'content' in result]
156
+
157
+ query_embedding = transformer_model.encode(query)
158
+ query_keywords = extract_keywords(query, corpus)
159
+
160
+ training_data = []
161
+ target_scores = []
162
+
163
+ for result in results:
164
+ # Truncate content
165
+ content = truncate_text(result['content'])
166
+ content_embedding = transformer_model.encode(content)
167
+
168
+ # Handle missing 'title' and 'meta' fields with default values, and truncate
169
+ title = truncate_text(result.get('title', ''))
170
+ title_embedding = transformer_model.encode(title)
171
+
172
+ meta_description = truncate_text(result.get('meta', {}).get('description', ''))
173
+ meta_description_embedding = transformer_model.encode(meta_description)
174
+
175
+ content_similarity = util.pytorch_cos_sim(query_embedding, content_embedding).item()
176
+ title_similarity = util.pytorch_cos_sim(query_embedding, title_embedding).item()
177
+ meta_description_similarity = util.pytorch_cos_sim(query_embedding, meta_description_embedding).item()
178
+
179
+ # Handle missing metadata by providing default values
180
+ content_length = result.get('meta', {}).get('content_length', 0)
181
+ total_links = result.get('meta', {}).get('total_links', 0)
182
+
183
+ result_keywords = extract_keywords(content, corpus)
184
+ keyword_overlap = calculate_keyword_overlap(query_keywords, result_keywords)
185
+ domain_authority = get_domain_authority(result.get('link', ''))
186
+
187
+ features = [
188
+ content_similarity, title_similarity, meta_description_similarity,
189
+ content_length, total_links, keyword_overlap, domain_authority
190
+ ]
191
+
192
+ training_data.append(features)
193
+
194
+ target_score = (0.4 * content_similarity + 0.3 * title_similarity +
195
+ 0.2 * meta_description_similarity + 0.1 * keyword_overlap)
196
+ target_scores.append(target_score)
197
+
198
+ # Normalize features
199
+ training_data = scaler.fit_transform(training_data)
200
+ training_data_tensor = torch.tensor(training_data, dtype=torch.float32)
201
+ target_scores_tensor = torch.tensor(target_scores, dtype=torch.float32).unsqueeze(1)
202
+
203
+ # Training loop
204
+ for epoch in range(epochs):
205
+ optimizer.zero_grad()
206
+ predicted_scores = ranking_model(training_data_tensor)
207
+ loss = criterion(predicted_scores, target_scores_tensor)
208
+ loss.backward()
209
+ optimizer.step()
210
+
211
+ if (epoch + 1) % 5 == 0:
212
+ print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")
213
+
214
+ # Predict final scores and rank results
215
+ with torch.no_grad():
216
+ final_scores = ranking_model(training_data_tensor).squeeze().tolist()
217
+
218
+ # Ensure final_scores is always a list
219
+ if isinstance(final_scores, float):
220
+ final_scores = [final_scores]
221
+
222
+ for result, score in zip(results, final_scores):
223
+ result['predicted_score'] = score
224
+
225
+ ranked_results = sorted(results, key=lambda x: x['predicted_score'], reverse=True)
226
+ return ranked_results
227
+
228
+ def get_domain_authority(url):
229
+ # Placeholder function - replace with actual domain authority data if available
230
+ high_authority_domains = ['arxiv.org', 'ncbi.nlm.nih.gov', 'nature.com', 'science.org']
231
+ medium_authority_domains = ['wikipedia.org', 'stackexchange.com', 'github.com']
232
+
233
+ for domain in high_authority_domains:
234
+ if domain in url:
235
+ return 1.0
236
+ for domain in medium_authority_domains:
237
+ if domain in url:
238
+ return 0.7
239
+ return 0.5
test_agent.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test_agent.py
2
+
3
+ import logging
4
+ from twisted.internet import reactor, defer, threads
5
+ from agent import AutonomousWebAgent
6
+ from ToTSearch import ToTSearch
7
+
8
+ # Configure logging
9
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
10
+
11
+ # Initialize the logger
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Suppress detailed logs for some libraries (like Scrapy or Transformers)
15
+ logging.getLogger('scrapy').setLevel(logging.ERROR)
16
+ logging.getLogger('transformers').setLevel(logging.ERROR)
17
+ logging.getLogger('twisted').setLevel(logging.ERROR)
18
+
19
+ import warnings
20
+ warnings.filterwarnings("ignore", category=FutureWarning)
21
+
22
+
23
+ class TestAgent:
24
+ def __init__(self):
25
+ # Initialize the AutonomousWebAgent
26
+ state_size = 7 # word_count, link_count, header_count, semantic_similarity, image_count, script_count, css_count
27
+ action_size = 3 # 0: Click Link, 1: Summarize, 2: RAG Generate
28
+ num_options = 3 # 0: Search, 1: Summarize, 2: RAG Generate
29
+
30
+ self.agent = AutonomousWebAgent(
31
+ state_size=state_size,
32
+ action_size=action_size,
33
+ num_options=num_options,
34
+ hidden_size=64,
35
+ learning_rate=0.001,
36
+ gamma=0.99,
37
+ epsilon=1.0,
38
+ epsilon_decay=0.995,
39
+ epsilon_min=0.01,
40
+ knowledge_base_path='knowledge_base.json'
41
+ )
42
+
43
+ # Initialize ToTSearch with the agent
44
+ self.tot_search = ToTSearch(self.agent)
45
+
46
+ # Few-shot examples for Tree of Thoughts
47
+ self.few_shot_examples = [
48
+ {
49
+ "query": "What are the effects of climate change on biodiversity?",
50
+ "thoughts": [
51
+ "Loss of habitats due to rising sea levels and changing temperatures",
52
+ "Disruption of ecosystems and food chains",
53
+ "Increased extinction rates for vulnerable species"
54
+ ],
55
+ "answer": "Climate change significantly impacts biodiversity through habitat loss, ecosystem disruption, and increased extinction rates. Rising temperatures and sea levels alter habitats, forcing species to adapt or migrate. This disrupts established ecosystems and food chains. Species unable to adapt quickly face a higher risk of extinction, particularly those with specialized habitats or limited ranges."
56
+ },
57
+ {
58
+ "query": "How can we promote sustainable energy adoption?",
59
+ "thoughts": [
60
+ "Government policies and incentives",
61
+ "Public awareness and education campaigns",
62
+ "Technological advancements and cost reduction"
63
+ ],
64
+ "answer": "Promoting sustainable energy adoption requires a multi-faceted approach. Government policies and incentives can encourage both businesses and individuals to switch to renewable sources. Public awareness and education campaigns help people understand the importance and benefits of sustainable energy. Continued technological advancements and cost reductions make sustainable energy more accessible and economically viable for widespread adoption."
65
+ }
66
+ ]
67
+
68
+ @defer.inlineCallbacks
69
+ def process_query(self, query, is_few_shot=False):
70
+ logger.info(f"Processing query: {query}")
71
+ try:
72
+ if is_few_shot:
73
+ few_shot_prompt = self.create_few_shot_prompt(query)
74
+ enhanced_query = f"{few_shot_prompt}\n\nQuery: {query}"
75
+ logger.debug(f"Enhanced query for few-shot learning: {enhanced_query[:100]}...")
76
+ final_answer = yield self.tot_search.search(enhanced_query)
77
+ else:
78
+ final_answer = yield self.tot_search.search(query)
79
+
80
+ logger.info(f"Final answer for '{query}':")
81
+ logger.info(final_answer)
82
+
83
+ yield self.agent.add_document_to_kb(title=f"ToT Search Result: {query}", content=final_answer)
84
+
85
+ yield self.agent.replay_worker(batch_size=32)
86
+ yield self.agent.replay_manager(batch_size=32)
87
+
88
+ return final_answer
89
+ except Exception as e:
90
+ logger.error(f"Error processing query '{query}': {str(e)}", exc_info=True)
91
+ return f"An error occurred: {str(e)}"
92
+
93
+ def create_few_shot_prompt(self, query):
94
+ prompt = "Here are some examples of how to approach queries using a Tree of Thoughts:\n\n"
95
+ for example in self.few_shot_examples:
96
+ prompt += f"Query: {example['query']}\n"
97
+ prompt += "Thoughts:\n"
98
+ for thought in example['thoughts']:
99
+ prompt += f"- {thought}\n"
100
+ prompt += f"Answer: {example['answer']}\n\n"
101
+ prompt += f"Now, let's approach the following query in a similar manner:\n\nQuery: {query}\n"
102
+ return prompt
103
+
104
+ def save_models(self):
105
+ self.agent.save_worker_model("worker_model_final.pth")
106
+ self.agent.save_manager_model("manager_model_final.pth")
107
+ logger.info("Agent models saved.")
108
+
109
+
110
+ def get_user_input():
111
+ return input("Enter your query (or 'quit' to exit): ")
112
+
113
+
114
+ @defer.inlineCallbacks
115
+ def run_test_session():
116
+ test_agent = TestAgent()
117
+
118
+ logger.info("Starting few-shot learning phase...")
119
+ for example in test_agent.few_shot_examples:
120
+ logger.info(f"Processing few-shot example: {example['query']}")
121
+ try:
122
+ yield test_agent.process_query(example['query'], is_few_shot=True)
123
+ except Exception as e:
124
+ logger.error(f"Error in few-shot learning: {str(e)}", exc_info=True)
125
+
126
+ logger.info("Few-shot learning phase completed. Starting interactive session.")
127
+
128
+ while True:
129
+ query = yield threads.deferToThread(get_user_input)
130
+
131
+ if query.lower() == 'quit':
132
+ break
133
+
134
+ try:
135
+ answer = yield test_agent.process_query(query)
136
+ print("\nAgent's response:")
137
+ print(answer)
138
+ print("\n" + "-"*50 + "\n")
139
+ except Exception as e:
140
+ logger.error(f"Error in interactive session: {str(e)}", exc_info=True)
141
+
142
+ test_agent.save_models()
143
+ reactor.stop()
144
+
145
+
146
+ if __name__ == "__main__":
147
+ reactor.callWhenRunning(run_test_session)
148
+ reactor.run()
train_agent.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_agent.py
2
+
3
+ from twisted.internet import reactor, defer, task
4
+ from agent import AutonomousWebAgent
5
+ import random
6
+ import logging
7
+ import sys
8
+ import time
9
+ import codecs
10
+
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.INFO,
13
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
14
+ handlers=[
15
+ logging.FileHandler("agent_training.log", encoding='utf-8'),
16
+ logging.StreamHandler(codecs.getwriter('utf-8')(sys.stdout.buffer))
17
+ ])
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # List of diverse queries
22
+ QUERIES = [
23
+ "machine learning", "climate change", "renewable energy", "artificial intelligence",
24
+ "quantum computing", "blockchain technology", "gene editing", "virtual reality",
25
+ "space exploration", "cybersecurity", "autonomous vehicles", "Internet of Things",
26
+ "3D printing", "nanotechnology", "bioinformatics", "augmented reality", "robotics",
27
+ "data science", "neural networks", "cloud computing", "edge computing", "5G technology",
28
+ "cryptocurrency", "natural language processing", "computer vision"
29
+ ]
30
+
31
+ @defer.inlineCallbacks
32
+ def train_agent():
33
+ # Updated state_size to 7 to match the feature extraction in AutonomousWebAgent
34
+ state_size = 7 # word_count, link_count, header_count, semantic_similarity, image_count, script_count, css_count
35
+ action_size = 3 # 0: Click Link, 1: Summarize, 2: RAG Generate
36
+ num_options = 3 # 0: Search, 1: Summarize, 2: RAG Generate
37
+
38
+ # Initialize the AutonomousWebAgent with the required arguments
39
+ agent = AutonomousWebAgent(
40
+ state_size=state_size,
41
+ action_size=action_size,
42
+ num_options=num_options, # Added parameter for HRL
43
+ hidden_size=64,
44
+ learning_rate=0.001,
45
+ gamma=0.99,
46
+ epsilon=1.0,
47
+ epsilon_decay=0.995,
48
+ epsilon_min=0.01,
49
+ knowledge_base_path='knowledge_base.json'
50
+ )
51
+ logger.info(f"Initialized AutonomousWebAgent with state_size={state_size}, action_size={action_size}, num_options={num_options}")
52
+
53
+ num_episodes = 10 # Adjust as needed
54
+ total_training_reward = 0
55
+ start_time = time.time()
56
+
57
+ for episode in range(num_episodes):
58
+ query = random.choice(QUERIES)
59
+ logger.info(f"Starting episode {episode + 1}/{num_episodes} with query: {query}")
60
+ episode_start_time = time.time()
61
+
62
+ try:
63
+ # Initiate the search process
64
+ search_deferred = agent.search(query)
65
+ search_deferred.addTimeout(300, reactor) # 5-minute timeout
66
+ total_reward = yield search_deferred
67
+ total_training_reward += total_reward
68
+ episode_duration = time.time() - episode_start_time
69
+ logger.info(f"Episode {episode + 1}/{num_episodes}, Query: {query}, Total Reward: {total_reward}, Duration: {episode_duration:.2f} seconds")
70
+ except defer.TimeoutError:
71
+ logger.error(f"Episode {episode + 1} timed out")
72
+ total_reward = -1 # Assign a negative reward for timeout
73
+ total_training_reward += total_reward
74
+ except Exception as e:
75
+ logger.error(f"Error in episode {episode + 1}: {str(e)}", exc_info=True)
76
+ total_reward = -1 # Assign a negative reward for errors
77
+ total_training_reward += total_reward
78
+
79
+ # Update target models periodically
80
+ if (episode + 1) % 10 == 0:
81
+ logger.info(f"Updating target models at episode {episode + 1}")
82
+ agent.update_worker_target_model()
83
+ agent.update_manager_target_model()
84
+ agent.manager.update_target_model()
85
+
86
+ # Log overall progress
87
+ progress = (episode + 1) / num_episodes
88
+ elapsed_time = time.time() - start_time
89
+ estimated_total_time = elapsed_time / progress if progress > 0 else 0
90
+ remaining_time = estimated_total_time - elapsed_time
91
+ logger.info(f"Overall progress: {progress:.2%}, Elapsed time: {elapsed_time:.2f}s, Estimated remaining time: {remaining_time:.2f}s")
92
+
93
+ total_training_time = time.time() - start_time
94
+ average_reward = total_training_reward / num_episodes
95
+ logger.info(f"Training completed. Total reward: {total_training_reward}, Average reward per episode: {average_reward:.2f}")
96
+ logger.info(f"Total training time: {total_training_time:.2f} seconds")
97
+ logger.info("Saving models.")
98
+
99
+ # Save both Worker and Manager models
100
+ agent.save_worker_model("worker_model.pth")
101
+ agent.save_manager_model("manager_model.pth")
102
+ agent.save("web_agent_model.pth") # Assuming this saves additional components if needed
103
+
104
+ if reactor.running:
105
+ logger.info("Stopping reactor")
106
+ reactor.stop()
107
+
108
+ def main():
109
+ logger.info("Starting agent training")
110
+ d = task.deferLater(reactor, 0, train_agent)
111
+ d.addErrback(lambda failure: logger.error(f"An error occurred: {failure}", exc_info=True))
112
+ d.addBoth(lambda _: reactor.stop())
113
+ reactor.run()
114
+
115
+ if __name__ == "__main__":
116
+ main()