Upload 20 files
Browse files- ToTSearch.py +219 -0
- agent.py +1082 -0
- lightbulb.py +1696 -0
- main_menu.py +61 -0
- mcts.py +225 -0
- my_search_engine/my_search_engine/__init__.py +0 -0
- my_search_engine/my_search_engine/__pycache__/__init__.cpython-312.pyc +0 -0
- my_search_engine/my_search_engine/__pycache__/items.cpython-312.pyc +0 -0
- my_search_engine/my_search_engine/items.py +10 -0
- my_search_engine/my_search_engine/middlewares.py +56 -0
- my_search_engine/my_search_engine/pipelines.py +53 -0
- my_search_engine/my_search_engine/settings.py +49 -0
- my_search_engine/my_search_engine/spiders/__init__.py +4 -0
- my_search_engine/my_search_engine/spiders/__pycache__/__init__.cpython-312.pyc +0 -0
- my_search_engine/my_search_engine/spiders/__pycache__/search_spider.cpython-312.pyc +0 -0
- my_search_engine/my_search_engine/spiders/search_spider.py +176 -0
- my_search_engine/scrapy.cfg +11 -0
- ranking.py +239 -0
- test_agent.py +148 -0
- train_agent.py +116 -0
ToTSearch.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ToTSearch.py
|
2 |
+
import random
|
3 |
+
from typing import List, Dict, Any, Generator
|
4 |
+
from sentence_transformers import SentenceTransformer, util
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from twisted.internet import defer
|
8 |
+
from agent import AutonomousWebAgent
|
9 |
+
from mcts import MCTS, MCTSNode
|
10 |
+
import logging
|
11 |
+
from twisted.internet.defer import Deferred
|
12 |
+
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
class ToTNode:
|
17 |
+
def __init__(self, thought, parent=None):
|
18 |
+
self.thought = thought
|
19 |
+
self.parent = parent
|
20 |
+
self.children = []
|
21 |
+
self.visits = 0
|
22 |
+
self.value = 0
|
23 |
+
self.search_results = []
|
24 |
+
self.mcts_node = None
|
25 |
+
|
26 |
+
def add_child(self, child_thought):
|
27 |
+
child = ToTNode(child_thought, self)
|
28 |
+
self.children.append(child)
|
29 |
+
return child
|
30 |
+
|
31 |
+
def update(self, reward):
|
32 |
+
self.visits += 1
|
33 |
+
self.value += reward
|
34 |
+
|
35 |
+
class ToTSearch:
|
36 |
+
def __init__(self, agent: AutonomousWebAgent, model='all-MiniLM-L6-v2', max_depth=3, num_thoughts=3, num_simulations=100):
|
37 |
+
self.agent = agent
|
38 |
+
self.model = SentenceTransformer(model)
|
39 |
+
self.max_depth = max_depth
|
40 |
+
self.num_thoughts = num_thoughts
|
41 |
+
self.num_simulations = num_simulations
|
42 |
+
self.mcts = MCTS(initial_state="", num_simulations=num_simulations)
|
43 |
+
|
44 |
+
def generate_thoughts(self, query: str) -> List[str]:
|
45 |
+
prompt = f"""Given the query "{query}", generate {self.num_thoughts} distinct thoughts or approaches to address it.
|
46 |
+
Each thought should be a complete sentence and offer a unique perspective or solution path."""
|
47 |
+
|
48 |
+
thoughts = self.agent.generate_text(prompt).split('\n')
|
49 |
+
return [thought.strip() for thought in thoughts if thought.strip()]
|
50 |
+
|
51 |
+
def expand_thought(self, thought: str) -> List[str]:
|
52 |
+
prompt = f"""Expand on the following thought: "{thought}"
|
53 |
+
Generate {self.num_thoughts} more specific sub-thoughts or considerations.
|
54 |
+
Each sub-thought should be a complete sentence and offer additional detail or a new angle."""
|
55 |
+
|
56 |
+
expansions = self.agent.generate_text(prompt).split('\n')
|
57 |
+
return [exp.strip() for exp in expansions if exp.strip()]
|
58 |
+
|
59 |
+
def evaluate_thought(self, thought: str, query: str) -> float:
|
60 |
+
thought_embedding = self.model.encode(thought)
|
61 |
+
query_embedding = self.model.encode(query)
|
62 |
+
return util.pytorch_cos_sim(thought_embedding, query_embedding).item()
|
63 |
+
|
64 |
+
@defer.inlineCallbacks
|
65 |
+
def search_and_augment(self, thought: str) -> Generator[Deferred, Any, List[Dict[str, Any]]]:
|
66 |
+
search_results = yield self.agent.retrieve_from_web(thought)
|
67 |
+
for result in search_results:
|
68 |
+
result['originating_thought'] = thought
|
69 |
+
defer.returnValue(search_results)
|
70 |
+
|
71 |
+
def select(self, node: ToTNode) -> ToTNode:
|
72 |
+
while node.children:
|
73 |
+
# Choose a node with zero visits or select based on the value/visits ratio
|
74 |
+
if any(child.visits == 0 for child in node.children):
|
75 |
+
zero_visit_nodes = [child for child in node.children if child.visits == 0]
|
76 |
+
selected_node = random.choice(zero_visit_nodes)
|
77 |
+
logger.debug(f"Selected node with 0 visits: {selected_node.thought}")
|
78 |
+
return selected_node
|
79 |
+
else:
|
80 |
+
selected_node = max(node.children, key=lambda child: (child.value / child.visits) if child.visits > 0 else float('-inf'))
|
81 |
+
logger.debug(f"Selected node based on value/visits ratio: {selected_node.thought}, value: {selected_node.value}, visits: {selected_node.visits}")
|
82 |
+
return selected_node
|
83 |
+
return node
|
84 |
+
|
85 |
+
|
86 |
+
def expand(self, node: ToTNode, query: str) -> ToTNode:
|
87 |
+
if not node.children and len(node.thought.split()) > 2:
|
88 |
+
expansions = self.expand_thought(node.thought)
|
89 |
+
for expansion in expansions:
|
90 |
+
node.add_child(expansion)
|
91 |
+
return random.choice(node.children) if node.children else node
|
92 |
+
|
93 |
+
@defer.inlineCallbacks
|
94 |
+
def simulate(self, node: ToTNode, query: str):
|
95 |
+
current_node = node
|
96 |
+
depth = 0
|
97 |
+
while depth < self.max_depth:
|
98 |
+
if not current_node.children:
|
99 |
+
break
|
100 |
+
current_node = random.choice(current_node.children)
|
101 |
+
depth += 1
|
102 |
+
|
103 |
+
logger.debug(f"Simulating for thought: {current_node.thought}")
|
104 |
+
|
105 |
+
search_results = yield self.search_and_augment(current_node.thought)
|
106 |
+
current_node.search_results = search_results
|
107 |
+
|
108 |
+
logger.debug(f"Search results count: {len(search_results)}")
|
109 |
+
|
110 |
+
ranked_results = self.agent.calculate_reward(current_node.thought, query)
|
111 |
+
logger.debug(f"Ranked results: {ranked_results}")
|
112 |
+
|
113 |
+
mcts_node = MCTSNode(current_node.thought)
|
114 |
+
current_node.mcts_node = mcts_node
|
115 |
+
mcts_total_reward = 0
|
116 |
+
|
117 |
+
for _ in range(self.num_simulations):
|
118 |
+
mcts_reward = yield self.mcts.simulate(mcts_node)
|
119 |
+
mcts_total_reward += mcts_reward
|
120 |
+
self.mcts.backpropagate(mcts_node, mcts_reward)
|
121 |
+
|
122 |
+
logger.debug(f"MCTS node visits: {mcts_node.visits}, total reward: {mcts_total_reward}")
|
123 |
+
|
124 |
+
if mcts_node.visits == 0 or ranked_results == 0:
|
125 |
+
logger.warning(f"Avoiding division by zero. MCTS visits: {mcts_node.visits}, Ranked results: {ranked_results}")
|
126 |
+
combined_reward = 0
|
127 |
+
else:
|
128 |
+
combined_reward = (ranked_results + mcts_value) / 2
|
129 |
+
|
130 |
+
if mcts_node.visits > 0:
|
131 |
+
mcts_value = mcts_total_reward / mcts_node.visits
|
132 |
+
logger.debug(f"MCTS value: {mcts_value}")
|
133 |
+
else:
|
134 |
+
mcts_value = 0
|
135 |
+
logger.warning(f"MCTS node has 0 visits, assigning value 0")
|
136 |
+
|
137 |
+
combined_reward = (ranked_results + mcts_value) / 2
|
138 |
+
logger.debug(f"Combined reward: {combined_reward}")
|
139 |
+
|
140 |
+
defer.returnValue(combined_reward)
|
141 |
+
|
142 |
+
def backpropagate(self, node: ToTNode, reward: float):
|
143 |
+
while node:
|
144 |
+
node.update(reward)
|
145 |
+
node = node.parent
|
146 |
+
|
147 |
+
@defer.inlineCallbacks
|
148 |
+
def tot_search(self, query: str) -> Generator[Deferred, Any, ToTNode]:
|
149 |
+
root = ToTNode(query)
|
150 |
+
for _ in range(self.num_simulations):
|
151 |
+
node = self.select(root)
|
152 |
+
node = self.expand(node, query)
|
153 |
+
reward = yield self.simulate(node, query)
|
154 |
+
self.backpropagate(node, reward)
|
155 |
+
|
156 |
+
# Update agent's experience replay
|
157 |
+
state = self.agent.extract_features(node.thought, query)
|
158 |
+
next_state = self.agent.extract_features(node.children[0].thought if node.children else node.thought, query)
|
159 |
+
self.agent.remember_worker(state, 0, reward, next_state, False)
|
160 |
+
|
161 |
+
# Perform agent's replay to update RL models
|
162 |
+
self.agent.replay_worker()
|
163 |
+
self.agent.replay_manager()
|
164 |
+
|
165 |
+
defer.returnValue(root)
|
166 |
+
|
167 |
+
def get_best_path(self, root: ToTNode) -> List[str]:
|
168 |
+
path = [root.thought]
|
169 |
+
current = root
|
170 |
+
while current.children:
|
171 |
+
current = max(current.children, key=lambda child: child.value / child.visits if child.visits > 0 else float('-inf'))
|
172 |
+
path.append(current.thought)
|
173 |
+
return path
|
174 |
+
|
175 |
+
@defer.inlineCallbacks
|
176 |
+
def synthesize_results(self, root: ToTNode, query: str) -> Generator[Deferred, Any, str]:
|
177 |
+
best_path = self.get_best_path(root)
|
178 |
+
all_results = []
|
179 |
+
|
180 |
+
def collect_results(node):
|
181 |
+
all_results.extend(node.search_results)
|
182 |
+
for child in node.children:
|
183 |
+
collect_results(child)
|
184 |
+
|
185 |
+
collect_results(root)
|
186 |
+
|
187 |
+
# Sort results by relevance
|
188 |
+
all_results.sort(key=lambda x: self.evaluate_thought(x['content'], query), reverse=True)
|
189 |
+
|
190 |
+
# Generate a summary of the top results
|
191 |
+
top_results = all_results[:5] # Adjust the number as needed
|
192 |
+
summary_prompt = f"Synthesize the following information into a coherent answer for the query '{query}':\n\n"
|
193 |
+
summary_prompt += f"Thought path: {' -> '.join(best_path)}\n\n"
|
194 |
+
for result in top_results:
|
195 |
+
summary_prompt += f"- {result['content'][:200]}...\n"
|
196 |
+
|
197 |
+
# Use the agent's RAG capabilities for final answer generation
|
198 |
+
final_answer = yield self.agent.generate_rag_response(query, top_results)
|
199 |
+
|
200 |
+
# Save the generated answer and thought path to the agent's knowledge base
|
201 |
+
self.agent.add_document_to_kb(
|
202 |
+
title=f"ToT Search Result: {query}",
|
203 |
+
content=final_answer,
|
204 |
+
metadata={"thought_path": best_path}
|
205 |
+
)
|
206 |
+
|
207 |
+
defer.returnValue(final_answer)
|
208 |
+
|
209 |
+
@defer.inlineCallbacks
|
210 |
+
def search(self, query: str) -> Generator[Deferred, Any, str]:
|
211 |
+
logger.info(f"Starting ToT search for query: {query}")
|
212 |
+
root = yield self.tot_search(query)
|
213 |
+
final_answer = yield self.synthesize_results(root, query)
|
214 |
+
logger.info(f"ToT search completed for query: {query}")
|
215 |
+
defer.returnValue(final_answer)
|
216 |
+
|
217 |
+
# Usage example:
|
218 |
+
# tot_search = ToTSearch(agent)
|
219 |
+
# final_answer = yield tot_search.search("What are the latest advancements in renewable energy?")
|
agent.py
ADDED
@@ -0,0 +1,1082 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# agent.py
|
3 |
+
# agent.py
|
4 |
+
import numpy as np
|
5 |
+
from mcts import MCTS
|
6 |
+
from ranking import train_ranking_model
|
7 |
+
from bs4 import BeautifulSoup
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.optim as optim
|
11 |
+
from collections import deque, OrderedDict
|
12 |
+
import random
|
13 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
14 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
15 |
+
from sentence_transformers import SentenceTransformer
|
16 |
+
import hashlib
|
17 |
+
from twisted.internet import defer
|
18 |
+
import logging
|
19 |
+
import json
|
20 |
+
import os
|
21 |
+
from urllib.parse import urlparse
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
# ==========================
|
26 |
+
# Prioritized Experience Replay
|
27 |
+
# ==========================
|
28 |
+
|
29 |
+
class SumTree:
|
30 |
+
"""
|
31 |
+
SumTree data structure where the parent’s value is the sum of its children.
|
32 |
+
Leaf nodes contain the priorities of experiences.
|
33 |
+
"""
|
34 |
+
def __init__(self, capacity):
|
35 |
+
self.capacity = capacity
|
36 |
+
self.tree = np.zeros(2 * capacity - 1)
|
37 |
+
self.data = np.zeros(capacity, dtype=object)
|
38 |
+
self.write = 0
|
39 |
+
self.n_entries = 0
|
40 |
+
|
41 |
+
def _propagate(self, idx, change):
|
42 |
+
parent = (idx - 1) // 2
|
43 |
+
self.tree[parent] += change
|
44 |
+
if parent != 0:
|
45 |
+
self._propagate(parent, change)
|
46 |
+
|
47 |
+
def _retrieve(self, idx, s):
|
48 |
+
left = 2 * idx + 1
|
49 |
+
right = left + 1
|
50 |
+
|
51 |
+
if left >= len(self.tree):
|
52 |
+
return idx
|
53 |
+
|
54 |
+
if s <= self.tree[left]:
|
55 |
+
return self._retrieve(left, s)
|
56 |
+
else:
|
57 |
+
return self._retrieve(right, s - self.tree[left])
|
58 |
+
|
59 |
+
def total(self):
|
60 |
+
return self.tree[0]
|
61 |
+
|
62 |
+
def add(self, p, data):
|
63 |
+
idx = self.write + self.capacity - 1
|
64 |
+
|
65 |
+
self.data[self.write] = data
|
66 |
+
self.update(idx, p)
|
67 |
+
|
68 |
+
self.write += 1
|
69 |
+
if self.write >= self.capacity:
|
70 |
+
self.write = 0
|
71 |
+
|
72 |
+
if self.n_entries < self.capacity:
|
73 |
+
self.n_entries += 1
|
74 |
+
|
75 |
+
def update(self, idx, p):
|
76 |
+
change = p - self.tree[idx]
|
77 |
+
self.tree[idx] = p
|
78 |
+
self._propagate(idx, change)
|
79 |
+
|
80 |
+
def get(self, s):
|
81 |
+
idx = self._retrieve(0, s)
|
82 |
+
data_idx = idx - self.capacity + 1
|
83 |
+
|
84 |
+
return (idx, self.tree[idx], self.data[data_idx])
|
85 |
+
|
86 |
+
class PrioritizedReplayMemory:
|
87 |
+
def __init__(self, capacity, alpha=0.6):
|
88 |
+
self.tree = SumTree(capacity)
|
89 |
+
self.alpha = alpha # [0,1] convert the importance of TD error to priority
|
90 |
+
self.epsilon = 1e-6 # small amount to avoid zero priority
|
91 |
+
|
92 |
+
def add(self, error, sample):
|
93 |
+
p = (np.abs(error) + self.epsilon) ** self.alpha
|
94 |
+
self.tree.add(p, sample)
|
95 |
+
|
96 |
+
def sample(self, batch_size, beta=0.4):
|
97 |
+
batch = []
|
98 |
+
idxs = []
|
99 |
+
segment = self.tree.total() / batch_size
|
100 |
+
priorities = []
|
101 |
+
|
102 |
+
for i in range(batch_size):
|
103 |
+
a = segment * i
|
104 |
+
b = segment * (i + 1)
|
105 |
+
s = random.uniform(a, b)
|
106 |
+
idx, p, data = self.tree.get(s)
|
107 |
+
batch.append(data)
|
108 |
+
idxs.append(idx)
|
109 |
+
priorities.append(p)
|
110 |
+
|
111 |
+
total = self.tree.total()
|
112 |
+
probs = priorities / total
|
113 |
+
weights = (self.tree.n_entries * probs) ** (-beta)
|
114 |
+
weights /= weights.max()
|
115 |
+
return batch, idxs, weights
|
116 |
+
|
117 |
+
def update(self, idx, error):
|
118 |
+
p = (np.abs(error) + self.epsilon) ** self.alpha
|
119 |
+
self.tree.update(idx, p)
|
120 |
+
|
121 |
+
# ==========================
|
122 |
+
# Hierarchical Reinforcement Learning (HRL)
|
123 |
+
# ==========================
|
124 |
+
|
125 |
+
class ManagerModel(nn.Module):
|
126 |
+
"""
|
127 |
+
High-level policy model (Manager) that decides which option to execute.
|
128 |
+
"""
|
129 |
+
def __init__(self, input_size, hidden_size, num_options):
|
130 |
+
super(ManagerModel, self).__init__()
|
131 |
+
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
|
132 |
+
self.fc = nn.Linear(hidden_size, num_options)
|
133 |
+
self.layer_norm = nn.LayerNorm(hidden_size)
|
134 |
+
|
135 |
+
def forward(self, x, hidden=None):
|
136 |
+
if x.dim() == 2:
|
137 |
+
x = x.unsqueeze(1) # Add a time dimension
|
138 |
+
out, hidden = self.lstm(x, hidden)
|
139 |
+
last_output = out[:, -1, :]
|
140 |
+
last_output = self.layer_norm(last_output)
|
141 |
+
option_scores = self.fc(last_output)
|
142 |
+
return option_scores, hidden
|
143 |
+
|
144 |
+
class WorkerModel(nn.Module):
|
145 |
+
"""
|
146 |
+
Low-level policy model (Worker) that executes actions based on the selected option.
|
147 |
+
"""
|
148 |
+
def __init__(self, input_size, hidden_size, action_size):
|
149 |
+
super(WorkerModel, self).__init__()
|
150 |
+
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
|
151 |
+
self.fc = nn.Linear(hidden_size, action_size)
|
152 |
+
self.layer_norm = nn.LayerNorm(hidden_size)
|
153 |
+
self.action_size = action_size # Store action_size for reference
|
154 |
+
|
155 |
+
def forward(self, x, hidden=None):
|
156 |
+
if x.dim() == 2:
|
157 |
+
x = x.unsqueeze(1) # Add a time dimension
|
158 |
+
out, hidden = self.lstm(x, hidden)
|
159 |
+
last_output = out[:, -1, :]
|
160 |
+
last_output = self.layer_norm(last_output)
|
161 |
+
action_scores = self.fc(last_output)
|
162 |
+
return action_scores, hidden
|
163 |
+
|
164 |
+
def act(self, state, epsilon=0.1):
|
165 |
+
"""
|
166 |
+
Selects an action using epsilon-greedy policy.
|
167 |
+
"""
|
168 |
+
if random.random() < epsilon:
|
169 |
+
action = random.randint(0, self.action_size - 1)
|
170 |
+
return action
|
171 |
+
state = torch.FloatTensor(state).unsqueeze(0).to(next(self.parameters()).device)
|
172 |
+
with torch.no_grad():
|
173 |
+
action_scores, _ = self(state)
|
174 |
+
action = torch.argmax(action_scores, dim=1).item()
|
175 |
+
return action
|
176 |
+
|
177 |
+
# ==========================
|
178 |
+
# RAGSummarizer Class
|
179 |
+
# ==========================
|
180 |
+
|
181 |
+
class RAGSummarizer:
|
182 |
+
def __init__(self, model_name='gpt2', embedding_model='all-MiniLM-L6-v2',
|
183 |
+
max_length=150, cache_capacity=100, persistent_cache_path='rag_cache.json'):
|
184 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
185 |
+
self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
186 |
+
self.model = GPT2LMHeadModel.from_pretrained(model_name).to(self.device)
|
187 |
+
# Explicitly set the device for SentenceTransformer
|
188 |
+
self.embedding_model = SentenceTransformer(embedding_model, device=self.device)
|
189 |
+
self.max_length = max_length
|
190 |
+
self.cache = LRUCache(cache_capacity)
|
191 |
+
self.persistent_cache_path = persistent_cache_path
|
192 |
+
self.load_persistent_cache()
|
193 |
+
|
194 |
+
def load_persistent_cache(self):
|
195 |
+
if os.path.exists(self.persistent_cache_path):
|
196 |
+
with open(self.persistent_cache_path, 'r', encoding='utf-8') as f:
|
197 |
+
try:
|
198 |
+
persistent_data = json.load(f)
|
199 |
+
for key, value in persistent_data.items():
|
200 |
+
self.cache.put(key, value)
|
201 |
+
logger.info(f"Loaded persistent cache with {len(persistent_data)} entries.")
|
202 |
+
except json.JSONDecodeError:
|
203 |
+
logger.warning("Persistent cache file is corrupted. Initializing empty cache.")
|
204 |
+
else:
|
205 |
+
logger.info("No persistent cache found. Starting with empty cache.")
|
206 |
+
|
207 |
+
def save_persistent_cache(self):
|
208 |
+
with open(self.persistent_cache_path, 'w', encoding='utf-8') as f:
|
209 |
+
json.dump(self.cache.cache, f, indent=2)
|
210 |
+
logger.info(f"Saved persistent cache with {len(self.cache.cache)} entries.")
|
211 |
+
|
212 |
+
def save_rag_data(self, query, chunks, embeddings):
|
213 |
+
data = {
|
214 |
+
"query": query,
|
215 |
+
"chunks": chunks,
|
216 |
+
"embeddings": embeddings.tolist()
|
217 |
+
}
|
218 |
+
|
219 |
+
os.makedirs("rag_data", exist_ok=True)
|
220 |
+
|
221 |
+
filename = f"rag_data/{hash(query)}.json"
|
222 |
+
with open(filename, 'w') as f:
|
223 |
+
json.dump(data, f, indent=2)
|
224 |
+
|
225 |
+
logger.info(f"Saved RAG data to {filename}")
|
226 |
+
|
227 |
+
def split_into_chunks(self, text, chunk_size=200):
|
228 |
+
words = text.split()
|
229 |
+
return [' '.join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]
|
230 |
+
|
231 |
+
def retrieve_relevant_chunks(self, query, chunks, embeddings, top_k=3):
|
232 |
+
if embeddings.size(0) == 0:
|
233 |
+
logger.warning("Embeddings are empty. Cannot retrieve relevant chunks.")
|
234 |
+
return []
|
235 |
+
query_embedding = self.embedding_model.encode([query], convert_to_tensor=True)
|
236 |
+
cosine_scores = cosine_similarity(query_embedding.cpu().numpy(), embeddings.cpu().numpy())[0]
|
237 |
+
top_indices = cosine_scores.argsort()[-top_k:][::-1]
|
238 |
+
# Ensure indices are within bounds
|
239 |
+
top_indices = [idx for idx in top_indices if idx < len(chunks)]
|
240 |
+
return [chunks[i] for i in top_indices]
|
241 |
+
|
242 |
+
def get_embeddings(self, chunks):
|
243 |
+
# Implement batch processing
|
244 |
+
batch_size = 32
|
245 |
+
embeddings = []
|
246 |
+
for i in range(0, len(chunks), batch_size):
|
247 |
+
batch = chunks[i:i+batch_size]
|
248 |
+
batch_embeddings = self.embedding_model.encode(batch, convert_to_tensor=True)
|
249 |
+
embeddings.append(batch_embeddings)
|
250 |
+
if embeddings:
|
251 |
+
return torch.cat(embeddings, dim=0)
|
252 |
+
else:
|
253 |
+
return torch.tensor([])
|
254 |
+
|
255 |
+
def generate_summary(self, query, relevant_chunks):
|
256 |
+
cache_key = hashlib.md5((query + ''.join(relevant_chunks)).encode()).hexdigest()
|
257 |
+
cached_summary = self.cache.get(cache_key)
|
258 |
+
if cached_summary:
|
259 |
+
return cached_summary
|
260 |
+
|
261 |
+
context = " ".join(relevant_chunks)
|
262 |
+
prompt = f"Summarize the following content in relation to '{query}': {context}\n\nSummary:"
|
263 |
+
|
264 |
+
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
|
265 |
+
|
266 |
+
try:
|
267 |
+
output = self.model.generate(
|
268 |
+
input_ids,
|
269 |
+
max_length=input_ids.shape[1] + self.max_length,
|
270 |
+
num_return_sequences=1,
|
271 |
+
no_repeat_ngram_size=2,
|
272 |
+
top_k=50,
|
273 |
+
top_p=0.95,
|
274 |
+
temperature=0.7,
|
275 |
+
early_stopping=True
|
276 |
+
)
|
277 |
+
except Exception as e:
|
278 |
+
logger.error(f"Error during summary generation: {str(e)}")
|
279 |
+
return "Summary generation failed."
|
280 |
+
|
281 |
+
self.save_rag_data(query, relevant_chunks, self.get_embeddings(relevant_chunks))
|
282 |
+
|
283 |
+
summary = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
284 |
+
summary = summary.split("Summary:")[-1].strip()
|
285 |
+
|
286 |
+
self.cache.put(cache_key, summary)
|
287 |
+
self.save_persistent_cache()
|
288 |
+
|
289 |
+
return summary
|
290 |
+
|
291 |
+
# ==========================
|
292 |
+
# WorldModel Class
|
293 |
+
# ==========================
|
294 |
+
|
295 |
+
class WorldModel(nn.Module):
|
296 |
+
def __init__(self, input_size, hidden_size, output_size, num_layers=2, dropout=0.3):
|
297 |
+
super(WorldModel, self).__init__()
|
298 |
+
self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layers,
|
299 |
+
batch_first=True, dropout=dropout)
|
300 |
+
self.fc = nn.Linear(hidden_size, output_size)
|
301 |
+
self.value_head = nn.Linear(hidden_size, 1)
|
302 |
+
self.layer_norm = nn.LayerNorm(hidden_size)
|
303 |
+
|
304 |
+
def forward(self, x, hidden=None):
|
305 |
+
if x.dim() == 2:
|
306 |
+
x = x.unsqueeze(1) # Add a time dimension
|
307 |
+
out, hidden = self.lstm(x, hidden)
|
308 |
+
last_output = out[:, -1, :]
|
309 |
+
last_output = self.layer_norm(last_output)
|
310 |
+
action_scores = self.fc(last_output)
|
311 |
+
state_value = self.value_head(last_output)
|
312 |
+
return action_scores, state_value, hidden
|
313 |
+
|
314 |
+
# ==========================
|
315 |
+
# Manager and Worker Classes for HRL
|
316 |
+
# ==========================
|
317 |
+
|
318 |
+
class Manager:
|
319 |
+
def __init__(self, state_size, num_options, hidden_size=128, learning_rate=0.001, gamma=0.99,
|
320 |
+
epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.01, memory_capacity=1000, device=torch.device("cpu")):
|
321 |
+
self.state_size = state_size
|
322 |
+
self.num_options = num_options
|
323 |
+
self.gamma = gamma
|
324 |
+
self.epsilon = epsilon
|
325 |
+
self.epsilon_decay = epsilon_decay
|
326 |
+
self.epsilon_min = epsilon_min
|
327 |
+
self.device = device
|
328 |
+
|
329 |
+
self.model = ManagerModel(state_size, hidden_size, num_options).to(self.device)
|
330 |
+
self.target_model = ManagerModel(state_size, hidden_size, num_options).to(self.device)
|
331 |
+
self.optimizer = optim.AdamW(self.model.parameters(), lr=learning_rate, weight_decay=1e-5)
|
332 |
+
self.loss_fn = nn.MSELoss()
|
333 |
+
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min', patience=5, factor=0.5, verbose=True)
|
334 |
+
|
335 |
+
self.memory = PrioritizedReplayMemory(capacity=memory_capacity, alpha=0.6)
|
336 |
+
|
337 |
+
self.update_target_model()
|
338 |
+
|
339 |
+
def update_target_model(self):
|
340 |
+
self.target_model.load_state_dict(self.model.state_dict())
|
341 |
+
|
342 |
+
def remember(self, state, option, reward, next_state, done, td_error):
|
343 |
+
sample = (state, option, reward, next_state, done)
|
344 |
+
self.memory.add(td_error, sample)
|
345 |
+
|
346 |
+
def act(self, state):
|
347 |
+
if random.random() < self.epsilon:
|
348 |
+
option = random.randint(0, self.num_options - 1)
|
349 |
+
return option
|
350 |
+
state = torch.FloatTensor(state).unsqueeze(0).to(self.model.lstm.weight.device)
|
351 |
+
with torch.no_grad():
|
352 |
+
option_scores, _ = self.model(state)
|
353 |
+
option = torch.argmax(option_scores).item()
|
354 |
+
return option
|
355 |
+
|
356 |
+
def replay(self, batch_size, beta=0.4):
|
357 |
+
if self.memory.tree.n_entries < batch_size:
|
358 |
+
return
|
359 |
+
batch, idxs, weights = self.memory.sample(batch_size, beta)
|
360 |
+
states, options, rewards, next_states, dones = zip(*batch)
|
361 |
+
|
362 |
+
states = torch.FloatTensor(states).to(self.model.lstm.weight.device)
|
363 |
+
next_states = torch.FloatTensor(next_states).to(self.model.lstm.weight.device)
|
364 |
+
options = torch.LongTensor(options).unsqueeze(1).to(self.model.lstm.weight.device)
|
365 |
+
rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.model.lstm.weight.device)
|
366 |
+
dones = torch.FloatTensor(dones).unsqueeze(1).to(self.model.lstm.weight.device)
|
367 |
+
weights = torch.FloatTensor(weights).unsqueeze(1).to(self.model.lstm.weight.device)
|
368 |
+
|
369 |
+
# Current Q values
|
370 |
+
current_q_values, _ = self.model(states)
|
371 |
+
current_q_values = current_q_values.gather(1, options)
|
372 |
+
|
373 |
+
# Target Q values
|
374 |
+
with torch.no_grad():
|
375 |
+
next_q_values, _ = self.target_model(next_states)
|
376 |
+
max_next_q_values = next_q_values.max(1)[0].unsqueeze(1)
|
377 |
+
target_q_values = rewards + (self.gamma * max_next_q_values * (1 - dones))
|
378 |
+
|
379 |
+
# Compute TD errors
|
380 |
+
td_errors = target_q_values - current_q_values
|
381 |
+
|
382 |
+
# Compute loss with importance-sampling weights
|
383 |
+
loss = (td_errors.pow(2) * weights).mean()
|
384 |
+
|
385 |
+
# Optimize the model
|
386 |
+
self.optimizer.zero_grad()
|
387 |
+
loss.backward()
|
388 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
389 |
+
self.optimizer.step()
|
390 |
+
self.scheduler.step(loss.item())
|
391 |
+
|
392 |
+
# Update priorities
|
393 |
+
td_errors_np = td_errors.detach().cpu().numpy().squeeze()
|
394 |
+
for idx, td_error in zip(idxs, td_errors_np):
|
395 |
+
self.memory.update(idx, np.abs(td_error))
|
396 |
+
|
397 |
+
# Decay epsilon
|
398 |
+
if self.epsilon > self.epsilon_min:
|
399 |
+
self.epsilon *= self.epsilon_decay
|
400 |
+
|
401 |
+
# ==========================
|
402 |
+
# AutonomousWebAgent Class
|
403 |
+
# ==========================
|
404 |
+
|
405 |
+
def truncate_text(text, max_length=1024):
|
406 |
+
tokens = text.split()
|
407 |
+
if len(tokens) > max_length:
|
408 |
+
return ' '.join(tokens[:max_length])
|
409 |
+
return text
|
410 |
+
|
411 |
+
class AutonomousWebAgent:
|
412 |
+
def __init__(self, state_size, action_size, num_options, hidden_size=64, learning_rate=0.001,
|
413 |
+
gamma=0.99, epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.01,
|
414 |
+
knowledge_base_path='knowledge_base.json'):
|
415 |
+
self.state_size = state_size
|
416 |
+
self.action_size = action_size
|
417 |
+
self.num_options = num_options # Number of high-level options for HRL
|
418 |
+
self.gamma = gamma
|
419 |
+
self.epsilon = epsilon
|
420 |
+
self.epsilon_decay = epsilon_decay
|
421 |
+
self.epsilon_min = epsilon_min
|
422 |
+
|
423 |
+
# Initialize RAGSummarizer first to get the device
|
424 |
+
self.summarizer = RAGSummarizer()
|
425 |
+
self.device = self.summarizer.device
|
426 |
+
|
427 |
+
# Initialize SentenceTransformer with the correct device
|
428 |
+
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device)
|
429 |
+
|
430 |
+
# Low-level (Worker) Model
|
431 |
+
self.worker_model = WorldModel(state_size, hidden_size, action_size).to(self.device)
|
432 |
+
self.worker_target_model = WorldModel(state_size, hidden_size, action_size).to(self.device)
|
433 |
+
self.worker_optimizer = optim.AdamW(self.worker_model.parameters(), lr=learning_rate, weight_decay=1e-5)
|
434 |
+
self.worker_loss_fn = nn.MSELoss()
|
435 |
+
self.worker_scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.worker_optimizer, 'min', patience=5, factor=0.5, verbose=True)
|
436 |
+
self.worker_memory = PrioritizedReplayMemory(capacity=2000, alpha=0.6)
|
437 |
+
self.update_worker_target_model()
|
438 |
+
|
439 |
+
# High-level (Manager) Model
|
440 |
+
self.manager = Manager(state_size, num_options, hidden_size=128, learning_rate=learning_rate,
|
441 |
+
gamma=gamma, epsilon=epsilon, epsilon_decay=epsilon_decay,
|
442 |
+
epsilon_min=epsilon_min, memory_capacity=1000, device=self.device)
|
443 |
+
|
444 |
+
self.mcts = MCTS(initial_state="")
|
445 |
+
logger.info(f"Initialized AutonomousWebAgent with state_size={state_size}, action_size={action_size}, num_options={num_options}")
|
446 |
+
|
447 |
+
self.site_performance = {} # {(site, query): performance_score}
|
448 |
+
|
449 |
+
# List of all search sites (base URLs without the query)
|
450 |
+
self.all_search_sites = [
|
451 |
+
"https://en.wikibooks.org/w/index.php?search=",
|
452 |
+
"https://en.wikiversity.org/w/index.php?search=",
|
453 |
+
"https://commons.wikimedia.org/w/index.php?search=",
|
454 |
+
"https://stackexchange.com/search?q=",
|
455 |
+
"https://arxiv.org/search/?query=",
|
456 |
+
"https://www.ncbi.nlm.nih.gov/pmc/?term=",
|
457 |
+
"https://www.gutenberg.org/ebooks/search/?query=",
|
458 |
+
"https://openlibrary.org/search?q=",
|
459 |
+
"https://doaj.org/search/articles?ref=homepage&q=",
|
460 |
+
"https://www.ted.com/search?q=",
|
461 |
+
"https://en.citizendium.org/wiki?search=",
|
462 |
+
"https://www.jstor.org/action/doBasicSearch?Query=",
|
463 |
+
"https://archive.org/search.php?query=",
|
464 |
+
"https://search.scielo.org/?q=",
|
465 |
+
"https://paperswithcode.com/search?q=",
|
466 |
+
"https://www.reddit.com/search/?q=",
|
467 |
+
"https://huggingface.co/models?search=",
|
468 |
+
"https://huggingface.co/datasets?search=",
|
469 |
+
"https://machinelearningmastery.com/?s=",
|
470 |
+
"https://www.kaggle.com/search?q=",
|
471 |
+
"https://towardsdatascience.com/search?q=",
|
472 |
+
"https://github.com/search?q=",
|
473 |
+
"https://stackoverflow.com/search?q=",
|
474 |
+
"https://www.youtube.com/results?search_query=",
|
475 |
+
"https://www.slideshare.net/search/slideshow?searchfrom=header&q="
|
476 |
+
]
|
477 |
+
|
478 |
+
# Initialize Knowledge Base
|
479 |
+
self.knowledge_base_path = knowledge_base_path
|
480 |
+
self.knowledge_base = []
|
481 |
+
self.kb_embeddings = None
|
482 |
+
self.load_knowledge_base()
|
483 |
+
|
484 |
+
# Additional Features for State Representation
|
485 |
+
self.additional_features = ['image_count', 'script_count', 'css_count']
|
486 |
+
|
487 |
+
def save(self, filename):
|
488 |
+
"""Save the entire agent state."""
|
489 |
+
state = {
|
490 |
+
'worker_model': self.worker_model.state_dict(),
|
491 |
+
'manager_model': self.manager.model.state_dict(),
|
492 |
+
'worker_optimizer': self.worker_optimizer.state_dict(),
|
493 |
+
'manager_optimizer': self.manager.optimizer.state_dict(),
|
494 |
+
'epsilon': self.epsilon
|
495 |
+
}
|
496 |
+
torch.save(state, filename)
|
497 |
+
logger.info(f"Saved agent state to {filename}")
|
498 |
+
|
499 |
+
def load(self, filename):
|
500 |
+
"""Load the entire agent state."""
|
501 |
+
state = torch.load(filename, map_location=self.device)
|
502 |
+
self.worker_model.load_state_dict(state['worker_model'])
|
503 |
+
self.manager.model.load_state_dict(state['manager_model'])
|
504 |
+
self.worker_optimizer.load_state_dict(state['worker_optimizer'])
|
505 |
+
self.manager.optimizer.load_state_dict(state['manager_optimizer'])
|
506 |
+
self.epsilon = state['epsilon']
|
507 |
+
logger.info(f"Loaded agent state from {filename}")
|
508 |
+
|
509 |
+
# ==========================
|
510 |
+
# Text Generation
|
511 |
+
# ==========================
|
512 |
+
|
513 |
+
def generate_text(self, prompt):
|
514 |
+
# Use the RAGSummarizer to generate text
|
515 |
+
chunks = self.summarizer.split_into_chunks(prompt)
|
516 |
+
embeddings = self.summarizer.get_embeddings(chunks)
|
517 |
+
relevant_chunks = self.summarizer.retrieve_relevant_chunks(query=prompt, chunks=chunks, embeddings=embeddings)
|
518 |
+
generated_text = self.summarizer.generate_summary(prompt, relevant_chunks)
|
519 |
+
return generated_text
|
520 |
+
|
521 |
+
# ==========================
|
522 |
+
# Knowledge Base Management
|
523 |
+
# ==========================
|
524 |
+
|
525 |
+
def load_knowledge_base(self):
|
526 |
+
if not os.path.exists(self.knowledge_base_path):
|
527 |
+
logger.warning(f"Knowledge base file {self.knowledge_base_path} does not exist. Initializing empty KB.")
|
528 |
+
self.knowledge_base = []
|
529 |
+
self.kb_embeddings = torch.tensor([]).to(self.device)
|
530 |
+
return
|
531 |
+
|
532 |
+
with open(self.knowledge_base_path, 'r', encoding='utf-8') as f:
|
533 |
+
self.knowledge_base = json.load(f)
|
534 |
+
|
535 |
+
if self.knowledge_base:
|
536 |
+
texts = [doc['content'] for doc in self.knowledge_base]
|
537 |
+
self.kb_embeddings = self.embedding_model.encode(texts, convert_to_tensor=True)
|
538 |
+
logger.info(f"Loaded {len(self.knowledge_base)} documents into the knowledge base.")
|
539 |
+
else:
|
540 |
+
self.kb_embeddings = torch.tensor([]).to(self.device)
|
541 |
+
logger.info("Knowledge base is empty.")
|
542 |
+
|
543 |
+
def save_knowledge_base(self):
|
544 |
+
with open(self.knowledge_base_path, 'w', encoding='utf-8') as f:
|
545 |
+
json.dump(self.knowledge_base, f, indent=2)
|
546 |
+
logger.info(f"Knowledge base saved with {len(self.knowledge_base)} documents.")
|
547 |
+
|
548 |
+
def add_document_to_kb(self, title, content, metadata=None):
|
549 |
+
document = {
|
550 |
+
"title": title,
|
551 |
+
"content": content,
|
552 |
+
"metadata": metadata or {}
|
553 |
+
}
|
554 |
+
self.knowledge_base.append(document)
|
555 |
+
# Update embeddings
|
556 |
+
new_embedding = self.embedding_model.encode([content], convert_to_tensor=True).to(self.device)
|
557 |
+
if self.kb_embeddings.numel() == 0:
|
558 |
+
self.kb_embeddings = new_embedding
|
559 |
+
else:
|
560 |
+
self.kb_embeddings = torch.cat([self.kb_embeddings, new_embedding], dim=0)
|
561 |
+
# Save to knowledge base
|
562 |
+
self.save_knowledge_base()
|
563 |
+
logger.info(f"Added new document to knowledge base: {title}")
|
564 |
+
|
565 |
+
def retrieve_from_kb(self, query, top_k=5):
|
566 |
+
if not self.knowledge_base:
|
567 |
+
logger.warning("Knowledge base is empty. No documents to retrieve.")
|
568 |
+
return []
|
569 |
+
|
570 |
+
query_embedding = self.embedding_model.encode([query], convert_to_tensor=True).to(self.device)
|
571 |
+
|
572 |
+
if self.kb_embeddings is None or self.kb_embeddings.numel() == 0:
|
573 |
+
logger.warning("Knowledge base embeddings are empty. No documents to retrieve.")
|
574 |
+
return []
|
575 |
+
|
576 |
+
if query_embedding.size(1) != self.kb_embeddings.size(1):
|
577 |
+
logger.error("Dimension mismatch between query embedding and KB embeddings.")
|
578 |
+
return []
|
579 |
+
|
580 |
+
cosine_scores = cosine_similarity(query_embedding.cpu().numpy(), self.kb_embeddings.cpu().numpy())[0]
|
581 |
+
top_indices = cosine_scores.argsort()[-top_k:][::-1]
|
582 |
+
|
583 |
+
# Ensure indices are within the knowledge_base length
|
584 |
+
top_indices = [idx for idx in top_indices if idx < len(self.knowledge_base)]
|
585 |
+
|
586 |
+
retrieved_docs = []
|
587 |
+
for idx in top_indices:
|
588 |
+
doc = self.knowledge_base[idx]
|
589 |
+
doc['score'] = cosine_scores[idx]
|
590 |
+
retrieved_docs.append(doc)
|
591 |
+
|
592 |
+
logger.info(f"Retrieved top {len(retrieved_docs)} documents from Knowledge Base for the query.")
|
593 |
+
return retrieved_docs
|
594 |
+
|
595 |
+
# ==========================
|
596 |
+
# RAG Integration
|
597 |
+
# ==========================
|
598 |
+
|
599 |
+
def retrieve_from_web(self, query, top_k=5):
|
600 |
+
logger.info(f"Performing web search for query: {query}")
|
601 |
+
mcts_iterations = self.calculate_mcts_iterations(np.zeros(self.state_size, dtype=np.float32))
|
602 |
+
self.mcts = MCTS(initial_state=query, num_simulations=mcts_iterations)
|
603 |
+
|
604 |
+
try:
|
605 |
+
new_query = yield self.mcts.run()
|
606 |
+
logger.debug(f"New query from MCTS: {new_query}")
|
607 |
+
# Select search sites
|
608 |
+
search_sites = self.select_search_sites(new_query)
|
609 |
+
results = yield self.mcts.web_search(new_query, search_sites)
|
610 |
+
logger.debug(f"Web search completed. Found {len(results)} results")
|
611 |
+
return results[:top_k] if results else []
|
612 |
+
except Exception as e:
|
613 |
+
logger.error(f"Error during MCTS or web search: {str(e)}", exc_info=True)
|
614 |
+
return []
|
615 |
+
|
616 |
+
def combine_documents(self, kb_docs, web_docs):
|
617 |
+
combined = kb_docs + web_docs
|
618 |
+
logger.info(f"Combined {len(kb_docs)} KB documents and {len(web_docs)} Web documents.")
|
619 |
+
return combined
|
620 |
+
|
621 |
+
def save_llm_training_data(self, query, content, summary=None, link=None, title=None):
|
622 |
+
data = {
|
623 |
+
"query": query,
|
624 |
+
"search_result": {
|
625 |
+
"link": link,
|
626 |
+
"title": title
|
627 |
+
},
|
628 |
+
"content": content,
|
629 |
+
"description": summary
|
630 |
+
}
|
631 |
+
|
632 |
+
os.makedirs("llm_training_data", exist_ok=True)
|
633 |
+
file_path = "llm_training_data/llm_training_data.jsonl"
|
634 |
+
|
635 |
+
# Append the new data as a new line in the JSONL file
|
636 |
+
with open(file_path, 'a', encoding='utf-8') as f:
|
637 |
+
json.dump(data, f)
|
638 |
+
f.write('\n')
|
639 |
+
|
640 |
+
logger.info(f"Appended LLM training data to {file_path}")
|
641 |
+
|
642 |
+
# ==========================
|
643 |
+
# Hierarchical RL Integration
|
644 |
+
# ==========================
|
645 |
+
|
646 |
+
def remember_manager(self, state, option, reward, next_state, done, td_error):
|
647 |
+
self.manager.remember(state, option, reward, next_state, done, td_error)
|
648 |
+
|
649 |
+
def remember_worker(self, state, action, reward, next_state, done):
|
650 |
+
self.worker_memory.add(reward, (state, action, reward, next_state, done))
|
651 |
+
|
652 |
+
# ==========================
|
653 |
+
# Action Selection and Execution
|
654 |
+
# ==========================
|
655 |
+
|
656 |
+
def act_manager(self, state):
|
657 |
+
option = self.manager.act(state)
|
658 |
+
return option
|
659 |
+
|
660 |
+
def act_worker(self, state):
|
661 |
+
action = self.worker_model.act(state, epsilon=self.epsilon)
|
662 |
+
return action
|
663 |
+
|
664 |
+
# ==========================
|
665 |
+
# Replay Methods
|
666 |
+
# ==========================
|
667 |
+
|
668 |
+
def replay_manager(self, batch_size=32, beta=0.4):
|
669 |
+
self.manager.replay(batch_size, beta)
|
670 |
+
|
671 |
+
def replay_worker(self, batch_size=32, beta=0.4):
|
672 |
+
result = self.worker_memory.replay(batch_size, beta)
|
673 |
+
if result is None:
|
674 |
+
return
|
675 |
+
batch, idxs, weights = result
|
676 |
+
if len(self.worker_memory.tree.data) >= batch_size:
|
677 |
+
batch, idxs, weights = self.worker_memory.sample(batch_size, beta)
|
678 |
+
states, actions, rewards, next_states, dones = zip(*batch)
|
679 |
+
|
680 |
+
states = torch.FloatTensor(states).to(self.worker_model.lstm.weight.device)
|
681 |
+
next_states = torch.FloatTensor(next_states).to(self.worker_model.lstm.weight.device)
|
682 |
+
actions = torch.LongTensor(actions).unsqueeze(1).to(self.worker_model.lstm.weight.device)
|
683 |
+
rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.worker_model.lstm.weight.device)
|
684 |
+
dones = torch.FloatTensor(dones).unsqueeze(1).to(self.worker_model.lstm.weight.device)
|
685 |
+
weights = torch.FloatTensor(weights).unsqueeze(1).to(self.worker_model.lstm.weight.device)
|
686 |
+
|
687 |
+
# Current Q values
|
688 |
+
current_q_values, _ = self.worker_model(states)
|
689 |
+
current_q_values = current_q_values.gather(1, actions)
|
690 |
+
|
691 |
+
# Target Q values
|
692 |
+
with torch.no_grad():
|
693 |
+
next_q_values, _ = self.worker_target_model(next_states)
|
694 |
+
max_next_q_values = next_q_values.max(1)[0].unsqueeze(1)
|
695 |
+
target_q_values = rewards + (self.gamma * max_next_q_values * (1 - dones))
|
696 |
+
|
697 |
+
# Compute TD errors
|
698 |
+
td_errors = target_q_values - current_q_values
|
699 |
+
|
700 |
+
# Compute loss with importance-sampling weights
|
701 |
+
loss = (td_errors.pow(2) * weights).mean()
|
702 |
+
|
703 |
+
# Optimize the model
|
704 |
+
self.worker_optimizer.zero_grad()
|
705 |
+
loss.backward()
|
706 |
+
torch.nn.utils.clip_grad_norm_(self.worker_model.parameters(), max_norm=1.0)
|
707 |
+
self.worker_optimizer.step()
|
708 |
+
self.worker_scheduler.step(loss.item())
|
709 |
+
|
710 |
+
# Update priorities
|
711 |
+
td_errors_np = td_errors.detach().cpu().numpy().squeeze()
|
712 |
+
for idx, td_error in zip(idxs, td_errors_np):
|
713 |
+
self.worker_memory.update(idx, np.abs(td_error))
|
714 |
+
|
715 |
+
# Decay epsilon
|
716 |
+
if self.epsilon > self.epsilon_min:
|
717 |
+
self.epsilon *= self.epsilon_decay
|
718 |
+
logger.debug(f"Updated epsilon to: {self.epsilon}")
|
719 |
+
|
720 |
+
# ==========================
|
721 |
+
# Load and Save Models
|
722 |
+
# ==========================
|
723 |
+
|
724 |
+
def load_worker_model(self, name):
|
725 |
+
self.worker_model.load_state_dict(torch.load(name, map_location=self.device))
|
726 |
+
logger.info(f"Loaded worker model weights from {name}")
|
727 |
+
|
728 |
+
def save_worker_model(self, name):
|
729 |
+
torch.save(self.worker_model.state_dict(), name)
|
730 |
+
logger.info(f"Saved worker model weights to {name}")
|
731 |
+
|
732 |
+
def load_manager_model(self, name):
|
733 |
+
self.manager.model.load_state_dict(torch.load(name, map_location=self.device))
|
734 |
+
self.manager.update_target_model()
|
735 |
+
logger.info(f"Loaded manager model weights from {name}")
|
736 |
+
|
737 |
+
def save_manager_model(self, name):
|
738 |
+
torch.save(self.manager.model.state_dict(), name)
|
739 |
+
logger.info(f"Saved manager model weights to {name}")
|
740 |
+
|
741 |
+
# ==========================
|
742 |
+
# Update Target Models
|
743 |
+
# ==========================
|
744 |
+
|
745 |
+
def update_worker_target_model(self):
|
746 |
+
self.worker_target_model.load_state_dict(self.worker_model.state_dict())
|
747 |
+
logger.info("Updated worker target model with current model weights")
|
748 |
+
|
749 |
+
def update_manager_target_model(self):
|
750 |
+
self.manager.update_target_model()
|
751 |
+
logger.info("Updated manager target model with current model weights")
|
752 |
+
|
753 |
+
# ==========================
|
754 |
+
# Feature Extraction
|
755 |
+
# ==========================
|
756 |
+
|
757 |
+
def extract_features(self, content, query):
|
758 |
+
content = truncate_text(content)
|
759 |
+
query = truncate_text(query)
|
760 |
+
soup = BeautifulSoup(content, 'html.parser')
|
761 |
+
text = soup.get_text()
|
762 |
+
word_count = len(text.split())
|
763 |
+
link_count = len(soup.find_all('a'))
|
764 |
+
header_count = len(soup.find_all(['h1', 'h2', 'h3', 'h4', 'h5', 'h6']))
|
765 |
+
|
766 |
+
# Calculate semantic similarity
|
767 |
+
text_embedding = self.embedding_model.encode([text], convert_to_tensor=True).to(self.device)
|
768 |
+
query_embedding = self.embedding_model.encode([query], convert_to_tensor=True).to(self.device)
|
769 |
+
semantic_similarity = cosine_similarity(text_embedding.cpu().numpy(), query_embedding.cpu().numpy())[0][0]
|
770 |
+
|
771 |
+
# Additional Features
|
772 |
+
image_count = len(soup.find_all('img'))
|
773 |
+
script_count = len(soup.find_all('script'))
|
774 |
+
css_count = len(soup.find_all('link', rel='stylesheet'))
|
775 |
+
|
776 |
+
return np.array([word_count, link_count, header_count, semantic_similarity, image_count, script_count, css_count])
|
777 |
+
|
778 |
+
# ==========================
|
779 |
+
# Reward Calculation
|
780 |
+
# ==========================
|
781 |
+
|
782 |
+
def calculate_reward(self, content, query):
|
783 |
+
try:
|
784 |
+
ranked_results = train_ranking_model(query, [{'content': content}])
|
785 |
+
logger.debug(f"Ranked results: {ranked_results}")
|
786 |
+
if ranked_results and isinstance(ranked_results[0], dict) and 'predicted_score' in ranked_results[0]:
|
787 |
+
reward = max(ranked_results[0]['predicted_score'], 0)
|
788 |
+
logger.debug(f"Calculated reward: {reward}")
|
789 |
+
return reward
|
790 |
+
else:
|
791 |
+
logger.warning(f"Invalid ranked results: {ranked_results}")
|
792 |
+
return 0
|
793 |
+
except Exception as e:
|
794 |
+
logger.error(f"Error in calculate_reward: {str(e)}", exc_info=True)
|
795 |
+
return 0
|
796 |
+
|
797 |
+
# ==========================
|
798 |
+
# Search Site Selection
|
799 |
+
# ==========================
|
800 |
+
|
801 |
+
def select_search_sites(self, query, num_sites=5):
|
802 |
+
# Select top sites based on past performance for this query
|
803 |
+
site_scores = {}
|
804 |
+
for (site, q), score in self.site_performance.items():
|
805 |
+
if q == query:
|
806 |
+
site_scores[site] = site_scores.get(site, 0) + score
|
807 |
+
if site_scores:
|
808 |
+
sorted_sites = sorted(site_scores.items(), key=lambda x: x[1], reverse=True)
|
809 |
+
top_sites = [site for site, score in sorted_sites[:num_sites]]
|
810 |
+
else:
|
811 |
+
# If no past data, select random sites
|
812 |
+
top_sites = random.sample(self.all_search_sites, num_sites)
|
813 |
+
# Construct full URLs with query
|
814 |
+
search_sites = [site + query for site in top_sites]
|
815 |
+
return search_sites
|
816 |
+
|
817 |
+
# ==========================
|
818 |
+
# Search Method with HRL
|
819 |
+
# ==========================
|
820 |
+
|
821 |
+
@defer.inlineCallbacks
|
822 |
+
def search(self, query, max_steps=2):
|
823 |
+
logger.info(f"Starting search for query: {query}")
|
824 |
+
state = np.zeros(self.state_size, dtype=np.float32)
|
825 |
+
total_reward = 0
|
826 |
+
content = ""
|
827 |
+
done = False
|
828 |
+
results = None
|
829 |
+
|
830 |
+
try:
|
831 |
+
# High-Level: Manager selects an option
|
832 |
+
option = self.act_manager(state)
|
833 |
+
logger.debug(f"Manager selected option: {option}")
|
834 |
+
|
835 |
+
# Execute the selected option
|
836 |
+
if option == 0: # Search Option
|
837 |
+
logger.debug("Executing Search Option")
|
838 |
+
results = yield self.retrieve_from_web(query)
|
839 |
+
if results:
|
840 |
+
content = results[0]['content']
|
841 |
+
site = urlparse(results[0]['link']).netloc
|
842 |
+
self.save_llm_training_data(
|
843 |
+
query,
|
844 |
+
content,
|
845 |
+
summary=results[0].get('summary'),
|
846 |
+
link=results[0].get('link'),
|
847 |
+
title=results[0].get('title')
|
848 |
+
)
|
849 |
+
self.add_document_to_kb(title=results[0].get('title', 'No Title'), content=content, metadata=results[0].get('meta', {}))
|
850 |
+
next_state = self.extract_features(content, query)
|
851 |
+
reward = self.calculate_reward(content, query)
|
852 |
+
logger.debug(f"Extracted features: {next_state}, Reward: {reward}")
|
853 |
+
# Update site performance
|
854 |
+
key = (site, query)
|
855 |
+
self.site_performance[key] = self.site_performance.get(key, 0) + reward
|
856 |
+
|
857 |
+
# Remember Manager's experience
|
858 |
+
self.remember_manager(state, option, reward, next_state, done, td_error=reward)
|
859 |
+
|
860 |
+
# Remember Worker's experience
|
861 |
+
self.remember_worker(state, 0, reward, next_state, done)
|
862 |
+
|
863 |
+
state = next_state.astype(np.float32)
|
864 |
+
total_reward += reward
|
865 |
+
|
866 |
+
else:
|
867 |
+
reward = -1
|
868 |
+
logger.warning(f"No results for query: {query}")
|
869 |
+
# Remember Manager's experience
|
870 |
+
self.remember_manager(state, option, reward, state, True, td_error=reward)
|
871 |
+
|
872 |
+
elif option == 1: # Summarize Option
|
873 |
+
logger.debug("Executing Summarize Option")
|
874 |
+
if content:
|
875 |
+
summary = self.summarizer.generate_summary(content, query)
|
876 |
+
self.save_llm_training_data(
|
877 |
+
query,
|
878 |
+
content,
|
879 |
+
summary=summary,
|
880 |
+
link=results[0].get('link') if results else None,
|
881 |
+
title=results[0].get('title') if results else None
|
882 |
+
)
|
883 |
+
reward = self.calculate_reward(summary, query)
|
884 |
+
next_state = self.extract_features(summary, query)
|
885 |
+
logger.info(f"Summary:\n{summary}")
|
886 |
+
logger.info(f"Summarized content. Reward: {reward}")
|
887 |
+
|
888 |
+
# Remember Manager's experience
|
889 |
+
self.remember_manager(state, option, reward, next_state, done, td_error=reward)
|
890 |
+
|
891 |
+
# Remember Worker's experience
|
892 |
+
self.remember_worker(state, 1, reward, next_state, done)
|
893 |
+
|
894 |
+
state = next_state.astype(np.float32)
|
895 |
+
total_reward += reward
|
896 |
+
else:
|
897 |
+
reward = -1
|
898 |
+
logger.warning("No content to summarize")
|
899 |
+
# Remember Manager's experience
|
900 |
+
self.remember_manager(state, option, reward, state, True, td_error=reward)
|
901 |
+
|
902 |
+
elif option == 2: # RAG-based Generation Option
|
903 |
+
logger.debug("Executing RAG-based Generation Option")
|
904 |
+
kb_docs = self.retrieve_from_kb(query, top_k=5)
|
905 |
+
web_docs = [] # Assuming web_docs are already retrieved
|
906 |
+
combined_docs = self.combine_documents(kb_docs, web_docs)
|
907 |
+
generated_output = self.generate_rag_response(query, combined_docs)
|
908 |
+
logger.info(f"Generated Output:\n{generated_output}")
|
909 |
+
self.save_llm_training_data(
|
910 |
+
query,
|
911 |
+
generated_output,
|
912 |
+
summary=None,
|
913 |
+
link=None,
|
914 |
+
title="RAG-generated response"
|
915 |
+
)
|
916 |
+
reward = self.calculate_reward(generated_output, query)
|
917 |
+
next_state = self.extract_features(generated_output, query)
|
918 |
+
|
919 |
+
# Remember Manager's experience
|
920 |
+
self.remember_manager(state, option, reward, next_state, done, td_error=reward)
|
921 |
+
|
922 |
+
# Remember Worker's experience
|
923 |
+
self.remember_worker(state, 2, reward, next_state, done)
|
924 |
+
|
925 |
+
state = next_state.astype(np.float32)
|
926 |
+
total_reward += reward
|
927 |
+
|
928 |
+
else:
|
929 |
+
logger.warning(f"Unknown option selected by Manager: {option}")
|
930 |
+
|
931 |
+
# Perform replay for both Manager and Worker
|
932 |
+
self.replay_manager(batch_size=32, beta=0.4)
|
933 |
+
self.replay_worker(batch_size=32, beta=0.4)
|
934 |
+
|
935 |
+
# Update target models periodically
|
936 |
+
self.update_worker_target_model()
|
937 |
+
self.update_manager_target_model()
|
938 |
+
|
939 |
+
logger.info(f"Search completed. Total reward: {total_reward}")
|
940 |
+
defer.returnValue(total_reward)
|
941 |
+
except Exception as e:
|
942 |
+
logger.error(f"Error during search: {str(e)}", exc_info=True)
|
943 |
+
defer.returnValue(-1) # Return a negative reward on error
|
944 |
+
|
945 |
+
# ==========================
|
946 |
+
# Summarization Method
|
947 |
+
# ==========================
|
948 |
+
|
949 |
+
def summarize(self, content, query):
|
950 |
+
chunks = self.summarizer.split_into_chunks(content)
|
951 |
+
embeddings = self.summarizer.get_embeddings(chunks)
|
952 |
+
relevant_chunks = self.summarizer.retrieve_relevant_chunks(query, chunks, embeddings)
|
953 |
+
summary = self.summarizer.generate_summary(query, relevant_chunks)
|
954 |
+
|
955 |
+
# Save RAG data
|
956 |
+
self.summarizer.save_rag_data(query, chunks, embeddings)
|
957 |
+
|
958 |
+
return summary
|
959 |
+
|
960 |
+
# ==========================
|
961 |
+
# MCTS Iterations Calculation
|
962 |
+
# ==========================
|
963 |
+
|
964 |
+
def calculate_mcts_iterations(self, state):
|
965 |
+
# Calculate MCTS iterations based on state complexity
|
966 |
+
base_iterations = 2
|
967 |
+
complexity_factor = np.mean(state) / 100 # Normalize state values
|
968 |
+
iterations = int(base_iterations * (1 + complexity_factor))
|
969 |
+
max_iterations = 5 # Set a reasonable maximum
|
970 |
+
return min(iterations, max_iterations)
|
971 |
+
|
972 |
+
# ==========================
|
973 |
+
# RAG-based Response Generation
|
974 |
+
# ==========================
|
975 |
+
|
976 |
+
def generate_rag_response(self, query, combined_docs):
|
977 |
+
if not combined_docs:
|
978 |
+
logger.warning("No documents available for RAG-based generation.")
|
979 |
+
return "I'm sorry, I couldn't find any relevant information."
|
980 |
+
|
981 |
+
# Prepare context for the generator
|
982 |
+
context = "\n\n".join([f"Title: {doc.get('title', 'No Title')}\nContent: {doc.get('content', '')}" for doc in combined_docs])
|
983 |
+
prompt = f"Query: {query}\n\nContext:\n{context}\n\nAnswer:"
|
984 |
+
|
985 |
+
# Check cache first
|
986 |
+
cache_key = hashlib.md5(prompt.encode()).hexdigest()
|
987 |
+
cached_response = self.summarizer.cache.get(cache_key)
|
988 |
+
if cached_response:
|
989 |
+
logger.debug("Using cached RAG response.")
|
990 |
+
return cached_response
|
991 |
+
|
992 |
+
# Generate response
|
993 |
+
input_ids = self.summarizer.tokenizer.encode(prompt, return_tensors='pt').to(self.summarizer.device)
|
994 |
+
try:
|
995 |
+
output = self.summarizer.model.generate(
|
996 |
+
input_ids,
|
997 |
+
max_length=input_ids.shape[1] + self.summarizer.max_length,
|
998 |
+
num_return_sequences=1,
|
999 |
+
no_repeat_ngram_size=2,
|
1000 |
+
top_k=50,
|
1001 |
+
top_p=0.95,
|
1002 |
+
temperature=0.7,
|
1003 |
+
early_stopping=True
|
1004 |
+
)
|
1005 |
+
except Exception as e:
|
1006 |
+
logger.error(f"Error during RAG response generation: {str(e)}")
|
1007 |
+
return "RAG response generation failed."
|
1008 |
+
|
1009 |
+
response = self.summarizer.tokenizer.decode(output[0], skip_special_tokens=True)
|
1010 |
+
answer = response.split("Answer:")[-1].strip()
|
1011 |
+
|
1012 |
+
# Cache the response
|
1013 |
+
self.summarizer.cache.put(cache_key, answer)
|
1014 |
+
self.summarizer.save_persistent_cache()
|
1015 |
+
return answer
|
1016 |
+
|
1017 |
+
# ==========================
|
1018 |
+
# Manager and Worker Interaction
|
1019 |
+
# ==========================
|
1020 |
+
|
1021 |
+
def select_option(self, option):
|
1022 |
+
|
1023 |
+
"""
|
1024 |
+
Define the mapping of options to their corresponding actions.
|
1025 |
+
"""
|
1026 |
+
# This can be expanded based on the number of options
|
1027 |
+
option_actions = {
|
1028 |
+
0: self.perform_search,
|
1029 |
+
1: self.perform_summarization,
|
1030 |
+
2: self.perform_rag_generation
|
1031 |
+
}
|
1032 |
+
action = option_actions.get(option, None)
|
1033 |
+
if action:
|
1034 |
+
return action
|
1035 |
+
else:
|
1036 |
+
logger.error(f"No action defined for option: {option}")
|
1037 |
+
return None
|
1038 |
+
|
1039 |
+
def perform_search(self, query):
|
1040 |
+
"""
|
1041 |
+
Perform the search action.
|
1042 |
+
"""
|
1043 |
+
# Implementation is handled in the 'search' method
|
1044 |
+
pass
|
1045 |
+
|
1046 |
+
def perform_summarization(self, content, query):
|
1047 |
+
"""
|
1048 |
+
Perform the summarization action.
|
1049 |
+
"""
|
1050 |
+
# Implementation is handled in the 'summarize' method
|
1051 |
+
pass
|
1052 |
+
|
1053 |
+
def perform_rag_generation(self, query, combined_docs):
|
1054 |
+
"""
|
1055 |
+
Perform the RAG-based generation action.
|
1056 |
+
"""
|
1057 |
+
# Implementation is handled in the 'generate_rag_response' method
|
1058 |
+
pass
|
1059 |
+
|
1060 |
+
# ==========================
|
1061 |
+
# LRUCache Class
|
1062 |
+
# ==========================
|
1063 |
+
|
1064 |
+
class LRUCache:
|
1065 |
+
def __init__(self, capacity):
|
1066 |
+
self.cache = OrderedDict()
|
1067 |
+
self.capacity = capacity
|
1068 |
+
|
1069 |
+
def get(self, key):
|
1070 |
+
if key not in self.cache:
|
1071 |
+
return None
|
1072 |
+
self.cache.move_to_end(key)
|
1073 |
+
return self.cache[key]
|
1074 |
+
|
1075 |
+
def put(self, key, value):
|
1076 |
+
if key in self.cache:
|
1077 |
+
self.cache.move_to_end(key)
|
1078 |
+
self.cache[key] = value
|
1079 |
+
if len(self.cache) > self.capacity:
|
1080 |
+
self.cache.popitem(last=False)
|
1081 |
+
|
1082 |
+
|
lightbulb.py
ADDED
@@ -0,0 +1,1696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.optim as optim
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
import copy
|
10 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
11 |
+
from torch.cuda.amp import autocast, GradScaler
|
12 |
+
from datasets import load_dataset
|
13 |
+
from transformers import AutoTokenizer
|
14 |
+
|
15 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
16 |
+
|
17 |
+
def parse_args():
|
18 |
+
parser = argparse.ArgumentParser(description='Train or Inference with World Model and Tree of Thought.')
|
19 |
+
parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name or path')
|
20 |
+
parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name from HuggingFace Datasets')
|
21 |
+
parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
|
22 |
+
parser.add_argument('--batch_size', type=int, default=4, help='Batch size')
|
23 |
+
parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs')
|
24 |
+
parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length')
|
25 |
+
parser.add_argument('--mcts_iterations', type=int, default=3, help='Number of MCTS Iterations')
|
26 |
+
parser.add_argument('--mcts_exploration_constant', type=float, default=1.414, help='Exploration constant for MCTS')
|
27 |
+
parser.add_argument('--accumulation_steps', type=int, default=4, help='Gradient accumulation steps')
|
28 |
+
parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
|
29 |
+
parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
|
30 |
+
parser.add_argument('--alpha', type=float, default=0.1, help='Entropy regularization weight')
|
31 |
+
parser.add_argument('--beta', type=float, default=0.1, help='Variance regularization weight')
|
32 |
+
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm for clipping')
|
33 |
+
parser.add_argument('--save_dir', type=str, default='./models', help='Directory to save the models')
|
34 |
+
parser.add_argument('--temperature', type=float, default=1.0, help='Temperature parameter for entropy and variance')
|
35 |
+
parser.add_argument('--mode', type=str, choices=['train', 'inference'], default='train', help='Mode: train or inference')
|
36 |
+
parser.add_argument('--inference_mode', type=str, choices=['world_model', 'without_world_model', 'world_model_tree_of_thought'], default='world_model_tree_of_thought', help='Inference mode')
|
37 |
+
parser.add_argument('--query', type=str, default='', help='Input query for inference')
|
38 |
+
parser.add_argument('--train_mode', type=str, choices=['world_model', 'language_model'], default='world_model', help='Train world model or language model only')
|
39 |
+
|
40 |
+
# Use parse_known_args to ignore unknown arguments
|
41 |
+
args, unknown = parser.parse_known_args()
|
42 |
+
return args
|
43 |
+
|
44 |
+
def load_data(args, tokenizer):
|
45 |
+
# Load the dataset
|
46 |
+
dataset = load_dataset(args.dataset_name, args.dataset_config)
|
47 |
+
|
48 |
+
# Ensure the tokenizer has a padding token
|
49 |
+
if tokenizer.pad_token is None:
|
50 |
+
tokenizer.pad_token = tokenizer.eos_token
|
51 |
+
|
52 |
+
def tokenize_function(examples):
|
53 |
+
return tokenizer(examples['text'], truncation=True, max_length=args.max_length)
|
54 |
+
|
55 |
+
tokenized_datasets = dataset.map(
|
56 |
+
tokenize_function,
|
57 |
+
batched=True,
|
58 |
+
num_proc=4,
|
59 |
+
remove_columns=dataset['train'].column_names,
|
60 |
+
)
|
61 |
+
|
62 |
+
# Build inputs and labels for language modeling
|
63 |
+
block_size = args.max_length
|
64 |
+
|
65 |
+
def group_texts(examples):
|
66 |
+
# Concatenate all texts
|
67 |
+
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
68 |
+
total_length = len(concatenated_examples['input_ids'])
|
69 |
+
# We drop the small remainder
|
70 |
+
total_length = (total_length // block_size) * block_size
|
71 |
+
# Split by chunks of block_size
|
72 |
+
result = {
|
73 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
74 |
+
for k, t in concatenated_examples.items()
|
75 |
+
}
|
76 |
+
result['labels'] = result['input_ids'].copy()
|
77 |
+
return result
|
78 |
+
|
79 |
+
lm_datasets = tokenized_datasets.map(
|
80 |
+
group_texts,
|
81 |
+
batched=True,
|
82 |
+
num_proc=4,
|
83 |
+
)
|
84 |
+
|
85 |
+
# Create DataLoader
|
86 |
+
train_dataset = lm_datasets['train']
|
87 |
+
eval_dataset = lm_datasets['validation'] if 'validation' in lm_datasets else lm_datasets['test']
|
88 |
+
|
89 |
+
def data_collator(data):
|
90 |
+
return {
|
91 |
+
'input_ids': torch.tensor([f['input_ids'] for f in data], dtype=torch.long),
|
92 |
+
'labels': torch.tensor([f['labels'] for f in data], dtype=torch.long)
|
93 |
+
}
|
94 |
+
|
95 |
+
train_loader = DataLoader(
|
96 |
+
train_dataset,
|
97 |
+
shuffle=True,
|
98 |
+
batch_size=args.batch_size,
|
99 |
+
collate_fn=data_collator,
|
100 |
+
pin_memory=True, # Speeds up transfer to GPU
|
101 |
+
num_workers=4
|
102 |
+
)
|
103 |
+
eval_loader = DataLoader(
|
104 |
+
eval_dataset,
|
105 |
+
shuffle=False,
|
106 |
+
batch_size=args.batch_size,
|
107 |
+
collate_fn=data_collator,
|
108 |
+
pin_memory=True,
|
109 |
+
num_workers=4
|
110 |
+
)
|
111 |
+
|
112 |
+
return train_loader, eval_loader
|
113 |
+
|
114 |
+
def save_all_models(transformer_model, representation_network, dynamics_network, prediction_network, action_encoder, save_dir, epoch):
|
115 |
+
"""
|
116 |
+
Save all models to the specified directory.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
transformer_model (nn.Module): Transformer model.
|
120 |
+
representation_network (nn.Module): Representation network.
|
121 |
+
dynamics_network (nn.Module): Dynamics network.
|
122 |
+
prediction_network (nn.Module): Prediction network.
|
123 |
+
action_encoder (nn.Module): Action encoder.
|
124 |
+
save_dir (str): Directory to save the models.
|
125 |
+
epoch (int): Current epoch number.
|
126 |
+
"""
|
127 |
+
os.makedirs(save_dir, exist_ok=True)
|
128 |
+
|
129 |
+
torch.save(transformer_model.state_dict(), os.path.join(save_dir, f'transformer_model_epoch_{epoch}.pt'))
|
130 |
+
torch.save(representation_network.state_dict(), os.path.join(save_dir, f'representation_network_epoch_{epoch}.pt'))
|
131 |
+
torch.save(dynamics_network.state_dict(), os.path.join(save_dir, f'dynamics_network_epoch_{epoch}.pt'))
|
132 |
+
torch.save(prediction_network.state_dict(), os.path.join(save_dir, f'prediction_network_epoch_{epoch}.pt'))
|
133 |
+
torch.save(action_encoder.state_dict(), os.path.join(save_dir, f'action_encoder_epoch_{epoch}.pt'))
|
134 |
+
|
135 |
+
print(f"All models saved for epoch {epoch}.")
|
136 |
+
|
137 |
+
class RotaryPositionalEncoding(nn.Module):
|
138 |
+
def __init__(self, d_model):
|
139 |
+
super(RotaryPositionalEncoding, self).__init__()
|
140 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
|
141 |
+
self.register_buffer('inv_freq', inv_freq)
|
142 |
+
|
143 |
+
def forward(self, x):
|
144 |
+
seq_len, batch_size, _ = x.size()
|
145 |
+
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
|
146 |
+
sinusoid_inp = torch.einsum("i,j->ij", t, self.inv_freq)
|
147 |
+
sin = sinusoid_inp.sin().unsqueeze(1) # (seq_len, 1, d_model/2)
|
148 |
+
cos = sinusoid_inp.cos().unsqueeze(1) # (seq_len, 1, d_model/2)
|
149 |
+
|
150 |
+
x1 = x[..., 0::2]
|
151 |
+
x2 = x[..., 1::2]
|
152 |
+
|
153 |
+
# Apply rotation
|
154 |
+
x_rotated = torch.zeros_like(x)
|
155 |
+
x_rotated[..., 0::2] = x1 * cos - x2 * sin
|
156 |
+
x_rotated[..., 1::2] = x1 * sin + x2 * cos
|
157 |
+
|
158 |
+
return x_rotated
|
159 |
+
|
160 |
+
class MultiHeadAttention(nn.Module):
|
161 |
+
def __init__(self, d_model, num_heads):
|
162 |
+
super(MultiHeadAttention, self).__init__()
|
163 |
+
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
164 |
+
self.d_k = d_model // num_heads
|
165 |
+
self.num_heads = num_heads
|
166 |
+
self.linear_q = nn.Linear(d_model, d_model)
|
167 |
+
self.linear_k = nn.Linear(d_model, d_model)
|
168 |
+
self.linear_v = nn.Linear(d_model, d_model)
|
169 |
+
self.linear_out = nn.Linear(d_model, d_model)
|
170 |
+
|
171 |
+
def forward(self, query, key, value, mask=None):
|
172 |
+
batch_size = query.size(0)
|
173 |
+
query = self.linear_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
|
174 |
+
key = self.linear_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
|
175 |
+
value = self.linear_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
|
176 |
+
|
177 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)
|
178 |
+
if mask is not None:
|
179 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
180 |
+
attn = F.softmax(scores, dim=-1)
|
181 |
+
output = torch.matmul(attn, value)
|
182 |
+
|
183 |
+
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
|
184 |
+
return self.linear_out(output)
|
185 |
+
|
186 |
+
class MoE(nn.Module):
|
187 |
+
def __init__(self, d_model, num_experts, d_ff, top_k=2, dropout=0.1):
|
188 |
+
super(MoE, self).__init__()
|
189 |
+
self.num_experts = num_experts
|
190 |
+
self.top_k = top_k
|
191 |
+
self.experts = nn.ModuleList([
|
192 |
+
nn.Sequential(
|
193 |
+
nn.Linear(d_model, d_ff),
|
194 |
+
nn.GELU() if i % 2 == 0 else nn.SiLU(),
|
195 |
+
nn.Linear(d_ff, d_model)
|
196 |
+
)
|
197 |
+
for i in range(num_experts)
|
198 |
+
])
|
199 |
+
self.gate = nn.Linear(d_model, num_experts)
|
200 |
+
self.dropout = nn.Dropout(dropout)
|
201 |
+
|
202 |
+
def forward(self, x):
|
203 |
+
batch_size, seq_len, d_model = x.size()
|
204 |
+
# Compute gating scores
|
205 |
+
gate_scores = self.gate(x) # (batch_size, seq_len, num_experts)
|
206 |
+
top_k_scores, top_k_indices = torch.topk(gate_scores, self.top_k, dim=-1) # (batch_size, seq_len, top_k)
|
207 |
+
top_k_scores = F.softmax(top_k_scores, dim=-1) # (batch_size, seq_len, top_k)
|
208 |
+
|
209 |
+
# Initialize output
|
210 |
+
output = torch.zeros_like(x)
|
211 |
+
|
212 |
+
# Flatten batch and sequence dimensions
|
213 |
+
x_flat = x.view(-1, d_model) # (batch_size * seq_len, d_model)
|
214 |
+
output_flat = output.view(-1, d_model)
|
215 |
+
top_k_indices_flat = top_k_indices.view(-1, self.top_k) # (batch_size * seq_len, top_k)
|
216 |
+
top_k_scores_flat = top_k_scores.view(-1, self.top_k) # (batch_size * seq_len, top_k)
|
217 |
+
|
218 |
+
for k in range(self.top_k):
|
219 |
+
expert_idx_flat = top_k_indices_flat[:, k] # (batch_size * seq_len)
|
220 |
+
expert_scores_flat = top_k_scores_flat[:, k] # (batch_size * seq_len)
|
221 |
+
for e in range(self.num_experts):
|
222 |
+
mask = (expert_idx_flat == e) # Boolean mask
|
223 |
+
if mask.any():
|
224 |
+
x_masked = x_flat[mask] # Select tokens for expert e
|
225 |
+
expert_output = self.experts[e](x_masked) # Apply expert e
|
226 |
+
output_flat[mask] += expert_scores_flat[mask].unsqueeze(-1) * expert_output
|
227 |
+
|
228 |
+
output = output_flat.view(batch_size, seq_len, d_model)
|
229 |
+
return self.dropout(output)
|
230 |
+
|
231 |
+
class TransformerBlock(nn.Module):
|
232 |
+
def __init__(self, d_model, num_heads, d_ff, num_experts, dropout=0.1, top_k=2):
|
233 |
+
super(TransformerBlock, self).__init__()
|
234 |
+
self.self_attention = MultiHeadAttention(d_model, num_heads)
|
235 |
+
self.norm1 = nn.LayerNorm(d_model)
|
236 |
+
self.cross_attention = MultiHeadAttention(d_model, num_heads)
|
237 |
+
self.norm2 = nn.LayerNorm(d_model)
|
238 |
+
self.moe = MoE(d_model, num_experts, d_ff, top_k, dropout)
|
239 |
+
self.norm3 = nn.LayerNorm(d_model)
|
240 |
+
|
241 |
+
def forward(self, x, mask=None, enc_output=None, enc_mask=None):
|
242 |
+
# Self-attention
|
243 |
+
attn_output = self.self_attention(x, x, x, mask)
|
244 |
+
x = self.norm1(x + attn_output)
|
245 |
+
# Cross-attention (only in decoder)
|
246 |
+
if enc_output is not None:
|
247 |
+
cross_attn_output = self.cross_attention(x, enc_output, enc_output, enc_mask)
|
248 |
+
x = self.norm2(x + cross_attn_output)
|
249 |
+
# Feedforward/MoE
|
250 |
+
moe_output = self.moe(x)
|
251 |
+
return self.norm3(x + moe_output)
|
252 |
+
|
253 |
+
class Transformer(nn.Module):
|
254 |
+
def __init__(self, input_dim, d_model, num_heads, num_layers, d_ff, num_experts, output_dim, dropout=0.1, top_k=2):
|
255 |
+
super(Transformer, self).__init__()
|
256 |
+
self.embedding = nn.Embedding(input_dim, d_model, padding_idx=input_dim - 1)
|
257 |
+
self.rotary_positional_encoding = RotaryPositionalEncoding(d_model)
|
258 |
+
self.encoder_layers = nn.ModuleList(
|
259 |
+
[TransformerBlock(d_model, num_heads, d_ff, num_experts, dropout, top_k) for _ in range(num_layers)]
|
260 |
+
)
|
261 |
+
self.decoder_layers = nn.ModuleList(
|
262 |
+
[TransformerBlock(d_model, num_heads, d_ff, num_experts, dropout, top_k) for _ in range(num_layers)]
|
263 |
+
)
|
264 |
+
self.output_layer = nn.Linear(d_model, output_dim)
|
265 |
+
self.d_model = d_model
|
266 |
+
|
267 |
+
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
|
268 |
+
# Encoder
|
269 |
+
src = self.embedding(src) * math.sqrt(self.d_model)
|
270 |
+
src = src.transpose(0, 1) # (batch_size, seq_len, d_model) -> (seq_len, batch_size, d_model)
|
271 |
+
src = self.rotary_positional_encoding(src)
|
272 |
+
src = src.transpose(0, 1) # (seq_len, batch_size, d_model) -> (batch_size, seq_len, d_model)
|
273 |
+
for layer in self.encoder_layers:
|
274 |
+
src = layer(src, src_mask)
|
275 |
+
|
276 |
+
# Decoder
|
277 |
+
tgt = self.embedding(tgt) * math.sqrt(self.d_model)
|
278 |
+
tgt = tgt.transpose(0, 1)
|
279 |
+
tgt = self.rotary_positional_encoding(tgt)
|
280 |
+
tgt = tgt.transpose(0, 1)
|
281 |
+
for layer in self.decoder_layers:
|
282 |
+
tgt = layer(tgt, tgt_mask, src, src_mask)
|
283 |
+
output = self.output_layer(tgt)
|
284 |
+
return output
|
285 |
+
|
286 |
+
def generate(self, src, tokenizer, max_length=20, temperature=1.0):
|
287 |
+
"""
|
288 |
+
Generate sequences using differentiable sampling (Gumbel-Softmax).
|
289 |
+
|
290 |
+
Args:
|
291 |
+
src (torch.Tensor): Source input tensor of shape (batch_size, seq_len)
|
292 |
+
tokenizer (transformers.PreTrainedTokenizer): Tokenizer to access special tokens
|
293 |
+
max_length (int): Maximum length of the generated sequence
|
294 |
+
temperature (float): Temperature parameter for Gumbel-Softmax
|
295 |
+
|
296 |
+
Returns:
|
297 |
+
torch.Tensor: Generated sequences of shape (batch_size, max_length)
|
298 |
+
torch.Tensor: Entropy values for each time step
|
299 |
+
torch.Tensor: Variance values for each time step
|
300 |
+
"""
|
301 |
+
batch_size = src.size(0)
|
302 |
+
|
303 |
+
# Encode the source
|
304 |
+
src_enc = self.embedding(src) * math.sqrt(self.d_model)
|
305 |
+
src_enc = src_enc.transpose(0, 1)
|
306 |
+
src_enc = self.rotary_positional_encoding(src_enc)
|
307 |
+
src_enc = src_enc.transpose(0, 1)
|
308 |
+
for layer in self.encoder_layers:
|
309 |
+
src_enc = layer(src_enc)
|
310 |
+
|
311 |
+
# Initialize decoder input with <sos> tokens
|
312 |
+
tgt_seq = torch.full((batch_size, 1), tokenizer.bos_token_id, dtype=torch.long, device=src.device)
|
313 |
+
entropies = []
|
314 |
+
variances = []
|
315 |
+
|
316 |
+
for _ in range(max_length):
|
317 |
+
tgt_emb = self.embedding(tgt_seq) * math.sqrt(self.d_model)
|
318 |
+
tgt_emb = tgt_emb.transpose(0, 1)
|
319 |
+
tgt_emb = self.rotary_positional_encoding(tgt_emb)
|
320 |
+
tgt_emb = tgt_emb.transpose(0, 1)
|
321 |
+
tgt_dec = tgt_emb
|
322 |
+
for layer in self.decoder_layers:
|
323 |
+
tgt_dec = layer(tgt_dec, None, src_enc, None)
|
324 |
+
output = self.output_layer(tgt_dec) # (batch_size, seq_len, vocab_size)
|
325 |
+
logits = output[:, -1, :] # Get logits for the last time step
|
326 |
+
|
327 |
+
# Compute token probabilities
|
328 |
+
probs = F.softmax(logits / temperature, dim=-1) # (batch_size, vocab_size)
|
329 |
+
|
330 |
+
# Compute entropy
|
331 |
+
entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1) # (batch_size)
|
332 |
+
entropies.append(entropy)
|
333 |
+
|
334 |
+
# Sample token using Gumbel-Softmax
|
335 |
+
gumbel_noise = -torch.log(-torch.log(torch.rand_like(probs) + 1e-9) + 1e-9)
|
336 |
+
y = (logits + gumbel_noise) / temperature
|
337 |
+
y = F.softmax(y, dim=-1) # (batch_size, vocab_size)
|
338 |
+
|
339 |
+
# Compute variance
|
340 |
+
variance = torch.var(y, dim=-1) # (batch_size)
|
341 |
+
variances.append(variance)
|
342 |
+
|
343 |
+
# Get token indices (argmax for hard selection)
|
344 |
+
next_tokens = torch.argmax(y, dim=-1, keepdim=True) # (batch_size, 1)
|
345 |
+
tgt_seq = torch.cat([tgt_seq, next_tokens], dim=1)
|
346 |
+
|
347 |
+
# Stack entropies and variances
|
348 |
+
entropies = torch.stack(entropies, dim=1) # (batch_size, max_length)
|
349 |
+
variances = torch.stack(variances, dim=1) # (batch_size, max_length)
|
350 |
+
|
351 |
+
return tgt_seq[:, 1:], entropies, variances # Exclude the initial <sos> token
|
352 |
+
|
353 |
+
# Objective Functions
|
354 |
+
|
355 |
+
class InfoNCE_Loss(nn.Module):
|
356 |
+
def __init__(self, temperature=0.07):
|
357 |
+
super(InfoNCE_Loss, self).__init__()
|
358 |
+
self.temperature = temperature
|
359 |
+
self.cross_entropy = nn.CrossEntropyLoss()
|
360 |
+
|
361 |
+
def forward(self, z_i, z_j):
|
362 |
+
"""
|
363 |
+
Args:
|
364 |
+
z_i (torch.Tensor): Flattened representations from view i, shape (2n, embed_dim)
|
365 |
+
z_j (torch.Tensor): Flattened representations from view j, shape (2n, embed_dim)
|
366 |
+
|
367 |
+
Returns:
|
368 |
+
torch.Tensor: InfoNCE loss
|
369 |
+
"""
|
370 |
+
n = z_i.size(0)
|
371 |
+
z = torch.cat([z_i, z_j], dim=0) # Shape: (2n, embed_dim)
|
372 |
+
|
373 |
+
z = F.normalize(z, dim=1)
|
374 |
+
similarity_matrix = torch.matmul(z, z.T) # Shape: (2n, 2n)
|
375 |
+
|
376 |
+
# Create a mask to exclude self-similarity
|
377 |
+
mask = torch.eye(2 * n, device=z.device, dtype=torch.bool)
|
378 |
+
similarity_matrix = similarity_matrix.masked_fill(mask, -1e4) # Use a manageable negative value
|
379 |
+
|
380 |
+
# Create labels for contrastive learning
|
381 |
+
labels = torch.arange(n, device=z.device)
|
382 |
+
labels = torch.cat([labels + n, labels], dim=0) # Shape: (2n,)
|
383 |
+
|
384 |
+
# Apply temperature scaling
|
385 |
+
similarity_matrix /= self.temperature
|
386 |
+
|
387 |
+
# Compute cross-entropy loss
|
388 |
+
loss = self.cross_entropy(similarity_matrix, labels)
|
389 |
+
return loss
|
390 |
+
|
391 |
+
class CovarianceRegularization(nn.Module):
|
392 |
+
def __init__(self, lambda_reg=1e-3):
|
393 |
+
super(CovarianceRegularization, self).__init__()
|
394 |
+
self.lambda_reg = lambda_reg
|
395 |
+
|
396 |
+
def forward(self, embeddings):
|
397 |
+
"""
|
398 |
+
Args:
|
399 |
+
embeddings (torch.Tensor): Embedding tensor, shape (batch_size, embed_dim)
|
400 |
+
|
401 |
+
Returns:
|
402 |
+
torch.Tensor: Covariance regularization loss
|
403 |
+
"""
|
404 |
+
batch_size, embed_dim = embeddings.size()
|
405 |
+
mean = embeddings.mean(dim=0)
|
406 |
+
embeddings_centered = embeddings - mean
|
407 |
+
cov = (embeddings_centered.T @ embeddings_centered) / (batch_size - 1)
|
408 |
+
cov_loss = torch.sum(cov ** 2) - torch.sum(torch.diag(cov) ** 2)
|
409 |
+
return self.lambda_reg * cov_loss
|
410 |
+
|
411 |
+
class DynamicsPerformanceLoss(nn.Module):
|
412 |
+
def __init__(self, lambda_var=1e-3):
|
413 |
+
super(DynamicsPerformanceLoss, self).__init__()
|
414 |
+
self.lambda_var = lambda_var
|
415 |
+
|
416 |
+
def forward(self, true_next_state, predicted_next_state):
|
417 |
+
"""
|
418 |
+
Args:
|
419 |
+
true_next_state (torch.Tensor): Ground truth next state, shape (batch_size, state_dim)
|
420 |
+
predicted_next_state (torch.Tensor): Predicted next state, shape (batch_size, state_dim)
|
421 |
+
|
422 |
+
Returns:
|
423 |
+
torch.Tensor: Dynamics performance loss
|
424 |
+
"""
|
425 |
+
mse_loss = F.mse_loss(predicted_next_state, true_next_state)
|
426 |
+
variance_loss = torch.var(predicted_next_state, dim=0).mean()
|
427 |
+
return mse_loss + self.lambda_var * variance_loss
|
428 |
+
|
429 |
+
class ThoughtConsistencyLoss(nn.Module):
|
430 |
+
def __init__(self):
|
431 |
+
super(ThoughtConsistencyLoss, self).__init__()
|
432 |
+
|
433 |
+
def forward(self, true_next_state, perturbed_next_state):
|
434 |
+
"""
|
435 |
+
Args:
|
436 |
+
true_next_state (torch.Tensor): Ground truth next state, shape (batch_size, state_dim)
|
437 |
+
perturbed_next_state (torch.Tensor): Perturbed next state, shape (batch_size, state_dim)
|
438 |
+
|
439 |
+
Returns:
|
440 |
+
torch.Tensor: Thought-consistency loss
|
441 |
+
"""
|
442 |
+
return F.mse_loss(true_next_state, perturbed_next_state)
|
443 |
+
|
444 |
+
class PolicyValueJointLoss(nn.Module):
|
445 |
+
def __init__(self, lambda_value=0.5):
|
446 |
+
super(PolicyValueJointLoss, self).__init__()
|
447 |
+
self.lambda_value = lambda_value
|
448 |
+
self.cross_entropy = nn.CrossEntropyLoss()
|
449 |
+
self.mse_loss = nn.MSELoss()
|
450 |
+
|
451 |
+
def forward(self, policy_logits, true_policy, value_pred, true_value):
|
452 |
+
"""
|
453 |
+
Args:
|
454 |
+
policy_logits (torch.Tensor): Logits from the policy network, shape (batch_size * seq_len, num_actions)
|
455 |
+
true_policy (torch.Tensor): Ground truth policy, shape (batch_size * seq_len, num_actions)
|
456 |
+
value_pred (torch.Tensor): Predicted values, shape (batch_size * seq_len)
|
457 |
+
true_value (torch.Tensor): Ground truth values, shape (batch_size * seq_len)
|
458 |
+
|
459 |
+
Returns:
|
460 |
+
torch.Tensor: Combined policy and value loss
|
461 |
+
"""
|
462 |
+
policy_logits = policy_logits.view(-1, policy_logits.size(-1))
|
463 |
+
true_policy = true_policy.view(-1, true_policy.size(-1))
|
464 |
+
value_pred = value_pred.view(-1)
|
465 |
+
true_value = true_value.view(-1)
|
466 |
+
|
467 |
+
policy_loss = self.cross_entropy(policy_logits, true_policy.argmax(dim=1))
|
468 |
+
value_loss = self.mse_loss(value_pred, true_value)
|
469 |
+
return policy_loss + self.lambda_value * value_loss
|
470 |
+
|
471 |
+
class ActionDiversityReward(nn.Module):
|
472 |
+
def __init__(self, lambda_div=1e-3):
|
473 |
+
super(ActionDiversityReward, self).__init__()
|
474 |
+
self.lambda_div = lambda_div
|
475 |
+
|
476 |
+
def forward(self, action_embeddings):
|
477 |
+
"""
|
478 |
+
Args:
|
479 |
+
action_embeddings (torch.Tensor): Embeddings of actions, shape (batch_size, embed_dim)
|
480 |
+
|
481 |
+
Returns:
|
482 |
+
torch.Tensor: Action diversity loss
|
483 |
+
"""
|
484 |
+
similarity_matrix = F.cosine_similarity(action_embeddings.unsqueeze(1), action_embeddings.unsqueeze(0), dim=2)
|
485 |
+
# Zero out self-similarity
|
486 |
+
similarity_matrix = similarity_matrix - torch.eye(similarity_matrix.size(0)).to(action_embeddings.device)
|
487 |
+
diversity_loss = torch.sum(similarity_matrix ** 2)
|
488 |
+
return self.lambda_div * diversity_loss
|
489 |
+
|
490 |
+
class ExpectedThoughtValueLoss(nn.Module):
|
491 |
+
def __init__(self):
|
492 |
+
super(ExpectedThoughtValueLoss, self).__init__()
|
493 |
+
|
494 |
+
def forward(self, mcts_best_values):
|
495 |
+
"""
|
496 |
+
Args:
|
497 |
+
mcts_best_values (torch.Tensor): Best values from MCTS, shape (batch_size)
|
498 |
+
|
499 |
+
Returns:
|
500 |
+
torch.Tensor: ETV loss
|
501 |
+
"""
|
502 |
+
return -mcts_best_values.mean()
|
503 |
+
|
504 |
+
class ExplorationRegularization(nn.Module):
|
505 |
+
def __init__(self, lambda_expl=1e-3):
|
506 |
+
super(ExplorationRegularization, self).__init__()
|
507 |
+
self.lambda_expl = lambda_expl
|
508 |
+
|
509 |
+
def forward(self, visit_counts):
|
510 |
+
"""
|
511 |
+
Args:
|
512 |
+
visit_counts (torch.Tensor): Visit counts for actions, shape (batch_size, num_actions)
|
513 |
+
|
514 |
+
Returns:
|
515 |
+
torch.Tensor: Exploration regularization loss
|
516 |
+
"""
|
517 |
+
reward = torch.sum(1.0 / (visit_counts + 1), dim=-1)
|
518 |
+
return self.lambda_expl * reward.mean()
|
519 |
+
|
520 |
+
class KL_DivergenceLoss(nn.Module):
|
521 |
+
def __init__(self):
|
522 |
+
super(KL_DivergenceLoss, self).__init__()
|
523 |
+
|
524 |
+
def forward(self, old_policy, new_policy):
|
525 |
+
"""
|
526 |
+
Args:
|
527 |
+
old_policy (torch.Tensor): Old policy probabilities, shape (batch_size, num_actions)
|
528 |
+
new_policy (torch.Tensor): New policy probabilities, shape (batch_size, num_actions)
|
529 |
+
|
530 |
+
Returns:
|
531 |
+
torch.Tensor: KL divergence loss
|
532 |
+
"""
|
533 |
+
kl_div = F.kl_div(new_policy.log(), old_policy, reduction='batchmean')
|
534 |
+
return kl_div
|
535 |
+
|
536 |
+
# MuZero Components
|
537 |
+
|
538 |
+
class ActionEncoder(nn.Module):
|
539 |
+
def __init__(self, action_vocab_size, embed_dim):
|
540 |
+
super(ActionEncoder, self).__init__()
|
541 |
+
self.embedding = nn.Embedding(action_vocab_size, embed_dim)
|
542 |
+
|
543 |
+
def forward(self, action_indices):
|
544 |
+
"""
|
545 |
+
Args:
|
546 |
+
action_indices (torch.Tensor): Tensor of shape (batch_size, seq_len)
|
547 |
+
|
548 |
+
Returns:
|
549 |
+
torch.Tensor: Encoded actions of shape (batch_size, seq_len, embed_dim)
|
550 |
+
"""
|
551 |
+
return self.embedding(action_indices)
|
552 |
+
|
553 |
+
class RepresentationNetwork(nn.Module):
|
554 |
+
def __init__(self, vocab_dim, d_model, state_dim):
|
555 |
+
super(RepresentationNetwork, self).__init__()
|
556 |
+
self.proj = nn.Linear(vocab_dim, d_model) # Project from vocab_dim to d_model
|
557 |
+
self.linear = nn.Linear(d_model, state_dim) # Project from d_model to state_dim
|
558 |
+
self.norm = nn.LayerNorm(state_dim)
|
559 |
+
|
560 |
+
def forward(self, transformer_output):
|
561 |
+
"""
|
562 |
+
Args:
|
563 |
+
transformer_output (torch.Tensor): Shape (batch_size, seq_len, vocab_dim)
|
564 |
+
|
565 |
+
Returns:
|
566 |
+
torch.Tensor: Encoded state of shape (batch_size, seq_len, state_dim)
|
567 |
+
"""
|
568 |
+
# First project down from vocab_dim to d_model
|
569 |
+
projected_output = self.proj(transformer_output)
|
570 |
+
# Then project down from d_model to state_dim
|
571 |
+
state = self.linear(projected_output)
|
572 |
+
state = self.norm(state)
|
573 |
+
return state
|
574 |
+
|
575 |
+
class DynamicsNetwork(nn.Module):
|
576 |
+
def __init__(self, state_dim, action_dim, hidden_dim):
|
577 |
+
super(DynamicsNetwork, self).__init__()
|
578 |
+
self.rms_norm = nn.LayerNorm(state_dim)
|
579 |
+
self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
|
580 |
+
self.activation = nn.GELU()
|
581 |
+
self.fc2 = nn.Linear(hidden_dim, state_dim)
|
582 |
+
|
583 |
+
def forward(self, state, action):
|
584 |
+
"""
|
585 |
+
Args:
|
586 |
+
state (torch.Tensor): Current state, shape (batch_size, seq_len, state_dim)
|
587 |
+
action (torch.Tensor): Action embedding, shape (batch_size, seq_len, action_dim)
|
588 |
+
|
589 |
+
Returns:
|
590 |
+
torch.Tensor: Predicted next state, shape (batch_size, seq_len, state_dim)
|
591 |
+
"""
|
592 |
+
norm_state = self.rms_norm(state)
|
593 |
+
combined = torch.cat([norm_state, action], dim=-1)
|
594 |
+
hidden = self.activation(self.fc1(combined))
|
595 |
+
next_state = self.fc2(hidden)
|
596 |
+
return next_state
|
597 |
+
|
598 |
+
class PredictionNetwork(nn.Module):
|
599 |
+
def __init__(self, state_dim, action_vocab_size, value_dim):
|
600 |
+
super(PredictionNetwork, self).__init__()
|
601 |
+
self.state_dim = state_dim
|
602 |
+
self.rms_norm = nn.LayerNorm(state_dim)
|
603 |
+
self.policy_head = nn.Linear(state_dim, action_vocab_size) # Output size is action_vocab_size
|
604 |
+
self.value_head = nn.Linear(state_dim, value_dim)
|
605 |
+
|
606 |
+
def forward(self, state):
|
607 |
+
"""
|
608 |
+
Args:
|
609 |
+
state (torch.Tensor): State representation, shape (batch_size, seq_len, state_dim)
|
610 |
+
Returns:
|
611 |
+
Tuple[torch.Tensor, torch.Tensor]: Policy logits and value estimates
|
612 |
+
"""
|
613 |
+
norm_state = self.rms_norm(state)
|
614 |
+
policy_logits = self.policy_head(norm_state) # Shape: (batch_size, seq_len, action_vocab_size)
|
615 |
+
value_estimates = self.value_head(norm_state).squeeze(-1) # Shape: (batch_size, seq_len)
|
616 |
+
return policy_logits, value_estimates
|
617 |
+
|
618 |
+
|
619 |
+
# Tree of Thought Components
|
620 |
+
|
621 |
+
class ThoughtNode:
|
622 |
+
def __init__(self, name):
|
623 |
+
self.name = name
|
624 |
+
self.children = []
|
625 |
+
self.parent = None
|
626 |
+
|
627 |
+
def add_child(self, child_node):
|
628 |
+
child_node.parent = self
|
629 |
+
self.children.append(child_node)
|
630 |
+
|
631 |
+
# Function to build the Tree of Thought from your detailed structure
|
632 |
+
def build_tree_of_thought():
|
633 |
+
# Create the root node
|
634 |
+
root = ThoughtNode('Problem-Solving Process')
|
635 |
+
|
636 |
+
# Level 1 nodes
|
637 |
+
problem_identification = ThoughtNode('Problem Identification')
|
638 |
+
problem_analysis = ThoughtNode('Problem Analysis')
|
639 |
+
solution_generation = ThoughtNode('Solution Generation')
|
640 |
+
implementation = ThoughtNode('Implementation')
|
641 |
+
evaluation_adjustment = ThoughtNode('Evaluation and Adjustment')
|
642 |
+
|
643 |
+
root.add_child(problem_identification)
|
644 |
+
root.add_child(problem_analysis)
|
645 |
+
root.add_child(solution_generation)
|
646 |
+
root.add_child(implementation)
|
647 |
+
root.add_child(evaluation_adjustment)
|
648 |
+
|
649 |
+
# Problem Identification children
|
650 |
+
B1 = ThoughtNode('Define the Problem')
|
651 |
+
B2 = ThoughtNode('Identify Stakeholders')
|
652 |
+
B3 = ThoughtNode('Determine Constraints')
|
653 |
+
B4 = ThoughtNode('Recognize Problem Type')
|
654 |
+
B5 = ThoughtNode('Historical Context')
|
655 |
+
problem_identification.add_child(B1)
|
656 |
+
problem_identification.add_child(B2)
|
657 |
+
problem_identification.add_child(B3)
|
658 |
+
problem_identification.add_child(B4)
|
659 |
+
problem_identification.add_child(B5)
|
660 |
+
|
661 |
+
# Define the Problem children
|
662 |
+
B1a = ThoughtNode('Problem Statement Formulation')
|
663 |
+
B1b = ThoughtNode('Scope Definition')
|
664 |
+
B1c = ThoughtNode('Objective Setting')
|
665 |
+
B1.add_child(B1a)
|
666 |
+
B1.add_child(B1b)
|
667 |
+
B1.add_child(B1c)
|
668 |
+
|
669 |
+
# Identify Stakeholders children
|
670 |
+
B2a = ThoughtNode('Stakeholder Mapping')
|
671 |
+
B2b = ThoughtNode('Interest and Influence Analysis')
|
672 |
+
B2c = ThoughtNode('Engagement Strategy')
|
673 |
+
B2.add_child(B2a)
|
674 |
+
B2.add_child(B2b)
|
675 |
+
B2.add_child(B2c)
|
676 |
+
|
677 |
+
# Determine Constraints children
|
678 |
+
B3a = ThoughtNode('Resource Limitations')
|
679 |
+
B3b = ThoughtNode('Time Constraints')
|
680 |
+
B3c = ThoughtNode('Legal and Regulatory Constraints')
|
681 |
+
B3.add_child(B3a)
|
682 |
+
B3.add_child(B3b)
|
683 |
+
B3.add_child(B3c)
|
684 |
+
|
685 |
+
# Recognize Problem Type children
|
686 |
+
B4a = ThoughtNode('Simple vs Complex')
|
687 |
+
B4b = ThoughtNode('Known vs Unknown')
|
688 |
+
B4c = ThoughtNode('Tame vs Wicked Problems')
|
689 |
+
B4.add_child(B4a)
|
690 |
+
B4.add_child(B4b)
|
691 |
+
B4.add_child(B4c)
|
692 |
+
|
693 |
+
# Historical Context children
|
694 |
+
B5a = ThoughtNode('Previous Attempts')
|
695 |
+
B5b = ThoughtNode('Lessons Learned')
|
696 |
+
B5c = ThoughtNode('Environmental Factors')
|
697 |
+
B5.add_child(B5a)
|
698 |
+
B5.add_child(B5b)
|
699 |
+
B5.add_child(B5c)
|
700 |
+
|
701 |
+
# Problem Analysis children
|
702 |
+
C1 = ThoughtNode('Root Cause Analysis')
|
703 |
+
C2 = ThoughtNode('System Mapping')
|
704 |
+
C3 = ThoughtNode('Data Collection')
|
705 |
+
C4 = ThoughtNode('Impact Assessment')
|
706 |
+
C5 = ThoughtNode('Theoretical Framework')
|
707 |
+
problem_analysis.add_child(C1)
|
708 |
+
problem_analysis.add_child(C2)
|
709 |
+
problem_analysis.add_child(C3)
|
710 |
+
problem_analysis.add_child(C4)
|
711 |
+
problem_analysis.add_child(C5)
|
712 |
+
|
713 |
+
# Root Cause Analysis children
|
714 |
+
C1a = ThoughtNode('5 Whys Technique')
|
715 |
+
C1b = ThoughtNode('Fishbone Diagram')
|
716 |
+
C1c = ThoughtNode('Pareto Analysis')
|
717 |
+
C1.add_child(C1a)
|
718 |
+
C1.add_child(C1b)
|
719 |
+
C1.add_child(C1c)
|
720 |
+
|
721 |
+
# System Mapping children
|
722 |
+
C2a = ThoughtNode('Causal Loop Diagrams')
|
723 |
+
C2b = ThoughtNode('Stock and Flow Models')
|
724 |
+
C2c = ThoughtNode('Network Analysis')
|
725 |
+
C2.add_child(C2a)
|
726 |
+
C2.add_child(C2b)
|
727 |
+
C2.add_child(C2c)
|
728 |
+
|
729 |
+
# Data Collection children
|
730 |
+
C3a = ThoughtNode('Quantitative Data')
|
731 |
+
C3b = ThoughtNode('Qualitative Data')
|
732 |
+
C3c = ThoughtNode('Data Validation')
|
733 |
+
C3.add_child(C3a)
|
734 |
+
C3.add_child(C3b)
|
735 |
+
C3.add_child(C3c)
|
736 |
+
|
737 |
+
# Quantitative Data children
|
738 |
+
C3a1 = ThoughtNode('Surveys and Questionnaires')
|
739 |
+
C3a2 = ThoughtNode('Experimental Data')
|
740 |
+
C3a3 = ThoughtNode('Big Data Analytics')
|
741 |
+
C3a.add_child(C3a1)
|
742 |
+
C3a.add_child(C3a2)
|
743 |
+
C3a.add_child(C3a3)
|
744 |
+
|
745 |
+
# Qualitative Data children
|
746 |
+
C3b1 = ThoughtNode('Interviews')
|
747 |
+
C3b2 = ThoughtNode('Focus Groups')
|
748 |
+
C3b3 = ThoughtNode('Observational Studies')
|
749 |
+
C3b.add_child(C3b1)
|
750 |
+
C3b.add_child(C3b2)
|
751 |
+
C3b.add_child(C3b3)
|
752 |
+
|
753 |
+
# Data Validation children
|
754 |
+
C3c1 = ThoughtNode('Statistical Validation')
|
755 |
+
C3c2 = ThoughtNode('Cross-Validation')
|
756 |
+
C3c3 = ThoughtNode('Expert Review')
|
757 |
+
C3c.add_child(C3c1)
|
758 |
+
C3c.add_child(C3c2)
|
759 |
+
C3c.add_child(C3c3)
|
760 |
+
|
761 |
+
# Impact Assessment children
|
762 |
+
C4a = ThoughtNode('Environmental Impact')
|
763 |
+
C4b = ThoughtNode('Social Impact')
|
764 |
+
C4c = ThoughtNode('Economic Impact')
|
765 |
+
C4.add_child(C4a)
|
766 |
+
C4.add_child(C4b)
|
767 |
+
C4.add_child(C4c)
|
768 |
+
|
769 |
+
# Theoretical Framework children
|
770 |
+
C5a = ThoughtNode('Literature Review')
|
771 |
+
C5b = ThoughtNode('Conceptual Modeling')
|
772 |
+
C5c = ThoughtNode('Hypothesis Formation')
|
773 |
+
C5.add_child(C5a)
|
774 |
+
C5.add_child(C5b)
|
775 |
+
C5.add_child(C5c)
|
776 |
+
|
777 |
+
# Solution Generation children
|
778 |
+
D1 = ThoughtNode('Creative Problem Solving')
|
779 |
+
D2 = ThoughtNode('Analytical Approach')
|
780 |
+
D3 = ThoughtNode('Mathematical Computation')
|
781 |
+
D4 = ThoughtNode('Decision Making')
|
782 |
+
solution_generation.add_child(D1)
|
783 |
+
solution_generation.add_child(D2)
|
784 |
+
solution_generation.add_child(D3)
|
785 |
+
solution_generation.add_child(D4)
|
786 |
+
|
787 |
+
# Action Planning, Resource Allocation, Change Management children (implementation phase)
|
788 |
+
E1 = ThoughtNode('Action Planning')
|
789 |
+
E2 = ThoughtNode('Resource Allocation')
|
790 |
+
E3 = ThoughtNode('Change Management')
|
791 |
+
implementation.add_child(E1)
|
792 |
+
implementation.add_child(E2)
|
793 |
+
implementation.add_child(E3)
|
794 |
+
|
795 |
+
# Verification, Performance Metrics, Feedback Loops, Continuous Improvement children (evaluation phase)
|
796 |
+
F1 = ThoughtNode('Verification')
|
797 |
+
F2 = ThoughtNode('Performance Metrics')
|
798 |
+
F3 = ThoughtNode('Feedback Loops')
|
799 |
+
F4 = ThoughtNode('Continuous Improvement')
|
800 |
+
evaluation_adjustment.add_child(F1)
|
801 |
+
evaluation_adjustment.add_child(F2)
|
802 |
+
evaluation_adjustment.add_child(F3)
|
803 |
+
evaluation_adjustment.add_child(F4)
|
804 |
+
|
805 |
+
# Cross-Cutting Considerations children
|
806 |
+
G = ThoughtNode('Cross-Cutting Considerations')
|
807 |
+
root.add_child(G)
|
808 |
+
|
809 |
+
# Cross-Cutting Considerations children
|
810 |
+
G1 = ThoughtNode('Ethical Framework')
|
811 |
+
G2 = ThoughtNode('Stakeholder Management')
|
812 |
+
G3 = ThoughtNode('Interdisciplinary Connections')
|
813 |
+
G4 = ThoughtNode('Technological Integration')
|
814 |
+
G5 = ThoughtNode('Emotional Intelligence')
|
815 |
+
G6 = ThoughtNode('Collaborative Problem Solving')
|
816 |
+
G7 = ThoughtNode('Computational Considerations') # Assuming H was intended as G7
|
817 |
+
G8 = ThoughtNode('Order of Operations') # Assuming I was intended as G8
|
818 |
+
G9 = ThoughtNode('Critical Thinking') # Assuming J was intended as G9
|
819 |
+
G10 = ThoughtNode('Future Perspective') # Assuming K was intended as G10
|
820 |
+
G11 = ThoughtNode('Learning and Adaptation') # Assuming L was intended as G11
|
821 |
+
G.add_child(G1)
|
822 |
+
G.add_child(G2)
|
823 |
+
G.add_child(G3)
|
824 |
+
G.add_child(G4)
|
825 |
+
G.add_child(G5)
|
826 |
+
G.add_child(G6)
|
827 |
+
G.add_child(G7)
|
828 |
+
G.add_child(G8)
|
829 |
+
G.add_child(G9)
|
830 |
+
G.add_child(G10)
|
831 |
+
G.add_child(G11)
|
832 |
+
|
833 |
+
# Ethical Framework children
|
834 |
+
G1a = ThoughtNode('Value-based Decision Making')
|
835 |
+
G1b = ThoughtNode('Long-term Consequences')
|
836 |
+
G1.add_child(G1a)
|
837 |
+
G1.add_child(G1b)
|
838 |
+
|
839 |
+
# Value-based Decision Making children
|
840 |
+
G1a1 = ThoughtNode('Ethical Theories Application')
|
841 |
+
G1a2 = ThoughtNode('Moral Dilemma Resolution')
|
842 |
+
G1a.add_child(G1a1)
|
843 |
+
G1a.add_child(G1a2)
|
844 |
+
|
845 |
+
# Long-term Consequences children
|
846 |
+
G1b1 = ThoughtNode('Sustainability Assessment')
|
847 |
+
G1b2 = ThoughtNode('Intergenerational Impact')
|
848 |
+
G1b.add_child(G1b1)
|
849 |
+
G1b.add_child(G1b2)
|
850 |
+
|
851 |
+
# Stakeholder Management children
|
852 |
+
G2a = ThoughtNode('Direct Stakeholders')
|
853 |
+
G2b = ThoughtNode('Indirect Stakeholders')
|
854 |
+
G2c = ThoughtNode('Conflicting Interests')
|
855 |
+
G2.add_child(G2a)
|
856 |
+
G2.add_child(G2b)
|
857 |
+
G2.add_child(G2c)
|
858 |
+
|
859 |
+
# Conflicting Interests children
|
860 |
+
G2c1 = ThoughtNode('Negotiation Strategies')
|
861 |
+
G2c2 = ThoughtNode('Conflict Resolution Techniques')
|
862 |
+
G2c.add_child(G2c1)
|
863 |
+
G2c.add_child(G2c2)
|
864 |
+
|
865 |
+
# Interdisciplinary Connections children
|
866 |
+
G3a = ThoughtNode('Related Fields')
|
867 |
+
G3b = ThoughtNode('Cross-disciplinary Impact')
|
868 |
+
G3.add_child(G3a)
|
869 |
+
G3.add_child(G3b)
|
870 |
+
|
871 |
+
# Related Fields children
|
872 |
+
G3a1 = ThoughtNode('Cross-domain Knowledge Transfer')
|
873 |
+
G3a2 = ThoughtNode('Interdisciplinary Collaboration')
|
874 |
+
G3a.add_child(G3a1)
|
875 |
+
G3a.add_child(G3a2)
|
876 |
+
|
877 |
+
# Cross-disciplinary Impact children
|
878 |
+
G3b1 = ThoughtNode('Synergy Identification')
|
879 |
+
G3b2 = ThoughtNode('Holistic Impact Assessment')
|
880 |
+
G3b.add_child(G3b1)
|
881 |
+
G3b.add_child(G3b2)
|
882 |
+
|
883 |
+
# Technological Integration children
|
884 |
+
G4a = ThoughtNode('AI-assisted Problem Solving')
|
885 |
+
G4b = ThoughtNode('Data-driven Insights')
|
886 |
+
G4c = ThoughtNode('Digital Collaboration Tools')
|
887 |
+
G4.add_child(G4a)
|
888 |
+
G4.add_child(G4b)
|
889 |
+
G4.add_child(G4c)
|
890 |
+
|
891 |
+
# AI-assisted Problem Solving children
|
892 |
+
G4a1 = ThoughtNode('Machine Learning Models')
|
893 |
+
G4a2 = ThoughtNode('Natural Language Processing')
|
894 |
+
G4a.add_child(G4a1)
|
895 |
+
G4a.add_child(G4a2)
|
896 |
+
|
897 |
+
# Data-driven Insights children
|
898 |
+
G4b1 = ThoughtNode('Big Data Analytics')
|
899 |
+
G4b2 = ThoughtNode('Predictive Modeling')
|
900 |
+
G4b.add_child(G4b1)
|
901 |
+
G4b.add_child(G4b2)
|
902 |
+
|
903 |
+
# Digital Collaboration Tools children
|
904 |
+
G4c1 = ThoughtNode('Project Management Platforms')
|
905 |
+
G4c2 = ThoughtNode('Virtual Reality Collaboration')
|
906 |
+
G4c.add_child(G4c1)
|
907 |
+
G4c.add_child(G4c2)
|
908 |
+
|
909 |
+
# Emotional Intelligence children
|
910 |
+
G5a = ThoughtNode('Self-Awareness')
|
911 |
+
G5b = ThoughtNode('Empathy')
|
912 |
+
G5c = ThoughtNode('Stress Management')
|
913 |
+
G5.add_child(G5a)
|
914 |
+
G5.add_child(G5b)
|
915 |
+
G5.add_child(G5c)
|
916 |
+
|
917 |
+
# Self-Awareness children
|
918 |
+
G5a1 = ThoughtNode('Emotional Recognition')
|
919 |
+
G5a2 = ThoughtNode('Personal Bias Identification')
|
920 |
+
G5a.add_child(G5a1)
|
921 |
+
G5a.add_child(G5a2)
|
922 |
+
|
923 |
+
# Empathy children
|
924 |
+
G5b1 = ThoughtNode('Perspective Taking')
|
925 |
+
G5b2 = ThoughtNode('Active Listening')
|
926 |
+
G5b.add_child(G5b1)
|
927 |
+
G5b.add_child(G5b2)
|
928 |
+
|
929 |
+
# Stress Management children
|
930 |
+
G5c1 = ThoughtNode('Mindfulness Techniques')
|
931 |
+
G5c2 = ThoughtNode('Resilience Building')
|
932 |
+
G5c.add_child(G5c1)
|
933 |
+
G5c.add_child(G5c2)
|
934 |
+
|
935 |
+
# Collaborative Problem Solving children
|
936 |
+
G6a = ThoughtNode('Team Dynamics')
|
937 |
+
G6b = ThoughtNode('Communication Strategies')
|
938 |
+
G6c = ThoughtNode('Conflict Resolution')
|
939 |
+
G6.add_child(G6a)
|
940 |
+
G6.add_child(G6b)
|
941 |
+
G6.add_child(G6c)
|
942 |
+
|
943 |
+
# Team Dynamics children
|
944 |
+
G6a1 = ThoughtNode('Team Formation Strategies')
|
945 |
+
G6a2 = ThoughtNode('Role Assignment')
|
946 |
+
G6a.add_child(G6a1)
|
947 |
+
G6a.add_child(G6a2)
|
948 |
+
|
949 |
+
# Communication Strategies children
|
950 |
+
G6b1 = ThoughtNode('Clear Messaging')
|
951 |
+
G6b2 = ThoughtNode('Feedback Mechanisms')
|
952 |
+
G6b.add_child(G6b1)
|
953 |
+
G6b.add_child(G6b2)
|
954 |
+
|
955 |
+
# Conflict Resolution children
|
956 |
+
G6c1 = ThoughtNode('Mediation Techniques')
|
957 |
+
G6c2 = ThoughtNode('Consensus Building')
|
958 |
+
G6c.add_child(G6c1)
|
959 |
+
G6c.add_child(G6c2)
|
960 |
+
|
961 |
+
# Computational Considerations children
|
962 |
+
G7a = ThoughtNode('CPU Operations')
|
963 |
+
G7b = ThoughtNode('GPU Parallelization')
|
964 |
+
G7c = ThoughtNode('Floating-Point Precision')
|
965 |
+
G7.add_child(G7a)
|
966 |
+
G7.add_child(G7b)
|
967 |
+
G7.add_child(G7c)
|
968 |
+
|
969 |
+
# CPU Operations children
|
970 |
+
G7a1 = ThoughtNode('Instruction Set Architecture')
|
971 |
+
G7a2 = ThoughtNode('Pipelining and Parallelism')
|
972 |
+
G7a.add_child(G7a1)
|
973 |
+
G7a.add_child(G7a2)
|
974 |
+
|
975 |
+
# GPU Parallelization children
|
976 |
+
G7b1 = ThoughtNode('CUDA Programming')
|
977 |
+
G7b2 = ThoughtNode('OpenCL Framework')
|
978 |
+
G7b.add_child(G7b1)
|
979 |
+
G7b.add_child(G7b2)
|
980 |
+
|
981 |
+
# Floating-Point Precision children
|
982 |
+
G7c1 = ThoughtNode('IEEE 754 Standard')
|
983 |
+
G7c2 = ThoughtNode('Error Propagation Analysis')
|
984 |
+
G7c.add_child(G7c1)
|
985 |
+
G7c.add_child(G7c2)
|
986 |
+
|
987 |
+
# Order of Operations children
|
988 |
+
G8a = ThoughtNode('Parentheses')
|
989 |
+
G8b = ThoughtNode('Exponents')
|
990 |
+
G8c = ThoughtNode('Multiplication and Division')
|
991 |
+
G8d = ThoughtNode('Addition and Subtraction')
|
992 |
+
G8.add_child(G8a)
|
993 |
+
G8.add_child(G8b)
|
994 |
+
G8.add_child(G8c)
|
995 |
+
G8.add_child(G8d)
|
996 |
+
|
997 |
+
# Critical Thinking children
|
998 |
+
G9a = ThoughtNode('Assumptions Questioning')
|
999 |
+
G9b = ThoughtNode('Bias Recognition')
|
1000 |
+
G9.add_child(G9a)
|
1001 |
+
G9.add_child(G9b)
|
1002 |
+
|
1003 |
+
# Assumptions Questioning children
|
1004 |
+
G9a1 = ThoughtNode('Socratic Questioning')
|
1005 |
+
G9a2 = ThoughtNode('Devil\'s Advocate Approach')
|
1006 |
+
G9a.add_child(G9a1)
|
1007 |
+
G9a.add_child(G9a2)
|
1008 |
+
|
1009 |
+
# Bias Recognition children
|
1010 |
+
G9b1 = ThoughtNode('Cognitive Bias Identification')
|
1011 |
+
G9b2 = ThoughtNode('Debiasing Techniques')
|
1012 |
+
G9b.add_child(G9b1)
|
1013 |
+
G9b.add_child(G9b2)
|
1014 |
+
|
1015 |
+
# Future Perspective children
|
1016 |
+
G10a = ThoughtNode('Short-term Projections')
|
1017 |
+
G10b = ThoughtNode('Long-term Scenarios')
|
1018 |
+
G10c = ThoughtNode('Potential Impacts')
|
1019 |
+
G10.add_child(G10a)
|
1020 |
+
G10.add_child(G10b)
|
1021 |
+
G10.add_child(G10c)
|
1022 |
+
|
1023 |
+
# Short-term Projections children
|
1024 |
+
G10a1 = ThoughtNode('Trend Analysis')
|
1025 |
+
G10a2 = ThoughtNode('Scenario Planning')
|
1026 |
+
G10a.add_child(G10a1)
|
1027 |
+
G10a.add_child(G10a2)
|
1028 |
+
|
1029 |
+
# Long-term Scenarios children
|
1030 |
+
G10b1 = ThoughtNode('Futures Wheel')
|
1031 |
+
G10b2 = ThoughtNode('Backcasting')
|
1032 |
+
G10b.add_child(G10b1)
|
1033 |
+
G10b.add_child(G10b2)
|
1034 |
+
|
1035 |
+
# Potential Impacts children
|
1036 |
+
G10c1 = ThoughtNode('Risk Assessment')
|
1037 |
+
G10c2 = ThoughtNode('Opportunity Identification')
|
1038 |
+
G10c.add_child(G10c1)
|
1039 |
+
G10c.add_child(G10c2)
|
1040 |
+
|
1041 |
+
# Learning and Adaptation children
|
1042 |
+
G11a = ThoughtNode('Reflective Practice')
|
1043 |
+
G11b = ThoughtNode('Knowledge Transfer')
|
1044 |
+
G11c = ThoughtNode('Adaptive Problem Solving')
|
1045 |
+
G11.add_child(G11a)
|
1046 |
+
G11.add_child(G11b)
|
1047 |
+
G11.add_child(G11c)
|
1048 |
+
|
1049 |
+
# Reflective Practice children
|
1050 |
+
G11a1 = ThoughtNode('After Action Review')
|
1051 |
+
G11a2 = ThoughtNode('Learning Journals')
|
1052 |
+
G11a.add_child(G11a1)
|
1053 |
+
G11a.add_child(G11a2)
|
1054 |
+
|
1055 |
+
# Knowledge Transfer children
|
1056 |
+
G11b1 = ThoughtNode('Best Practice Documentation')
|
1057 |
+
G11b2 = ThoughtNode('Mentoring Programs')
|
1058 |
+
G11b.add_child(G11b1)
|
1059 |
+
G11b.add_child(G11b2)
|
1060 |
+
|
1061 |
+
# Adaptive Problem Solving children
|
1062 |
+
G11c1 = ThoughtNode('Iterative Approaches')
|
1063 |
+
G11c2 = ThoughtNode('Flexibility in Methodology')
|
1064 |
+
G11c.add_child(G11c1)
|
1065 |
+
G11c.add_child(G11c2)
|
1066 |
+
|
1067 |
+
return root
|
1068 |
+
|
1069 |
+
def traverse_tree(node, action_list):
|
1070 |
+
if node.name not in action_list:
|
1071 |
+
action_list.append(node.name)
|
1072 |
+
for child in node.children:
|
1073 |
+
traverse_tree(child, action_list)
|
1074 |
+
|
1075 |
+
class MCTSNode:
|
1076 |
+
__slots__ = [
|
1077 |
+
'state',
|
1078 |
+
'parent',
|
1079 |
+
'action',
|
1080 |
+
'children',
|
1081 |
+
'visit_count',
|
1082 |
+
'value_sum',
|
1083 |
+
'prior',
|
1084 |
+
'cached_policy',
|
1085 |
+
'cached_value',
|
1086 |
+
'thought_node' # Added to keep track of the current thought node
|
1087 |
+
]
|
1088 |
+
|
1089 |
+
def __init__(self, state, thought_node, parent=None, action=None):
|
1090 |
+
self.state = state
|
1091 |
+
self.thought_node = thought_node # Reference to the ThoughtNode
|
1092 |
+
self.parent = parent
|
1093 |
+
self.action = action
|
1094 |
+
self.children = {}
|
1095 |
+
self.visit_count = 0
|
1096 |
+
self.value_sum = 0.0
|
1097 |
+
self.prior = 0.0
|
1098 |
+
self.cached_policy = None
|
1099 |
+
self.cached_value = None
|
1100 |
+
|
1101 |
+
def expand(self, priors):
|
1102 |
+
"""
|
1103 |
+
Expand the node by adding all valid child nodes from the thought tree.
|
1104 |
+
|
1105 |
+
Args:
|
1106 |
+
priors (dict): A dictionary mapping action names to prior probabilities.
|
1107 |
+
"""
|
1108 |
+
for child_thought_node in self.thought_node.children:
|
1109 |
+
action = child_thought_node.name # Action name
|
1110 |
+
if action not in self.children:
|
1111 |
+
# Assume batch size of 1 for individual nodes
|
1112 |
+
child_state = self.state.apply_action(action)
|
1113 |
+
child_node = MCTSNode(
|
1114 |
+
state=child_state,
|
1115 |
+
thought_node=child_thought_node,
|
1116 |
+
parent=self,
|
1117 |
+
action=action
|
1118 |
+
)
|
1119 |
+
child_node.prior = priors.get(action, 1.0 / len(self.thought_node.children)) # Default prior if not provided
|
1120 |
+
self.children[action] = child_node
|
1121 |
+
|
1122 |
+
def is_leaf(self):
|
1123 |
+
return len(self.children) == 0
|
1124 |
+
|
1125 |
+
def ucb_score(self, total_visits, exploration_constant=math.sqrt(2)):
|
1126 |
+
if self.visit_count == 0:
|
1127 |
+
return float('inf')
|
1128 |
+
avg_value = self.value_sum / self.visit_count
|
1129 |
+
exploration_term = exploration_constant * self.prior * math.sqrt(total_visits) / (1 + self.visit_count)
|
1130 |
+
return avg_value + exploration_term
|
1131 |
+
|
1132 |
+
class MCTS:
|
1133 |
+
def __init__(self, prediction_network, dynamics_network, action_encoder, num_iterations=10, exploration_constant=math.sqrt(2)):
|
1134 |
+
self.prediction_network = prediction_network
|
1135 |
+
self.dynamics_network = dynamics_network
|
1136 |
+
self.action_encoder = action_encoder
|
1137 |
+
self.num_iterations = num_iterations
|
1138 |
+
self.exploration_constant = exploration_constant
|
1139 |
+
self.cache = {}
|
1140 |
+
|
1141 |
+
def search(self, root_state):
|
1142 |
+
"""
|
1143 |
+
Perform MCTS starting from the root state.
|
1144 |
+
|
1145 |
+
Args:
|
1146 |
+
root_state (State): The root state from which to start the search.
|
1147 |
+
|
1148 |
+
Returns:
|
1149 |
+
str: The best action to take from the root state.
|
1150 |
+
"""
|
1151 |
+
root_node = MCTSNode(state=root_state, thought_node=root_state.thought_node)
|
1152 |
+
|
1153 |
+
for _ in range(self.num_iterations):
|
1154 |
+
node = self.select(root_node)
|
1155 |
+
value = self.evaluate(node)
|
1156 |
+
self.backpropagate(node, value)
|
1157 |
+
|
1158 |
+
best_action = self.best_action(root_node)
|
1159 |
+
return best_action
|
1160 |
+
|
1161 |
+
def select(self, node):
|
1162 |
+
while not node.is_leaf():
|
1163 |
+
total_visits = sum(child.visit_count for child in node.children.values())
|
1164 |
+
_, node = max(
|
1165 |
+
node.children.items(),
|
1166 |
+
key=lambda item: item[1].ucb_score(total_visits, self.exploration_constant)
|
1167 |
+
)
|
1168 |
+
return node
|
1169 |
+
|
1170 |
+
def evaluate(self, node):
|
1171 |
+
# Use the prediction network to get policy and value estimates
|
1172 |
+
state_representation = node.state.representation # Shape: (batch_size=1, seq_len, state_dim)
|
1173 |
+
policy_logits, value_estimate = self.prediction_network(state_representation)
|
1174 |
+
value_estimate = value_estimate.item() # Convert tensor to scalar
|
1175 |
+
|
1176 |
+
# Convert policy logits to probabilities
|
1177 |
+
policy_probs = F.softmax(policy_logits, dim=-1).squeeze(0) # Shape: (seq_len, action_vocab_size)
|
1178 |
+
# For simplicity, use the last time step's policy
|
1179 |
+
policy_probs = policy_probs[-1] # Shape: (action_vocab_size,)
|
1180 |
+
|
1181 |
+
# Map policy probabilities to the actions available from the current thought node
|
1182 |
+
priors = {}
|
1183 |
+
for child in node.thought_node.children:
|
1184 |
+
action_name = child.name
|
1185 |
+
action_idx = action_to_index.get(action_name, None)
|
1186 |
+
if action_idx is not None and action_idx < policy_probs.size(0):
|
1187 |
+
priors[action_name] = policy_probs[action_idx].item()
|
1188 |
+
else:
|
1189 |
+
priors[action_name] = 1.0 / len(node.thought_node.children) # Uniform prior if not found
|
1190 |
+
|
1191 |
+
# Expand the node
|
1192 |
+
node.expand(priors)
|
1193 |
+
|
1194 |
+
return value_estimate
|
1195 |
+
|
1196 |
+
def backpropagate(self, node, value):
|
1197 |
+
while node is not None:
|
1198 |
+
node.visit_count += 1
|
1199 |
+
node.value_sum += value
|
1200 |
+
node = node.parent
|
1201 |
+
|
1202 |
+
def best_action(self, root_node):
|
1203 |
+
# Select the child with the highest visit count
|
1204 |
+
best_child = max(root_node.children.values(), key=lambda n: n.visit_count)
|
1205 |
+
return best_child.action
|
1206 |
+
|
1207 |
+
class State:
|
1208 |
+
def __init__(self, representation, dynamics_network, action_encoder, thought_node):
|
1209 |
+
"""
|
1210 |
+
Args:
|
1211 |
+
representation (torch.Tensor): Encoded state representation, shape (batch_size, seq_len, state_dim)
|
1212 |
+
dynamics_network (nn.Module): The Dynamics Network to predict next states
|
1213 |
+
action_encoder (nn.Module): The Action Encoder to encode actions
|
1214 |
+
thought_node (ThoughtNode): The current node in the Tree of Thought
|
1215 |
+
"""
|
1216 |
+
self.representation = representation # Shape: (batch_size, seq_len, state_dim)
|
1217 |
+
self.dynamics_network = dynamics_network
|
1218 |
+
self.action_encoder = action_encoder
|
1219 |
+
self.thought_node = thought_node # Current position in the Tree of Thought
|
1220 |
+
|
1221 |
+
def apply_action(self, action):
|
1222 |
+
"""
|
1223 |
+
Apply an action to the current state to get a new state.
|
1224 |
+
|
1225 |
+
Args:
|
1226 |
+
action (str): The action to apply (the name of the ThoughtNode)
|
1227 |
+
|
1228 |
+
Returns:
|
1229 |
+
State: The new state after applying the action
|
1230 |
+
"""
|
1231 |
+
# Find the corresponding child node in the thought tree
|
1232 |
+
next_thought_node = None
|
1233 |
+
for child in self.thought_node.children:
|
1234 |
+
if child.name == action:
|
1235 |
+
next_thought_node = child
|
1236 |
+
break
|
1237 |
+
if next_thought_node is None:
|
1238 |
+
raise ValueError(f"Action '{action}' is not valid from the current thought node.")
|
1239 |
+
|
1240 |
+
# Encode action
|
1241 |
+
action_index = torch.tensor([[action_to_index[action]]], device=self.representation.device)
|
1242 |
+
action_embedding = self.action_encoder(action_index)
|
1243 |
+
|
1244 |
+
# Predict the next state using the Dynamics Network
|
1245 |
+
next_state_representation = self.dynamics_network(self.representation, action_embedding)
|
1246 |
+
|
1247 |
+
return State(
|
1248 |
+
representation=next_state_representation,
|
1249 |
+
dynamics_network=self.dynamics_network,
|
1250 |
+
action_encoder=self.action_encoder,
|
1251 |
+
thought_node=next_thought_node
|
1252 |
+
)
|
1253 |
+
|
1254 |
+
class PPOAgent:
|
1255 |
+
def __init__(self, policy_network, optimizer, clip_epsilon=0.2, entropy_coef=0.01, value_coef=0.5):
|
1256 |
+
self.policy_network = policy_network
|
1257 |
+
self.optimizer = optimizer
|
1258 |
+
self.clip_epsilon = clip_epsilon
|
1259 |
+
self.entropy_coef = entropy_coef
|
1260 |
+
self.value_coef = value_coef
|
1261 |
+
|
1262 |
+
def compute_loss(self, states, old_log_probs, actions, returns, advantages):
|
1263 |
+
# Get policy logits and value estimates
|
1264 |
+
policy_logits, value_estimates = self.policy_network(states)
|
1265 |
+
batch_size, seq_len, num_actions = policy_logits.size()
|
1266 |
+
|
1267 |
+
# Flatten tensors using reshape
|
1268 |
+
policy_logits = policy_logits.reshape(-1, num_actions) # Shape: (batch_size * seq_len, num_actions)
|
1269 |
+
value_estimates = value_estimates.view(-1)
|
1270 |
+
actions = actions.reshape(-1) # Shape: (batch_size * seq_len)
|
1271 |
+
old_log_probs = old_log_probs.reshape(-1) # Shape: (batch_size * seq_len)
|
1272 |
+
returns = returns.view(-1)
|
1273 |
+
advantages = advantages.reshape(-1) # Shape: (batch_size * seq_len)
|
1274 |
+
|
1275 |
+
# Ensure value_estimates and returns are the same size
|
1276 |
+
if value_estimates.size() != returns.size():
|
1277 |
+
print(f"Shape mismatch: value_estimates shape: {value_estimates.size()}, returns shape: {returns.size()}")
|
1278 |
+
value_estimates = value_estimates[:returns.size(0)]
|
1279 |
+
|
1280 |
+
# Compute new log probabilities
|
1281 |
+
new_log_probs_all = F.log_softmax(policy_logits, dim=-1) # Shape: (batch_size * seq_len, num_actions)
|
1282 |
+
new_log_probs = new_log_probs_all.gather(1, actions.unsqueeze(-1)).squeeze(-1) # Shape: (batch_size * seq_len)
|
1283 |
+
|
1284 |
+
# Compute ratios
|
1285 |
+
ratios = torch.exp(new_log_probs - old_log_probs)
|
1286 |
+
|
1287 |
+
# PPO surrogate loss
|
1288 |
+
surr1 = ratios * advantages
|
1289 |
+
surr2 = torch.clamp(ratios, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages
|
1290 |
+
policy_loss = -torch.min(surr1, surr2).mean()
|
1291 |
+
|
1292 |
+
# Value loss
|
1293 |
+
value_loss = F.mse_loss(value_estimates, returns)
|
1294 |
+
|
1295 |
+
# Entropy loss
|
1296 |
+
entropy = -(new_log_probs * torch.exp(new_log_probs)).mean()
|
1297 |
+
|
1298 |
+
# Total loss
|
1299 |
+
total_loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy
|
1300 |
+
return total_loss
|
1301 |
+
|
1302 |
+
|
1303 |
+
|
1304 |
+
def infer(query, world_model_components, root_thought_node, tokenizer, max_length=20, inference_mode='world_model'):
|
1305 |
+
"""
|
1306 |
+
Perform inference given a query, utilizing the Tree of Thought and MCTS.
|
1307 |
+
|
1308 |
+
Args:
|
1309 |
+
query (str): The input query or prompt.
|
1310 |
+
world_model_components (tuple): Tuple containing the model components.
|
1311 |
+
root_thought_node (ThoughtNode): The root node of the Tree of Thought.
|
1312 |
+
tokenizer (transformers.PreTrainedTokenizer): The tokenizer used.
|
1313 |
+
max_length (int): Maximum length for the generated sequence.
|
1314 |
+
inference_mode (str): Inference mode ('world_model', 'without_world_model', 'world_model_tree_of_thought')
|
1315 |
+
|
1316 |
+
Returns:
|
1317 |
+
List[str] or str: The sequence of actions (thoughts) selected or generated text.
|
1318 |
+
"""
|
1319 |
+
representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer = world_model_components
|
1320 |
+
|
1321 |
+
# Tokenize and encode the query
|
1322 |
+
input_ids = tokenizer.encode(query, return_tensors='pt').to(device)
|
1323 |
+
attention_mask = (input_ids != tokenizer.pad_token_id).long()
|
1324 |
+
|
1325 |
+
if inference_mode == 'without_world_model':
|
1326 |
+
# Directly use the transformer model to generate text
|
1327 |
+
with torch.no_grad():
|
1328 |
+
generated_ids, entropies, variances = model_transformer.generate(src=input_ids, tokenizer=tokenizer, max_length=max_length, temperature=args.temperature)
|
1329 |
+
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
1330 |
+
return generated_text
|
1331 |
+
|
1332 |
+
else:
|
1333 |
+
# Use the world model components
|
1334 |
+
with torch.no_grad():
|
1335 |
+
transformer_output = model_transformer(input_ids, input_ids)
|
1336 |
+
# Get the initial state representation
|
1337 |
+
initial_representation = representation_network(transformer_output) # Shape: (batch_size=1, seq_len, state_dim)
|
1338 |
+
initial_state = State(
|
1339 |
+
representation=initial_representation,
|
1340 |
+
dynamics_network=dynamics_network,
|
1341 |
+
action_encoder=action_encoder,
|
1342 |
+
thought_node=root_thought_node
|
1343 |
+
)
|
1344 |
+
|
1345 |
+
if inference_mode == 'world_model_tree_of_thought':
|
1346 |
+
# Use MCTS with Tree of Thought
|
1347 |
+
mcts = MCTS(prediction_network, dynamics_network, action_encoder, num_iterations=args.mcts_iterations, exploration_constant=args.mcts_exploration_constant)
|
1348 |
+
current_state = initial_state
|
1349 |
+
thought_sequence = []
|
1350 |
+
|
1351 |
+
for _ in range(max_length):
|
1352 |
+
best_action = mcts.search(current_state)
|
1353 |
+
thought_sequence.append(best_action)
|
1354 |
+
|
1355 |
+
# Apply the best action to get the next state
|
1356 |
+
current_state = current_state.apply_action(best_action)
|
1357 |
+
|
1358 |
+
# Check if we've reached a leaf node (no further actions)
|
1359 |
+
if len(current_state.thought_node.children) == 0:
|
1360 |
+
break
|
1361 |
+
|
1362 |
+
return thought_sequence
|
1363 |
+
else:
|
1364 |
+
# Use the world model without Tree of Thought
|
1365 |
+
# For simplicity, we will generate actions based on the prediction network
|
1366 |
+
policy_logits, _ = prediction_network(initial_state.representation)
|
1367 |
+
policy_probs = F.softmax(policy_logits, dim=-1)
|
1368 |
+
# Select actions with highest probabilities
|
1369 |
+
top_actions = torch.argmax(policy_probs, dim=-1)
|
1370 |
+
generated_actions = [index_to_action[idx.item()] for idx in top_actions[0]]
|
1371 |
+
return generated_actions
|
1372 |
+
|
1373 |
+
def train_epoch_world_model(world_model_components, train_loader, optimizer, scheduler, scaler, args, model_transformer, state_dim, embed_dim, input_dim):
|
1374 |
+
representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, _ = world_model_components
|
1375 |
+
representation_network.train()
|
1376 |
+
dynamics_network.train()
|
1377 |
+
prediction_network.train()
|
1378 |
+
action_encoder.train()
|
1379 |
+
ppo_agent.policy_network.train()
|
1380 |
+
|
1381 |
+
total_loss = 0.0
|
1382 |
+
optimizer.zero_grad()
|
1383 |
+
print(f"Starting World Model training epoch with {len(train_loader)}batches...")
|
1384 |
+
|
1385 |
+
for i, batch in enumerate(train_loader):
|
1386 |
+
print(f"Processing batch {i+1}/{len(train_loader)}...")
|
1387 |
+
|
1388 |
+
# Move batches to the device
|
1389 |
+
src_batch = batch['input_ids'].to(device)
|
1390 |
+
tgt_batch = batch['labels'].to(device)
|
1391 |
+
|
1392 |
+
with torch.amp.autocast(device_type='cuda'):
|
1393 |
+
print("Forward pass through Transformer (frozen)...")
|
1394 |
+
with torch.no_grad():
|
1395 |
+
transformer_output = model_transformer(src_batch, tgt_batch[:, :-1])
|
1396 |
+
|
1397 |
+
# World Model - Representation
|
1398 |
+
state_representation = representation_network(transformer_output) # On GPU
|
1399 |
+
|
1400 |
+
# For simplicity, let's assume true actions are provided (e.g., next tokens)
|
1401 |
+
true_actions = tgt_batch[:, :-1] # Shape: (batch_size, seq_len)
|
1402 |
+
action_sequences = true_actions
|
1403 |
+
|
1404 |
+
# Get action embeddings
|
1405 |
+
action_embeddings = action_encoder(action_sequences) # Shape: (batch_size, seq_len, embed_dim)
|
1406 |
+
|
1407 |
+
# Apply dynamics network
|
1408 |
+
predicted_next_state_batch = dynamics_network(state_representation, action_embeddings) # Shape: (batch_size, seq_len, state_dim)
|
1409 |
+
|
1410 |
+
# Prediction Network - Policy logits and value
|
1411 |
+
policy_logits, value_estimates = prediction_network(predicted_next_state_batch)
|
1412 |
+
# value_estimates now has shape (batch_size, seq_len)
|
1413 |
+
|
1414 |
+
# Define true_policy and true_value as placeholders on the GPU
|
1415 |
+
true_policy = F.one_hot(true_actions, num_classes=input_dim).float() # Shape: (batch_size, seq_len, input_dim)
|
1416 |
+
true_value = torch.zeros_like(value_estimates).to(device)
|
1417 |
+
|
1418 |
+
# Compute PPO loss
|
1419 |
+
actions_selected = true_actions # Shape: (batch_size, seq_len)
|
1420 |
+
old_log_probs = torch.zeros_like(actions_selected, dtype=torch.float32).to(device)
|
1421 |
+
returns = torch.zeros_like(actions_selected, dtype=torch.float32).to(device)
|
1422 |
+
advantages = torch.zeros_like(actions_selected, dtype=torch.float32).to(device)
|
1423 |
+
|
1424 |
+
# Compute PPO loss using states
|
1425 |
+
ppo_loss = ppo_agent.compute_loss(state_representation, old_log_probs, actions_selected, returns, advantages)
|
1426 |
+
|
1427 |
+
# Compute InfoNCE Loss
|
1428 |
+
z_i = state_representation.view(-1, state_dim) # Shape: (batch_size * seq_len, state_dim)
|
1429 |
+
z_j = F.dropout(z_i, p=0.1, training=True)
|
1430 |
+
info_nce = InfoNCE_Loss()(z_i, z_j)
|
1431 |
+
|
1432 |
+
# Compute other losses
|
1433 |
+
covariance = CovarianceRegularization()(predicted_next_state_batch.view(-1, predicted_next_state_batch.size(-1)))
|
1434 |
+
dynamics_loss = DynamicsPerformanceLoss()(state_representation, predicted_next_state_batch)
|
1435 |
+
perturbed_next_state = predicted_next_state_batch + torch.randn_like(predicted_next_state_batch) * 0.01
|
1436 |
+
thought_loss = ThoughtConsistencyLoss()(predicted_next_state_batch, perturbed_next_state)
|
1437 |
+
pv_loss = PolicyValueJointLoss()(policy_logits, true_policy, value_estimates.squeeze(-1), true_value.squeeze(-1))
|
1438 |
+
action_diversity = ActionDiversityReward()(action_embeddings.view(-1, embed_dim))
|
1439 |
+
mcts_best_values = torch.zeros(actions_selected.size(0)).to(device)
|
1440 |
+
etv = ExpectedThoughtValueLoss()(mcts_best_values)
|
1441 |
+
visit_counts = torch.ones(actions_selected.size(0), policy_logits.size(-1)).to(device)
|
1442 |
+
exploration = ExplorationRegularization()(visit_counts)
|
1443 |
+
old_policy = F.softmax(policy_logits.detach(), dim=-1)
|
1444 |
+
new_policy = F.softmax(policy_logits, dim=-1)
|
1445 |
+
kl_loss = KL_DivergenceLoss()(old_policy, new_policy)
|
1446 |
+
|
1447 |
+
# Total Loss
|
1448 |
+
loss = (
|
1449 |
+
ppo_loss +
|
1450 |
+
info_nce +
|
1451 |
+
covariance +
|
1452 |
+
dynamics_loss +
|
1453 |
+
thought_loss +
|
1454 |
+
pv_loss +
|
1455 |
+
action_diversity +
|
1456 |
+
etv +
|
1457 |
+
exploration +
|
1458 |
+
kl_loss
|
1459 |
+
)
|
1460 |
+
loss = loss / args.accumulation_steps
|
1461 |
+
|
1462 |
+
print("Backward pass...")
|
1463 |
+
scaler.scale(loss).backward()
|
1464 |
+
|
1465 |
+
if (i + 1) % args.accumulation_steps == 0 or (i + 1) == len(train_loader):
|
1466 |
+
print("Gradient clipping...")
|
1467 |
+
scaler.unscale_(optimizer)
|
1468 |
+
torch.nn.utils.clip_grad_norm_(
|
1469 |
+
[param for group in optimizer.param_groups for param in group['params']],
|
1470 |
+
args.max_grad_norm
|
1471 |
+
)
|
1472 |
+
|
1473 |
+
print("Optimizer step...")
|
1474 |
+
scaler.step(optimizer)
|
1475 |
+
scaler.update()
|
1476 |
+
|
1477 |
+
print("Zeroing gradients...")
|
1478 |
+
optimizer.zero_grad()
|
1479 |
+
|
1480 |
+
print("Updating learning rate...")
|
1481 |
+
scheduler.step()
|
1482 |
+
|
1483 |
+
total_loss += loss.item() * args.accumulation_steps
|
1484 |
+
print(f"Batch {i+1} completed. Current loss: {loss.item():.4f}")
|
1485 |
+
|
1486 |
+
avg_loss = total_loss / len(train_loader)
|
1487 |
+
print(f"World Model training epoch completed. Average loss: {avg_loss:.4f}")
|
1488 |
+
return avg_loss
|
1489 |
+
|
1490 |
+
|
1491 |
+
def train_epoch_language_model(model, train_loader, optimizer, scheduler, scaler, args):
|
1492 |
+
model.train()
|
1493 |
+
total_loss = 0.0
|
1494 |
+
optimizer.zero_grad()
|
1495 |
+
print(f"Starting Language Model training epoch with {len(train_loader)} batches...")
|
1496 |
+
|
1497 |
+
for i, batch in enumerate(train_loader):
|
1498 |
+
input_ids = batch['input_ids'].to(device)
|
1499 |
+
labels = batch['labels'].to(device)
|
1500 |
+
|
1501 |
+
with autocast():
|
1502 |
+
outputs = model(input_ids, input_ids)
|
1503 |
+
logits = outputs.view(-1, outputs.size(-1))
|
1504 |
+
labels = labels.view(-1)
|
1505 |
+
loss = F.cross_entropy(logits, labels, ignore_index=model.embedding.padding_idx)
|
1506 |
+
loss = loss / args.accumulation_steps
|
1507 |
+
|
1508 |
+
scaler.scale(loss).backward()
|
1509 |
+
|
1510 |
+
if (i + 1) % args.accumulation_steps == 0 or (i + 1) == len(train_loader):
|
1511 |
+
scaler.unscale_(optimizer)
|
1512 |
+
torch.nn.utils.clip_grad_norm_(
|
1513 |
+
[param for group in optimizer.param_groups for param in group['params']],
|
1514 |
+
args.max_grad_norm
|
1515 |
+
)
|
1516 |
+
scaler.step(optimizer)
|
1517 |
+
scaler.update()
|
1518 |
+
optimizer.zero_grad()
|
1519 |
+
scheduler.step()
|
1520 |
+
|
1521 |
+
total_loss += loss.item() * args.accumulation_steps
|
1522 |
+
print(f"Batch {i + 1} completed. Current loss: {loss.item():.4f}")
|
1523 |
+
|
1524 |
+
avg_loss = total_loss / len(train_loader)
|
1525 |
+
print(f"Language Model training epoch completed. Average loss: {avg_loss:.4f}")
|
1526 |
+
return avg_loss
|
1527 |
+
|
1528 |
+
|
1529 |
+
|
1530 |
+
def main():
|
1531 |
+
args = parse_args()
|
1532 |
+
print("Arguments parsed successfully.")
|
1533 |
+
|
1534 |
+
# Create save directory
|
1535 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
1536 |
+
print(f"Save directory created: {args.save_dir}")
|
1537 |
+
|
1538 |
+
# Load tokenizer
|
1539 |
+
print("Loading tokenizer...")
|
1540 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
1541 |
+
if tokenizer.pad_token is None:
|
1542 |
+
tokenizer.pad_token = tokenizer.eos_token
|
1543 |
+
print("Tokenizer loaded successfully.")
|
1544 |
+
|
1545 |
+
# Define padding_idx and input dimension based on tokenizer
|
1546 |
+
padding_idx = tokenizer.pad_token_id
|
1547 |
+
input_dim = len(tokenizer)
|
1548 |
+
|
1549 |
+
# Initialize the Transformer model on GPU
|
1550 |
+
print("Initializing Transformer model...")
|
1551 |
+
model_transformer = Transformer(
|
1552 |
+
input_dim=input_dim,
|
1553 |
+
d_model=128,
|
1554 |
+
num_heads=4,
|
1555 |
+
num_layers=4,
|
1556 |
+
d_ff=256,
|
1557 |
+
num_experts=2,
|
1558 |
+
output_dim=input_dim,
|
1559 |
+
dropout=0.1,
|
1560 |
+
top_k=2
|
1561 |
+
).to(device)
|
1562 |
+
model_transformer.train()
|
1563 |
+
print("Transformer model initialized on device.")
|
1564 |
+
|
1565 |
+
# Define model parameters (adjusted for speed)
|
1566 |
+
d_model = 128
|
1567 |
+
state_dim = 128
|
1568 |
+
action_dim = d_model
|
1569 |
+
hidden_dim = 256
|
1570 |
+
vocab_dim = input_dim
|
1571 |
+
embed_dim = d_model
|
1572 |
+
|
1573 |
+
# Define World Model components
|
1574 |
+
representation_network = RepresentationNetwork(vocab_dim, d_model, state_dim).to(device)
|
1575 |
+
dynamics_network = DynamicsNetwork(state_dim, action_dim, hidden_dim).to(device)
|
1576 |
+
prediction_network = PredictionNetwork(state_dim, input_dim, 1).to(device)
|
1577 |
+
action_encoder = ActionEncoder(input_dim, action_dim).to(device)
|
1578 |
+
|
1579 |
+
# Initialize PPO Agent
|
1580 |
+
ppo_agent = PPOAgent(
|
1581 |
+
policy_network=prediction_network,
|
1582 |
+
optimizer=optim.AdamW(prediction_network.parameters(), lr=args.learning_rate),
|
1583 |
+
clip_epsilon=0.2,
|
1584 |
+
entropy_coef=0.01,
|
1585 |
+
value_coef=0.5
|
1586 |
+
)
|
1587 |
+
|
1588 |
+
# Bundle World Model components
|
1589 |
+
world_model_components = (representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer)
|
1590 |
+
|
1591 |
+
if args.mode == 'train':
|
1592 |
+
print("Loading and preprocessing data...")
|
1593 |
+
train_loader, eval_loader = load_data(args, tokenizer)
|
1594 |
+
print("Data loaded and preprocessed successfully.")
|
1595 |
+
|
1596 |
+
# Optimizer and Scheduler
|
1597 |
+
optimizer = optim.AdamW(
|
1598 |
+
list(representation_network.parameters()) +
|
1599 |
+
list(dynamics_network.parameters()) +
|
1600 |
+
list(prediction_network.parameters()) +
|
1601 |
+
list(action_encoder.parameters()),
|
1602 |
+
lr=args.learning_rate, weight_decay=args.weight_decay
|
1603 |
+
) if args.train_mode == 'world_model' else optim.AdamW(model_transformer.parameters(), lr=args.learning_rate)
|
1604 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=args.num_epochs)
|
1605 |
+
scaler = GradScaler()
|
1606 |
+
|
1607 |
+
print(f"Starting {args.train_mode} training...")
|
1608 |
+
|
1609 |
+
for epoch in range(args.num_epochs):
|
1610 |
+
if args.train_mode == 'world_model':
|
1611 |
+
avg_loss = train_epoch_world_model(
|
1612 |
+
world_model_components,
|
1613 |
+
train_loader,
|
1614 |
+
optimizer,
|
1615 |
+
scheduler,
|
1616 |
+
scaler,
|
1617 |
+
args,
|
1618 |
+
model_transformer,
|
1619 |
+
state_dim,
|
1620 |
+
embed_dim,
|
1621 |
+
input_dim
|
1622 |
+
)
|
1623 |
+
else:
|
1624 |
+
avg_loss = train_epoch_language_model(
|
1625 |
+
model_transformer,
|
1626 |
+
train_loader,
|
1627 |
+
optimizer,
|
1628 |
+
scheduler,
|
1629 |
+
scaler,
|
1630 |
+
args
|
1631 |
+
)
|
1632 |
+
|
1633 |
+
print(f"{args.train_mode.capitalize()} training epoch {epoch + 1} completed. Average loss: {avg_loss:.4f}")
|
1634 |
+
|
1635 |
+
if args.train_mode == 'world_model':
|
1636 |
+
save_all_models(model_transformer, representation_network, dynamics_network, prediction_network, action_encoder, args.save_dir, epoch + 1)
|
1637 |
+
print(f"Models saved for epoch {epoch + 1}")
|
1638 |
+
else:
|
1639 |
+
torch.save(model_transformer.state_dict(), os.path.join(args.save_dir, f'language_model_epoch_{epoch + 1}.pt'))
|
1640 |
+
print(f"Language model saved for epoch {epoch + 1}")
|
1641 |
+
|
1642 |
+
print("Training completed.")
|
1643 |
+
|
1644 |
+
elif args.mode == 'inference':
|
1645 |
+
# Build Tree of Thought if needed
|
1646 |
+
tree_root = build_tree_of_thought()
|
1647 |
+
# Generate action list
|
1648 |
+
action_list = []
|
1649 |
+
traverse_tree(tree_root, action_list)
|
1650 |
+
|
1651 |
+
# Create mappings
|
1652 |
+
global action_to_index, index_to_action
|
1653 |
+
action_to_index = {action: idx for idx, action in enumerate(action_list)}
|
1654 |
+
index_to_action = {idx: action for action, idx in action_to_index.items()}
|
1655 |
+
action_vocab_size = len(action_list)
|
1656 |
+
|
1657 |
+
# Update action encoder and prediction network with new vocab size
|
1658 |
+
action_encoder = ActionEncoder(action_vocab_size, action_dim).to(device)
|
1659 |
+
prediction_network = PredictionNetwork(state_dim, action_vocab_size, 1).to(device)
|
1660 |
+
|
1661 |
+
# Load the saved models
|
1662 |
+
# Assuming the models are saved after training
|
1663 |
+
# You need to adjust the paths and epoch numbers as necessary
|
1664 |
+
model_transformer.load_state_dict(torch.load(os.path.join(args.save_dir, 'transformer_model_epoch_2.pt')))
|
1665 |
+
representation_network.load_state_dict(torch.load(os.path.join(args.save_dir, 'representation_network_epoch_2.pt')))
|
1666 |
+
dynamics_network.load_state_dict(torch.load(os.path.join(args.save_dir, 'dynamics_network_epoch_2.pt')))
|
1667 |
+
saved_state_dict = torch.load(os.path.join(args.save_dir, 'prediction_network_epoch_2.pt'))
|
1668 |
+
prediction_network.policy_head = nn.Linear(prediction_network.state_dim, 50257) # Update to match saved model size
|
1669 |
+
prediction_network.load_state_dict(saved_state_dict, strict=False)
|
1670 |
+
|
1671 |
+
# Resize `policy_head` back to 158 after loading
|
1672 |
+
prediction_network.policy_head = nn.Linear(prediction_network.state_dim, 158).to(device)
|
1673 |
+
|
1674 |
+
|
1675 |
+
|
1676 |
+
action_encoder.load_state_dict(torch.load(os.path.join(args.save_dir, 'action_encoder_epoch_2.pt')))
|
1677 |
+
|
1678 |
+
# Prepare the components
|
1679 |
+
world_model_components = (representation_network, dynamics_network, prediction_network, action_encoder, ppo_agent, model_transformer)
|
1680 |
+
|
1681 |
+
# Perform inference
|
1682 |
+
if not args.query:
|
1683 |
+
args.query = input("Please enter your query: ")
|
1684 |
+
|
1685 |
+
result = infer(args.query, world_model_components, tree_root, tokenizer, inference_mode=args.inference_mode)
|
1686 |
+
|
1687 |
+
if args.inference_mode == 'without_world_model':
|
1688 |
+
print("Generated Text:")
|
1689 |
+
print(result)
|
1690 |
+
else:
|
1691 |
+
print("Generated Thought Sequence:")
|
1692 |
+
for thought in result:
|
1693 |
+
print(thought)
|
1694 |
+
|
1695 |
+
if __name__ == '__main__':
|
1696 |
+
main()
|
main_menu.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# main_menu.py
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import sys
|
5 |
+
from train_agent import train_agent
|
6 |
+
from test_agent import TestAgent, run_test_session
|
7 |
+
from lightbulb import main as world_model_main
|
8 |
+
|
9 |
+
def parse_main_args():
|
10 |
+
parser = argparse.ArgumentParser(description="Main Menu for Selecting Tasks")
|
11 |
+
parser.add_argument('--task', type=str, choices=['train_llm_world', 'train_agent', 'test_agent'],
|
12 |
+
required=True, help='Choose task to execute: train_llm_world, train_agent, test_agent')
|
13 |
+
# Optional arguments for more granular control
|
14 |
+
parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name for LLM')
|
15 |
+
parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name for training')
|
16 |
+
parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
|
17 |
+
parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training')
|
18 |
+
parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs for training')
|
19 |
+
parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length for training')
|
20 |
+
parser.add_argument('--mode', type=str, choices=['train', 'inference'], default='train', help='Train or inference mode for LLM')
|
21 |
+
parser.add_argument('--query', type=str, default='', help='Query for the test_agent')
|
22 |
+
return parser.parse_args()
|
23 |
+
|
24 |
+
def main():
|
25 |
+
# Parse arguments for the main function
|
26 |
+
args = parse_main_args()
|
27 |
+
|
28 |
+
# Execute tasks based on user input
|
29 |
+
if args.task == 'train_llm_world':
|
30 |
+
print("Starting LLM and World Model Training...")
|
31 |
+
# Directly call the world model main function
|
32 |
+
sys.argv = ['lightbulb.py', '--mode', args.mode, '--model_name', args.model_name,
|
33 |
+
'--dataset_name', args.dataset_name, '--dataset_config', args.dataset_config,
|
34 |
+
'--batch_size', str(args.batch_size), '--num_epochs', str(args.num_epochs),
|
35 |
+
'--max_length', str(args.max_length)]
|
36 |
+
world_model_main()
|
37 |
+
|
38 |
+
elif args.task == 'train_agent':
|
39 |
+
print("Starting Agent Training...")
|
40 |
+
# Call the train_agent function from train_agent.py
|
41 |
+
from twisted.internet import reactor, task
|
42 |
+
d = task.deferLater(reactor, 0, train_agent)
|
43 |
+
d.addErrback(lambda failure: print(f"An error occurred: {failure}", exc_info=True))
|
44 |
+
d.addBoth(lambda _: reactor.stop())
|
45 |
+
reactor.run()
|
46 |
+
|
47 |
+
elif args.task == 'test_agent':
|
48 |
+
print("Starting Test Agent...")
|
49 |
+
test_agent = TestAgent()
|
50 |
+
if args.query:
|
51 |
+
# Directly process a single query
|
52 |
+
result = test_agent.process_query(args.query)
|
53 |
+
print("\nAgent's response:")
|
54 |
+
print(result)
|
55 |
+
else:
|
56 |
+
# Run the interactive session
|
57 |
+
reactor.callWhenRunning(run_test_session)
|
58 |
+
reactor.run()
|
59 |
+
|
60 |
+
if __name__ == "__main__":
|
61 |
+
main()
|
mcts.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# mcts.py
|
2 |
+
import math
|
3 |
+
import random
|
4 |
+
from nltk.corpus import wordnet
|
5 |
+
from scrapy.crawler import CrawlerRunner
|
6 |
+
from scrapy.utils.log import configure_logging
|
7 |
+
from scrapy.utils.project import get_project_settings
|
8 |
+
from twisted.internet import reactor, defer
|
9 |
+
from scrapy import signals
|
10 |
+
import logging
|
11 |
+
from my_search_engine.my_search_engine.spiders.search_spider import SearchSpider
|
12 |
+
from sentence_transformers import SentenceTransformer, util
|
13 |
+
from ranking import train_ranking_model
|
14 |
+
import time
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
class MCTSNode:
|
19 |
+
def __init__(self, state, parent=None, action=None):
|
20 |
+
self.state = state
|
21 |
+
self.parent = parent
|
22 |
+
self.action = action
|
23 |
+
self.children = []
|
24 |
+
self.visits = 0
|
25 |
+
self.value = 0
|
26 |
+
self.ucb_score = float('inf')
|
27 |
+
|
28 |
+
def is_leaf(self):
|
29 |
+
return len(self.children) == 0
|
30 |
+
|
31 |
+
def add_child(self, child_state, action=None):
|
32 |
+
child_node = MCTSNode(child_state, parent=self, action=action)
|
33 |
+
self.children.append(child_node)
|
34 |
+
return child_node
|
35 |
+
|
36 |
+
def update(self, reward):
|
37 |
+
self.visits += 1
|
38 |
+
self.value += reward
|
39 |
+
if self.parent: # Only calculate UCB if not root
|
40 |
+
self.ucb_score = self.calculate_ucb()
|
41 |
+
|
42 |
+
def calculate_ucb(self, exploration_weight=1.41):
|
43 |
+
if self.visits == 0 or not self.parent:
|
44 |
+
return float('inf')
|
45 |
+
exploitation = self.value / self.visits
|
46 |
+
exploration = exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits)
|
47 |
+
return exploitation + exploration
|
48 |
+
|
49 |
+
class MCTS:
|
50 |
+
def __init__(self, initial_state, num_simulations=20, exploration_weight=1.41):
|
51 |
+
self.root = MCTSNode(initial_state)
|
52 |
+
self.num_simulations = num_simulations
|
53 |
+
self.exploration_weight = exploration_weight
|
54 |
+
self.query_model = SentenceTransformer('all-MiniLM-L6-v2')
|
55 |
+
self.results = []
|
56 |
+
self.crawler_runner = CrawlerRunner(get_project_settings())
|
57 |
+
self.initial_state = initial_state
|
58 |
+
self.num_iterations = 5
|
59 |
+
|
60 |
+
def select(self, node):
|
61 |
+
while not node.is_leaf():
|
62 |
+
if not node.children:
|
63 |
+
return node
|
64 |
+
node = max(node.children, key=lambda c: c.calculate_ucb(self.exploration_weight))
|
65 |
+
return node
|
66 |
+
|
67 |
+
def expand(self, node):
|
68 |
+
if node.visits == 0:
|
69 |
+
return node
|
70 |
+
possible_refinements = self.get_possible_refinements(node.state)
|
71 |
+
for refinement in possible_refinements:
|
72 |
+
node.add_child(refinement)
|
73 |
+
return random.choice(node.children) if node.children else node
|
74 |
+
|
75 |
+
def calculate_combined_reward(self, ranking_score, state):
|
76 |
+
state_length_reward = len(state) / 100
|
77 |
+
if state:
|
78 |
+
query_complexity = len(set(state.split())) / len(state.split())
|
79 |
+
else:
|
80 |
+
query_complexity = 0
|
81 |
+
semantic_similarity = self.calculate_semantic_similarity(state, self.root.state)
|
82 |
+
|
83 |
+
combined_reward = (
|
84 |
+
0.5 * ranking_score +
|
85 |
+
0.2 * state_length_reward +
|
86 |
+
0.2 * query_complexity +
|
87 |
+
0.1 * semantic_similarity
|
88 |
+
)
|
89 |
+
return combined_reward
|
90 |
+
|
91 |
+
def calculate_semantic_similarity(self, query1, query2):
|
92 |
+
embedding1 = self.query_model.encode(query1)
|
93 |
+
embedding2 = self.query_model.encode(query2)
|
94 |
+
return util.pytorch_cos_sim(embedding1, embedding2).item()
|
95 |
+
|
96 |
+
def backpropagate(self, node, reward):
|
97 |
+
while node is not None:
|
98 |
+
node.update(reward)
|
99 |
+
node = node.parent
|
100 |
+
|
101 |
+
def best_action(self):
|
102 |
+
if not self.root.children:
|
103 |
+
return self.root
|
104 |
+
|
105 |
+
def score(node):
|
106 |
+
if node.visits == 0:
|
107 |
+
return float('-inf')
|
108 |
+
return node.value / node.visits
|
109 |
+
|
110 |
+
return max(self.root.children, key=score)
|
111 |
+
|
112 |
+
def refine_query(self, query):
|
113 |
+
words = query.split()
|
114 |
+
refined_query = []
|
115 |
+
|
116 |
+
for word in words:
|
117 |
+
if word.lower() not in {"how", "to", "get", "an", "the", "and", "or", "of", "build"}:
|
118 |
+
synonyms = wordnet.synsets(word)
|
119 |
+
if synonyms:
|
120 |
+
synonym_words = [lemma.name() for lemma in synonyms[0].lemmas()
|
121 |
+
if len(lemma.name().split()) == 1 and word != lemma.name()]
|
122 |
+
if synonym_words:
|
123 |
+
refined_query.append(random.choice(synonym_words))
|
124 |
+
else:
|
125 |
+
refined_query.append(word)
|
126 |
+
else:
|
127 |
+
refined_query.append(word)
|
128 |
+
else:
|
129 |
+
refined_query.append(word)
|
130 |
+
|
131 |
+
possible_intent_keywords = ['guide', 'tutorial', 'LLM', 'language model', 'NLP', 'GPT']
|
132 |
+
refined_query.append(random.choice(possible_intent_keywords))
|
133 |
+
|
134 |
+
return ' '.join(refined_query)
|
135 |
+
|
136 |
+
def get_related_queries(self, query):
|
137 |
+
query_embedding = self.query_model.encode(query)
|
138 |
+
refined_query_variations = [query]
|
139 |
+
words_to_avoid = {'how', 'to', 'get'}
|
140 |
+
words = query.split()
|
141 |
+
|
142 |
+
for word in words:
|
143 |
+
if word.lower() not in words_to_avoid:
|
144 |
+
synonyms = wordnet.synsets(word)
|
145 |
+
if synonyms:
|
146 |
+
synonym_words = [lemma.name() for lemma in synonyms[0].lemmas() if lemma.name() != word]
|
147 |
+
if synonym_words:
|
148 |
+
refined_query = query.replace(word, random.choice(synonym_words))
|
149 |
+
refined_query_variations.append(refined_query)
|
150 |
+
|
151 |
+
refined_query_variations = list(set(refined_query_variations))
|
152 |
+
refined_query_embeddings = [self.query_model.encode(variation) for variation in refined_query_variations]
|
153 |
+
similarity_scores = util.pytorch_cos_sim(query_embedding, refined_query_embeddings).tolist()[0]
|
154 |
+
|
155 |
+
similarity_threshold = 0.8
|
156 |
+
filtered_queries = [variation for idx, variation in enumerate(refined_query_variations)
|
157 |
+
if similarity_scores[idx] > similarity_threshold]
|
158 |
+
|
159 |
+
return filtered_queries[:2] if filtered_queries else [query]
|
160 |
+
|
161 |
+
def get_possible_refinements(self, query):
|
162 |
+
refined_queries = self.get_related_queries(query)
|
163 |
+
return refined_queries + [self.refine_query(query)]
|
164 |
+
|
165 |
+
@defer.inlineCallbacks
|
166 |
+
def web_search(self, query, search_sites=None):
|
167 |
+
if not query.strip():
|
168 |
+
logger.error("Cannot perform web search with an empty query.")
|
169 |
+
defer.returnValue([])
|
170 |
+
|
171 |
+
logger.info(f"Starting web search for query: {query}")
|
172 |
+
configure_logging(install_root_handler=False)
|
173 |
+
logging.basicConfig(level=logging.INFO)
|
174 |
+
|
175 |
+
results = []
|
176 |
+
|
177 |
+
def crawler_results(item, response, spider):
|
178 |
+
logger.info(f"Received result: {item['title']}")
|
179 |
+
results.append(item)
|
180 |
+
|
181 |
+
try:
|
182 |
+
crawler = self.crawler_runner.create_crawler(SearchSpider)
|
183 |
+
crawler.signals.connect(crawler_results, signal=signals.item_scraped)
|
184 |
+
|
185 |
+
# Start crawling, passing query and search_sites to the spider
|
186 |
+
yield self.crawler_runner.crawl(crawler, query=query, search_sites=search_sites)
|
187 |
+
except Exception as e:
|
188 |
+
logger.error(f"Error during web search: {str(e)}")
|
189 |
+
defer.returnValue([])
|
190 |
+
|
191 |
+
logger.info(f"Web search completed. Found {len(results)} results.")
|
192 |
+
defer.returnValue(results)
|
193 |
+
|
194 |
+
@defer.inlineCallbacks
|
195 |
+
def run(self):
|
196 |
+
logger.info(f"Starting MCTS run with {self.num_iterations} iterations")
|
197 |
+
for i in range(self.num_iterations):
|
198 |
+
logger.debug(f"Iteration {i+1}/{self.num_iterations}")
|
199 |
+
leaf = self.select(self.root)
|
200 |
+
child = self.expand(leaf)
|
201 |
+
reward = yield self.simulate(child)
|
202 |
+
self.backpropagate(child, reward)
|
203 |
+
|
204 |
+
best_child = self.best_action()
|
205 |
+
logger.info(f"MCTS run completed. Best action: {best_child.state}")
|
206 |
+
defer.returnValue(best_child.state if best_child != self.root else self.root.state)
|
207 |
+
|
208 |
+
@defer.inlineCallbacks
|
209 |
+
def simulate(self, node):
|
210 |
+
query_results = yield self.web_search(node.state)
|
211 |
+
ranked_results = train_ranking_model(node.state, query_results)
|
212 |
+
|
213 |
+
if ranked_results:
|
214 |
+
top_score = ranked_results[0]['predicted_score']
|
215 |
+
else:
|
216 |
+
top_score = 0
|
217 |
+
|
218 |
+
reward = self.calculate_combined_reward(top_score, node.state)
|
219 |
+
defer.returnValue(reward)
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
|
224 |
+
|
225 |
+
|
my_search_engine/my_search_engine/__init__.py
ADDED
File without changes
|
my_search_engine/my_search_engine/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (165 Bytes). View file
|
|
my_search_engine/my_search_engine/__pycache__/items.cpython-312.pyc
ADDED
Binary file (799 Bytes). View file
|
|
my_search_engine/my_search_engine/items.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import scrapy
|
2 |
+
|
3 |
+
class MySearchEngineItem(scrapy.Item):
|
4 |
+
title = scrapy.Field()
|
5 |
+
link = scrapy.Field()
|
6 |
+
content = scrapy.Field()
|
7 |
+
score = scrapy.Field() # Will be set later during ranking (MCTS or NLP)
|
8 |
+
meta = scrapy.Field()
|
9 |
+
predicted_score = scrapy.Field()
|
10 |
+
summary = scrapy.Field()
|
my_search_engine/my_search_engine/middlewares.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# middlewares.py
|
2 |
+
|
3 |
+
import random
|
4 |
+
import logging
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
class RotateUserAgentMiddleware:
|
9 |
+
"""Middleware for rotating user agents to avoid detection."""
|
10 |
+
|
11 |
+
USER_AGENTS = [
|
12 |
+
# Chrome User Agents
|
13 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
|
14 |
+
" Chrome/93.0.4577.63 Safari/537.36",
|
15 |
+
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko)"
|
16 |
+
" Chrome/93.0.4577.63 Safari/537.36",
|
17 |
+
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko)"
|
18 |
+
" Chrome/93.0.4577.63 Safari/537.36",
|
19 |
+
# Firefox User Agents
|
20 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:92.0) Gecko/20100101 Firefox/92.0",
|
21 |
+
"Mozilla/5.0 (Macintosh; Intel Mac OS X 11.5; rv:92.0) Gecko/20100101 Firefox/92.0",
|
22 |
+
"Mozilla/5.0 (X11; Linux x86_64; rv:92.0) Gecko/20100101 Firefox/92.0",
|
23 |
+
# Safari User Agents
|
24 |
+
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.1.2 Safari/605.1.15",
|
25 |
+
"Mozilla/5.0 (iPhone; CPU iPhone OS 14_7_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.1.2 Mobile/15E148 Safari/604.1",
|
26 |
+
# Edge User Agents
|
27 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/93.0.4577.63 Safari/537.36 Edg/93.0.961.38",
|
28 |
+
# Opera User Agents
|
29 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/93.0.4577.63 Safari/537.36 OPR/78.0.4093.184",
|
30 |
+
# Mobile User Agents
|
31 |
+
"Mozilla/5.0 (Linux; Android 11; SM-G981B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/93.0.4577.62 Mobile Safari/537.36",
|
32 |
+
"Mozilla/5.0 (iPhone; CPU iPhone OS 14_7_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) CriOS/93.0.4577.62 Mobile/15E148 Safari/604.1",
|
33 |
+
# Add more user agents from different browsers and devices
|
34 |
+
]
|
35 |
+
|
36 |
+
def process_request(self, request, spider):
|
37 |
+
"""Assign a random user agent to each request."""
|
38 |
+
user_agent = random.choice(self.USER_AGENTS)
|
39 |
+
request.headers['User-Agent'] = user_agent
|
40 |
+
logger.debug(f"Using User-Agent: {user_agent}")
|
41 |
+
|
42 |
+
# Optional: Proxy Middleware
|
43 |
+
class ProxyMiddleware:
|
44 |
+
"""Middleware for rotating proxies."""
|
45 |
+
|
46 |
+
PROXIES = [
|
47 |
+
# Add proxy URLs if using proxies
|
48 |
+
# 'http://proxy1.example.com:8000',
|
49 |
+
# 'http://proxy2.example.com:8031',
|
50 |
+
]
|
51 |
+
|
52 |
+
def process_request(self, request, spider):
|
53 |
+
if self.PROXIES:
|
54 |
+
proxy = random.choice(self.PROXIES)
|
55 |
+
request.meta['proxy'] = proxy
|
56 |
+
logger.debug(f"Using Proxy: {proxy}")
|
my_search_engine/my_search_engine/pipelines.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pipelines.py
|
2 |
+
|
3 |
+
import json
|
4 |
+
|
5 |
+
class SaveToJSONPipeline:
|
6 |
+
"""Pipeline that saves scraped data to a JSON file."""
|
7 |
+
|
8 |
+
def open_spider(self, spider):
|
9 |
+
"""Open the file when the spider starts."""
|
10 |
+
self.file = open('scraped_results.json', 'w', encoding='utf-8')
|
11 |
+
|
12 |
+
def close_spider(self, spider):
|
13 |
+
"""Close the file when the spider finishes."""
|
14 |
+
self.file.close()
|
15 |
+
|
16 |
+
def process_item(self, item, spider):
|
17 |
+
"""Write each scraped item to the JSON file."""
|
18 |
+
line = json.dumps(dict(item), ensure_ascii=False) + "\n"
|
19 |
+
self.file.write(line)
|
20 |
+
return item
|
21 |
+
|
22 |
+
|
23 |
+
class ContentCleanupPipeline:
|
24 |
+
"""Pipeline to clean up content by removing unnecessary whitespace."""
|
25 |
+
|
26 |
+
def process_item(self, item, spider):
|
27 |
+
"""Clean up content field."""
|
28 |
+
item['content'] = ' '.join(item['content'].split()) # Clean up content by removing extra spaces
|
29 |
+
return item
|
30 |
+
|
31 |
+
|
32 |
+
class DisplayResultsPipeline:
|
33 |
+
"""Pipeline that formats and prints the search results in a Google-like format."""
|
34 |
+
|
35 |
+
def open_spider(self, spider):
|
36 |
+
"""Initialize an empty results list when the spider starts."""
|
37 |
+
self.results = []
|
38 |
+
|
39 |
+
def process_item(self, item, spider):
|
40 |
+
"""Store the item in the results list."""
|
41 |
+
self.results.append({
|
42 |
+
'title': item['title'],
|
43 |
+
'summary': item['content'],
|
44 |
+
'link': item['link']
|
45 |
+
})
|
46 |
+
return item
|
47 |
+
|
48 |
+
def close_spider(self, spider):
|
49 |
+
"""Print out the formatted results when the spider finishes."""
|
50 |
+
print("\nTop 10 Related Links for the Search Query:")
|
51 |
+
for i, result in enumerate(self.results[:10], start=1):
|
52 |
+
print(f"{i}. {result['title']}\n {result['summary'][:200]}...\n {result['link']}\n")
|
53 |
+
|
my_search_engine/my_search_engine/settings.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# settings.py
|
2 |
+
|
3 |
+
# Scrapy configurations
|
4 |
+
BOT_NAME = 'my_search_engine'
|
5 |
+
SPIDER_MODULES = ['my_search_engine.spiders']
|
6 |
+
NEWSPIDER_MODULE = 'my_search_engine.spiders'
|
7 |
+
|
8 |
+
# Obey robots.txt rules
|
9 |
+
ROBOTSTXT_OBEY = True
|
10 |
+
|
11 |
+
# Configure maximum concurrent requests performed by Scrapy
|
12 |
+
CONCURRENT_REQUESTS = 16
|
13 |
+
CONCURRENT_REQUESTS_PER_DOMAIN = 1
|
14 |
+
|
15 |
+
# Configure a delay for requests to the same website
|
16 |
+
DOWNLOAD_DELAY = 2 # Fixed delay of 2 seconds
|
17 |
+
|
18 |
+
# Disable cookies (enabled by default)
|
19 |
+
COOKIES_ENABLED = False
|
20 |
+
|
21 |
+
# Enable AutoThrottle
|
22 |
+
AUTOTHROTTLE_ENABLED = True
|
23 |
+
AUTOTHROTTLE_START_DELAY = 1 # Initial download delay
|
24 |
+
AUTOTHROTTLE_MAX_DELAY = 10 # Maximum download delay in case of high latencies
|
25 |
+
AUTOTHROTTLE_TARGET_CONCURRENCY = 1.0 # Average number of requests Scrapy should be sending in parallel
|
26 |
+
|
27 |
+
# User Agent (default, only used if RotateUserAgentMiddleware fails)
|
28 |
+
USER_AGENT = "Mozilla/5.0 (compatible; MySearchEngine/1.0)"
|
29 |
+
|
30 |
+
# Downloader middlewares
|
31 |
+
DOWNLOADER_MIDDLEWARES = {
|
32 |
+
'my_search_engine.middlewares.RotateUserAgentMiddleware': 543,
|
33 |
+
# Uncomment the following line if using the ProxyMiddleware
|
34 |
+
# 'my_search_engine.middlewares.ProxyMiddleware': 544,
|
35 |
+
'scrapy.downloadermiddlewares.useragent.UserAgentMiddleware': None, # Disable default user agent middleware
|
36 |
+
}
|
37 |
+
|
38 |
+
# Item pipelines
|
39 |
+
ITEM_PIPELINES = {
|
40 |
+
'my_search_engine.pipelines.SaveToJSONPipeline': 300,
|
41 |
+
'my_search_engine.pipelines.ContentCleanupPipeline': 400,
|
42 |
+
'my_search_engine.pipelines.DisplayResultsPipeline': 200,
|
43 |
+
}
|
44 |
+
|
45 |
+
# Enable logging
|
46 |
+
LOG_ENABLED = True
|
47 |
+
LOG_LEVEL = 'INFO' # Set to 'DEBUG' to see all logs, including middleware logs
|
48 |
+
|
49 |
+
# Additional settings can be added below as needed
|
my_search_engine/my_search_engine/spiders/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This package will contain the spiders of your Scrapy project
|
2 |
+
#
|
3 |
+
# Please refer to the documentation for information on how to create and manage
|
4 |
+
# your spiders.
|
my_search_engine/my_search_engine/spiders/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (173 Bytes). View file
|
|
my_search_engine/my_search_engine/spiders/__pycache__/search_spider.cpython-312.pyc
ADDED
Binary file (11.2 kB). View file
|
|
my_search_engine/my_search_engine/spiders/search_spider.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# search_spider.py
|
2 |
+
import scrapy
|
3 |
+
from bs4 import BeautifulSoup
|
4 |
+
from my_search_engine.my_search_engine.items import MySearchEngineItem
|
5 |
+
import random
|
6 |
+
from urllib.parse import urlparse, urljoin
|
7 |
+
import traceback
|
8 |
+
import re
|
9 |
+
from twisted.internet.error import TCPTimedOutError, ConnectionRefusedError, TimeoutError
|
10 |
+
from scrapy.exceptions import CloseSpider
|
11 |
+
import logging
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
class SearchSpider(scrapy.Spider):
|
16 |
+
name = "search_spider"
|
17 |
+
allowed_domains = [] # To be set dynamically from search_sites
|
18 |
+
|
19 |
+
def __init__(self, query=None, search_sites=None, max_depth=2, max_links_per_page=3, *args, **kwargs):
|
20 |
+
super(SearchSpider, self).__init__(*args, **kwargs)
|
21 |
+
self.query = query
|
22 |
+
if not self.query:
|
23 |
+
raise CloseSpider("No search query provided")
|
24 |
+
self.max_depth = max_depth
|
25 |
+
self.max_links_per_page = max_links_per_page
|
26 |
+
if search_sites is None:
|
27 |
+
self.search_sites = [
|
28 |
+
f"https://en.wikibooks.org/w/index.php?search={self.query}",
|
29 |
+
f"https://en.wikiversity.org/w/index.php?search={self.query}",
|
30 |
+
f"https://commons.wikimedia.org/w/index.php?search={self.query}",
|
31 |
+
f"https://stackexchange.com/search?q={self.query}",
|
32 |
+
f"https://arxiv.org/search/?query={self.query}&searchtype=all",
|
33 |
+
f"https://www.ncbi.nlm.nih.gov/pmc/?term={self.query}",
|
34 |
+
f"https://www.gutenberg.org/ebooks/search/?query={self.query}",
|
35 |
+
f"https://openlibrary.org/search?q={self.query}",
|
36 |
+
f"https://doaj.org/search/articles?ref=homepage&q={self.query}",
|
37 |
+
f"https://www.ted.com/search?q={self.query}",
|
38 |
+
f"https://en.citizendium.org/wiki?search={self.query}",
|
39 |
+
f"https://www.jstor.org/action/doBasicSearch?Query={self.query}",
|
40 |
+
f"https://archive.org/search.php?query={self.query}",
|
41 |
+
f"https://search.scielo.org/?q={self.query}",
|
42 |
+
f"https://paperswithcode.com/search?q={self.query}",
|
43 |
+
f"https://www.reddit.com/search/?q={self.query}",
|
44 |
+
f"https://huggingface.co/models?search={self.query}",
|
45 |
+
f"https://huggingface.co/datasets?search={self.query}",
|
46 |
+
f"https://machinelearningmastery.com/?s={self.query}",
|
47 |
+
f"https://www.kaggle.com/search?q={self.query}",
|
48 |
+
f"https://towardsdatascience.com/search?q={self.query}",
|
49 |
+
f"https://github.com/search?q={self.query}",
|
50 |
+
f"https://stackoverflow.com/search?q={self.query}",
|
51 |
+
f"https://www.youtube.com/results?search_query={self.query}",
|
52 |
+
f"https://www.slideshare.net/search/slideshow?searchfrom=header&q={self.query}"
|
53 |
+
]
|
54 |
+
|
55 |
+
else:
|
56 |
+
self.search_sites = search_sites
|
57 |
+
|
58 |
+
|
59 |
+
def start_requests(self):
|
60 |
+
if not self.query:
|
61 |
+
logger.error("No search query provided in start_requests")
|
62 |
+
return
|
63 |
+
|
64 |
+
self.allowed_domains = list(set([urlparse(url).netloc for url in self.search_sites]))
|
65 |
+
logger.info(f"Starting requests for query: {self.query}")
|
66 |
+
|
67 |
+
for url in self.search_sites:
|
68 |
+
yield scrapy.Request(
|
69 |
+
url,
|
70 |
+
callback=self.parse,
|
71 |
+
meta={
|
72 |
+
'dont_retry': True,
|
73 |
+
'handle_httpstatus_list': [302, 403, 404, 420, 429, 500, 503],
|
74 |
+
'depth': 1 # Start at depth 1
|
75 |
+
},
|
76 |
+
errback=self.errback_httpbin
|
77 |
+
)
|
78 |
+
|
79 |
+
def parse(self, response):
|
80 |
+
depth = response.meta.get('depth', 1)
|
81 |
+
if depth > self.max_depth:
|
82 |
+
logger.debug(f"Reached max depth at {response.url}")
|
83 |
+
return
|
84 |
+
|
85 |
+
logger.info(f"Parsing response from {response.url} at depth {depth}")
|
86 |
+
|
87 |
+
try:
|
88 |
+
soup = BeautifulSoup(response.text, 'html.parser')
|
89 |
+
|
90 |
+
# Check for irrelevant or blocked content
|
91 |
+
if any(term in soup.text.lower() for term in ['captcha', 'verification', 'no items found', 'no results', 'access denied']):
|
92 |
+
logger.warning(f"Irrelevant page detected: {response.url}")
|
93 |
+
return
|
94 |
+
|
95 |
+
title = soup.find('title').get_text().strip() if soup.find('title') else 'No title'
|
96 |
+
meta_description = soup.find('meta', {'name': 'description'})
|
97 |
+
meta_description = meta_description['content'].strip() if meta_description else 'No description'
|
98 |
+
|
99 |
+
content = self.extract_main_content(soup)
|
100 |
+
summary = self.generate_summary(content, 200)
|
101 |
+
total_links = len(soup.find_all('a', href=True))
|
102 |
+
content_length = len(content.split())
|
103 |
+
|
104 |
+
if content_length < 100:
|
105 |
+
logger.info(f"Content too short ({content_length} words) for {response.url}")
|
106 |
+
return
|
107 |
+
|
108 |
+
item = MySearchEngineItem()
|
109 |
+
item['title'] = title
|
110 |
+
item['link'] = response.url
|
111 |
+
item['content'] = content
|
112 |
+
item['summary'] = summary
|
113 |
+
item['meta'] = {
|
114 |
+
'description': meta_description,
|
115 |
+
'total_links': total_links,
|
116 |
+
'content_length': content_length,
|
117 |
+
'domain': urlparse(response.url).netloc,
|
118 |
+
}
|
119 |
+
yield item
|
120 |
+
|
121 |
+
# Limit the number of links per page
|
122 |
+
links = soup.find_all('a', href=True)
|
123 |
+
random.shuffle(links)
|
124 |
+
links = links[:self.max_links_per_page] # Limit the number of links
|
125 |
+
|
126 |
+
for link in links:
|
127 |
+
href = link.get('href')
|
128 |
+
full_url = urljoin(response.url, href)
|
129 |
+
if self.is_valid_link(full_url):
|
130 |
+
logger.debug(f"Following link: {full_url}")
|
131 |
+
yield scrapy.Request(
|
132 |
+
url=full_url,
|
133 |
+
callback=self.parse,
|
134 |
+
meta={'depth': depth + 1},
|
135 |
+
errback=self.errback_httpbin
|
136 |
+
)
|
137 |
+
except Exception as e:
|
138 |
+
logger.error(f"Error parsing {response.url}: {str(e)}")
|
139 |
+
logger.error(traceback.format_exc())
|
140 |
+
|
141 |
+
def extract_main_content(self, soup):
|
142 |
+
for element in soup(['script', 'style', 'nav', 'header', 'footer']):
|
143 |
+
element.decompose()
|
144 |
+
|
145 |
+
main_content = soup.find('main') or soup.find('article') or soup.find('div', class_='content')
|
146 |
+
|
147 |
+
if main_content:
|
148 |
+
return ' '.join(main_content.stripped_strings)
|
149 |
+
|
150 |
+
paragraphs = soup.find_all('p')
|
151 |
+
return ' '.join([p.get_text().strip() for p in paragraphs])
|
152 |
+
|
153 |
+
def generate_summary(self, content, max_length=200):
|
154 |
+
sentences = re.split(r'(?<=[.!?])\s+', content)
|
155 |
+
summary = ""
|
156 |
+
for sentence in sentences:
|
157 |
+
if len(summary) + len(sentence) <= max_length:
|
158 |
+
summary += sentence + " "
|
159 |
+
else:
|
160 |
+
break
|
161 |
+
return summary.strip()
|
162 |
+
|
163 |
+
def is_valid_link(self, url):
|
164 |
+
parsed_url = urlparse(url)
|
165 |
+
return any(domain in parsed_url.netloc for domain in self.allowed_domains)
|
166 |
+
|
167 |
+
def errback_httpbin(self, failure):
|
168 |
+
logger.error(f"Error on {failure.request.url}: {str(failure.value)}")
|
169 |
+
logger.error(traceback.format_exc())
|
170 |
+
|
171 |
+
if failure.check(ConnectionRefusedError):
|
172 |
+
logger.warning(f"Connection refused: {failure.request.url}")
|
173 |
+
elif failure.check(TimeoutError, TCPTimedOutError):
|
174 |
+
logger.warning(f"Timeout: {failure.request.url}")
|
175 |
+
else:
|
176 |
+
logger.error(f"Failed to process: {failure.request.url}")
|
my_search_engine/scrapy.cfg
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Automatically created by: scrapy startproject
|
2 |
+
#
|
3 |
+
# For more information about the [deploy] section see:
|
4 |
+
# https://scrapyd.readthedocs.io/en/latest/deploy.html
|
5 |
+
|
6 |
+
[settings]
|
7 |
+
default = my_search_engine.settings
|
8 |
+
|
9 |
+
[deploy]
|
10 |
+
#url = http://localhost:6800/
|
11 |
+
project = my_search_engine
|
ranking.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ranking.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.optim as optim
|
5 |
+
import pandas as pd
|
6 |
+
from sentence_transformers import SentenceTransformer, util
|
7 |
+
import numpy as np
|
8 |
+
from sklearn.preprocessing import MinMaxScaler
|
9 |
+
from collections import Counter
|
10 |
+
import re
|
11 |
+
import string
|
12 |
+
from collections import Counter
|
13 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
14 |
+
from nltk.corpus import stopwords
|
15 |
+
from nltk.stem import WordNetLemmatizer
|
16 |
+
from nltk.tokenize import word_tokenize
|
17 |
+
import spacy
|
18 |
+
|
19 |
+
def truncate_text(text, max_length=1024):
|
20 |
+
tokens = text.split()
|
21 |
+
if len(tokens) > max_length:
|
22 |
+
return ' '.join(tokens[:max_length])
|
23 |
+
return text
|
24 |
+
|
25 |
+
class RankingNN(nn.Module):
|
26 |
+
def __init__(self, input_size=7):
|
27 |
+
super(RankingNN, self).__init__()
|
28 |
+
self.fc1 = nn.Linear(input_size, 64)
|
29 |
+
self.fc2 = nn.Linear(64, 32)
|
30 |
+
self.fc3 = nn.Linear(32, 16)
|
31 |
+
self.fc4 = nn.Linear(16, 1)
|
32 |
+
self.dropout = nn.Dropout(0.2)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
x = torch.relu(self.fc1(x))
|
36 |
+
x = self.dropout(x)
|
37 |
+
x = torch.relu(self.fc2(x))
|
38 |
+
x = self.dropout(x)
|
39 |
+
x = torch.relu(self.fc3(x))
|
40 |
+
x = self.fc4(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
transformer_model = SentenceTransformer('all-MiniLM-L6-v2')
|
44 |
+
ranking_model = RankingNN()
|
45 |
+
optimizer = optim.Adam(ranking_model.parameters(), lr=0.001, weight_decay=1e-5)
|
46 |
+
criterion = nn.MSELoss()
|
47 |
+
scaler = MinMaxScaler()
|
48 |
+
|
49 |
+
# Download necessary resources
|
50 |
+
import nltk
|
51 |
+
nltk.download('punkt')
|
52 |
+
nltk.download('stopwords')
|
53 |
+
nltk.download('wordnet')
|
54 |
+
|
55 |
+
# Initialize resources
|
56 |
+
stop_words = set(stopwords.words('english'))
|
57 |
+
lemmatizer = WordNetLemmatizer()
|
58 |
+
nlp = spacy.load("en_core_web_sm") # Small model to keep compute low
|
59 |
+
|
60 |
+
def preprocess_text(text):
|
61 |
+
"""
|
62 |
+
Preprocess the input text by lowercasing, removing punctuation, and filtering out stopwords.
|
63 |
+
Lemmatization is applied as well.
|
64 |
+
"""
|
65 |
+
# Lowercase the text
|
66 |
+
text = text.lower()
|
67 |
+
|
68 |
+
# Remove punctuation using regex
|
69 |
+
text = re.sub(r'[' + string.punctuation + ']', ' ', text)
|
70 |
+
|
71 |
+
# Tokenize the text into words
|
72 |
+
words = word_tokenize(text)
|
73 |
+
|
74 |
+
# Lemmatize, filter out stopwords and non-alphabetic words
|
75 |
+
processed_words = [lemmatizer.lemmatize(word) for word in words if word.isalpha() and word not in stop_words]
|
76 |
+
|
77 |
+
return processed_words
|
78 |
+
|
79 |
+
def extract_named_entities(text):
|
80 |
+
"""
|
81 |
+
Extract named entities (e.g., people, organizations, locations) from the text.
|
82 |
+
"""
|
83 |
+
doc = nlp(text)
|
84 |
+
named_entities = [ent.text for ent in doc.ents if ent.label_ in {"PERSON", "ORG", "GPE", "LOC"}]
|
85 |
+
return named_entities
|
86 |
+
|
87 |
+
def extract_keywords_tfidf(corpus, text, n=5):
|
88 |
+
"""
|
89 |
+
Extract keywords from the text using TF-IDF, combined with Named Entity Recognition and lemmatization.
|
90 |
+
"""
|
91 |
+
# Preprocess the text and the entire corpus
|
92 |
+
preprocessed_texts = [' '.join(preprocess_text(doc)) for doc in corpus]
|
93 |
+
preprocessed_text = ' '.join(preprocess_text(text))
|
94 |
+
|
95 |
+
# Named entities extraction
|
96 |
+
named_entities = extract_named_entities(text)
|
97 |
+
|
98 |
+
# Use TF-IDF vectorizer to find the most important words
|
99 |
+
vectorizer = TfidfVectorizer(max_features=1000) # Keep it light, max 1000 features
|
100 |
+
X = vectorizer.fit_transform(preprocessed_texts)
|
101 |
+
|
102 |
+
# Get the feature names (i.e., the words)
|
103 |
+
feature_names = vectorizer.get_feature_names_out()
|
104 |
+
|
105 |
+
# Transform the current text into TF-IDF scores
|
106 |
+
response = vectorizer.transform([preprocessed_text])
|
107 |
+
tfidf_scores = zip(feature_names, response.toarray()[0])
|
108 |
+
|
109 |
+
# Sort by TF-IDF score
|
110 |
+
sorted_tfidf = sorted(tfidf_scores, key=lambda x: x[1], reverse=True)
|
111 |
+
|
112 |
+
# Combine top TF-IDF words with named entities for more richness
|
113 |
+
keywords = [word for word, score in sorted_tfidf[:n]]
|
114 |
+
combined_keywords = keywords + named_entities
|
115 |
+
|
116 |
+
return combined_keywords[:n]
|
117 |
+
|
118 |
+
def extract_keywords(text, corpus, n=5):
|
119 |
+
"""
|
120 |
+
Wrapper function that combines preprocessing, TF-IDF, and Named Entity Recognition to extract top N keywords.
|
121 |
+
"""
|
122 |
+
if not text.strip():
|
123 |
+
return []
|
124 |
+
|
125 |
+
# Extract keywords using the TF-IDF based approach
|
126 |
+
keywords = extract_keywords_tfidf(corpus, text, n)
|
127 |
+
|
128 |
+
# If no meaningful keywords are found, fallback to keyword frequency
|
129 |
+
if not keywords:
|
130 |
+
return extract_fallback_keywords(text, n)
|
131 |
+
|
132 |
+
return keywords
|
133 |
+
|
134 |
+
def extract_fallback_keywords(text, n=5):
|
135 |
+
"""
|
136 |
+
Fallback method to extract keywords based on word frequency in case TF-IDF or NER fails.
|
137 |
+
"""
|
138 |
+
words = preprocess_text(text)
|
139 |
+
word_freq = Counter(words)
|
140 |
+
return [word for word, _ in word_freq.most_common(n)]
|
141 |
+
|
142 |
+
def calculate_keyword_overlap(query_keywords, result_keywords):
|
143 |
+
if len(query_keywords) == 0:
|
144 |
+
return 0 # No keywords in query, so overlap is 0
|
145 |
+
return len(set(query_keywords) & set(result_keywords)) / len(query_keywords)
|
146 |
+
|
147 |
+
def train_ranking_model(query, results, corpus=None, epochs=1):
|
148 |
+
query = truncate_text(query)
|
149 |
+
if not results:
|
150 |
+
print("No results available. Skipping training.")
|
151 |
+
return []
|
152 |
+
|
153 |
+
if corpus is None:
|
154 |
+
# If no corpus is provided, use results as a fallback
|
155 |
+
corpus = [truncate_text(result['content']) for result in results if 'content' in result]
|
156 |
+
|
157 |
+
query_embedding = transformer_model.encode(query)
|
158 |
+
query_keywords = extract_keywords(query, corpus)
|
159 |
+
|
160 |
+
training_data = []
|
161 |
+
target_scores = []
|
162 |
+
|
163 |
+
for result in results:
|
164 |
+
# Truncate content
|
165 |
+
content = truncate_text(result['content'])
|
166 |
+
content_embedding = transformer_model.encode(content)
|
167 |
+
|
168 |
+
# Handle missing 'title' and 'meta' fields with default values, and truncate
|
169 |
+
title = truncate_text(result.get('title', ''))
|
170 |
+
title_embedding = transformer_model.encode(title)
|
171 |
+
|
172 |
+
meta_description = truncate_text(result.get('meta', {}).get('description', ''))
|
173 |
+
meta_description_embedding = transformer_model.encode(meta_description)
|
174 |
+
|
175 |
+
content_similarity = util.pytorch_cos_sim(query_embedding, content_embedding).item()
|
176 |
+
title_similarity = util.pytorch_cos_sim(query_embedding, title_embedding).item()
|
177 |
+
meta_description_similarity = util.pytorch_cos_sim(query_embedding, meta_description_embedding).item()
|
178 |
+
|
179 |
+
# Handle missing metadata by providing default values
|
180 |
+
content_length = result.get('meta', {}).get('content_length', 0)
|
181 |
+
total_links = result.get('meta', {}).get('total_links', 0)
|
182 |
+
|
183 |
+
result_keywords = extract_keywords(content, corpus)
|
184 |
+
keyword_overlap = calculate_keyword_overlap(query_keywords, result_keywords)
|
185 |
+
domain_authority = get_domain_authority(result.get('link', ''))
|
186 |
+
|
187 |
+
features = [
|
188 |
+
content_similarity, title_similarity, meta_description_similarity,
|
189 |
+
content_length, total_links, keyword_overlap, domain_authority
|
190 |
+
]
|
191 |
+
|
192 |
+
training_data.append(features)
|
193 |
+
|
194 |
+
target_score = (0.4 * content_similarity + 0.3 * title_similarity +
|
195 |
+
0.2 * meta_description_similarity + 0.1 * keyword_overlap)
|
196 |
+
target_scores.append(target_score)
|
197 |
+
|
198 |
+
# Normalize features
|
199 |
+
training_data = scaler.fit_transform(training_data)
|
200 |
+
training_data_tensor = torch.tensor(training_data, dtype=torch.float32)
|
201 |
+
target_scores_tensor = torch.tensor(target_scores, dtype=torch.float32).unsqueeze(1)
|
202 |
+
|
203 |
+
# Training loop
|
204 |
+
for epoch in range(epochs):
|
205 |
+
optimizer.zero_grad()
|
206 |
+
predicted_scores = ranking_model(training_data_tensor)
|
207 |
+
loss = criterion(predicted_scores, target_scores_tensor)
|
208 |
+
loss.backward()
|
209 |
+
optimizer.step()
|
210 |
+
|
211 |
+
if (epoch + 1) % 5 == 0:
|
212 |
+
print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")
|
213 |
+
|
214 |
+
# Predict final scores and rank results
|
215 |
+
with torch.no_grad():
|
216 |
+
final_scores = ranking_model(training_data_tensor).squeeze().tolist()
|
217 |
+
|
218 |
+
# Ensure final_scores is always a list
|
219 |
+
if isinstance(final_scores, float):
|
220 |
+
final_scores = [final_scores]
|
221 |
+
|
222 |
+
for result, score in zip(results, final_scores):
|
223 |
+
result['predicted_score'] = score
|
224 |
+
|
225 |
+
ranked_results = sorted(results, key=lambda x: x['predicted_score'], reverse=True)
|
226 |
+
return ranked_results
|
227 |
+
|
228 |
+
def get_domain_authority(url):
|
229 |
+
# Placeholder function - replace with actual domain authority data if available
|
230 |
+
high_authority_domains = ['arxiv.org', 'ncbi.nlm.nih.gov', 'nature.com', 'science.org']
|
231 |
+
medium_authority_domains = ['wikipedia.org', 'stackexchange.com', 'github.com']
|
232 |
+
|
233 |
+
for domain in high_authority_domains:
|
234 |
+
if domain in url:
|
235 |
+
return 1.0
|
236 |
+
for domain in medium_authority_domains:
|
237 |
+
if domain in url:
|
238 |
+
return 0.7
|
239 |
+
return 0.5
|
test_agent.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# test_agent.py
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from twisted.internet import reactor, defer, threads
|
5 |
+
from agent import AutonomousWebAgent
|
6 |
+
from ToTSearch import ToTSearch
|
7 |
+
|
8 |
+
# Configure logging
|
9 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
10 |
+
|
11 |
+
# Initialize the logger
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
# Suppress detailed logs for some libraries (like Scrapy or Transformers)
|
15 |
+
logging.getLogger('scrapy').setLevel(logging.ERROR)
|
16 |
+
logging.getLogger('transformers').setLevel(logging.ERROR)
|
17 |
+
logging.getLogger('twisted').setLevel(logging.ERROR)
|
18 |
+
|
19 |
+
import warnings
|
20 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
21 |
+
|
22 |
+
|
23 |
+
class TestAgent:
|
24 |
+
def __init__(self):
|
25 |
+
# Initialize the AutonomousWebAgent
|
26 |
+
state_size = 7 # word_count, link_count, header_count, semantic_similarity, image_count, script_count, css_count
|
27 |
+
action_size = 3 # 0: Click Link, 1: Summarize, 2: RAG Generate
|
28 |
+
num_options = 3 # 0: Search, 1: Summarize, 2: RAG Generate
|
29 |
+
|
30 |
+
self.agent = AutonomousWebAgent(
|
31 |
+
state_size=state_size,
|
32 |
+
action_size=action_size,
|
33 |
+
num_options=num_options,
|
34 |
+
hidden_size=64,
|
35 |
+
learning_rate=0.001,
|
36 |
+
gamma=0.99,
|
37 |
+
epsilon=1.0,
|
38 |
+
epsilon_decay=0.995,
|
39 |
+
epsilon_min=0.01,
|
40 |
+
knowledge_base_path='knowledge_base.json'
|
41 |
+
)
|
42 |
+
|
43 |
+
# Initialize ToTSearch with the agent
|
44 |
+
self.tot_search = ToTSearch(self.agent)
|
45 |
+
|
46 |
+
# Few-shot examples for Tree of Thoughts
|
47 |
+
self.few_shot_examples = [
|
48 |
+
{
|
49 |
+
"query": "What are the effects of climate change on biodiversity?",
|
50 |
+
"thoughts": [
|
51 |
+
"Loss of habitats due to rising sea levels and changing temperatures",
|
52 |
+
"Disruption of ecosystems and food chains",
|
53 |
+
"Increased extinction rates for vulnerable species"
|
54 |
+
],
|
55 |
+
"answer": "Climate change significantly impacts biodiversity through habitat loss, ecosystem disruption, and increased extinction rates. Rising temperatures and sea levels alter habitats, forcing species to adapt or migrate. This disrupts established ecosystems and food chains. Species unable to adapt quickly face a higher risk of extinction, particularly those with specialized habitats or limited ranges."
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"query": "How can we promote sustainable energy adoption?",
|
59 |
+
"thoughts": [
|
60 |
+
"Government policies and incentives",
|
61 |
+
"Public awareness and education campaigns",
|
62 |
+
"Technological advancements and cost reduction"
|
63 |
+
],
|
64 |
+
"answer": "Promoting sustainable energy adoption requires a multi-faceted approach. Government policies and incentives can encourage both businesses and individuals to switch to renewable sources. Public awareness and education campaigns help people understand the importance and benefits of sustainable energy. Continued technological advancements and cost reductions make sustainable energy more accessible and economically viable for widespread adoption."
|
65 |
+
}
|
66 |
+
]
|
67 |
+
|
68 |
+
@defer.inlineCallbacks
|
69 |
+
def process_query(self, query, is_few_shot=False):
|
70 |
+
logger.info(f"Processing query: {query}")
|
71 |
+
try:
|
72 |
+
if is_few_shot:
|
73 |
+
few_shot_prompt = self.create_few_shot_prompt(query)
|
74 |
+
enhanced_query = f"{few_shot_prompt}\n\nQuery: {query}"
|
75 |
+
logger.debug(f"Enhanced query for few-shot learning: {enhanced_query[:100]}...")
|
76 |
+
final_answer = yield self.tot_search.search(enhanced_query)
|
77 |
+
else:
|
78 |
+
final_answer = yield self.tot_search.search(query)
|
79 |
+
|
80 |
+
logger.info(f"Final answer for '{query}':")
|
81 |
+
logger.info(final_answer)
|
82 |
+
|
83 |
+
yield self.agent.add_document_to_kb(title=f"ToT Search Result: {query}", content=final_answer)
|
84 |
+
|
85 |
+
yield self.agent.replay_worker(batch_size=32)
|
86 |
+
yield self.agent.replay_manager(batch_size=32)
|
87 |
+
|
88 |
+
return final_answer
|
89 |
+
except Exception as e:
|
90 |
+
logger.error(f"Error processing query '{query}': {str(e)}", exc_info=True)
|
91 |
+
return f"An error occurred: {str(e)}"
|
92 |
+
|
93 |
+
def create_few_shot_prompt(self, query):
|
94 |
+
prompt = "Here are some examples of how to approach queries using a Tree of Thoughts:\n\n"
|
95 |
+
for example in self.few_shot_examples:
|
96 |
+
prompt += f"Query: {example['query']}\n"
|
97 |
+
prompt += "Thoughts:\n"
|
98 |
+
for thought in example['thoughts']:
|
99 |
+
prompt += f"- {thought}\n"
|
100 |
+
prompt += f"Answer: {example['answer']}\n\n"
|
101 |
+
prompt += f"Now, let's approach the following query in a similar manner:\n\nQuery: {query}\n"
|
102 |
+
return prompt
|
103 |
+
|
104 |
+
def save_models(self):
|
105 |
+
self.agent.save_worker_model("worker_model_final.pth")
|
106 |
+
self.agent.save_manager_model("manager_model_final.pth")
|
107 |
+
logger.info("Agent models saved.")
|
108 |
+
|
109 |
+
|
110 |
+
def get_user_input():
|
111 |
+
return input("Enter your query (or 'quit' to exit): ")
|
112 |
+
|
113 |
+
|
114 |
+
@defer.inlineCallbacks
|
115 |
+
def run_test_session():
|
116 |
+
test_agent = TestAgent()
|
117 |
+
|
118 |
+
logger.info("Starting few-shot learning phase...")
|
119 |
+
for example in test_agent.few_shot_examples:
|
120 |
+
logger.info(f"Processing few-shot example: {example['query']}")
|
121 |
+
try:
|
122 |
+
yield test_agent.process_query(example['query'], is_few_shot=True)
|
123 |
+
except Exception as e:
|
124 |
+
logger.error(f"Error in few-shot learning: {str(e)}", exc_info=True)
|
125 |
+
|
126 |
+
logger.info("Few-shot learning phase completed. Starting interactive session.")
|
127 |
+
|
128 |
+
while True:
|
129 |
+
query = yield threads.deferToThread(get_user_input)
|
130 |
+
|
131 |
+
if query.lower() == 'quit':
|
132 |
+
break
|
133 |
+
|
134 |
+
try:
|
135 |
+
answer = yield test_agent.process_query(query)
|
136 |
+
print("\nAgent's response:")
|
137 |
+
print(answer)
|
138 |
+
print("\n" + "-"*50 + "\n")
|
139 |
+
except Exception as e:
|
140 |
+
logger.error(f"Error in interactive session: {str(e)}", exc_info=True)
|
141 |
+
|
142 |
+
test_agent.save_models()
|
143 |
+
reactor.stop()
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == "__main__":
|
147 |
+
reactor.callWhenRunning(run_test_session)
|
148 |
+
reactor.run()
|
train_agent.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# train_agent.py
|
2 |
+
|
3 |
+
from twisted.internet import reactor, defer, task
|
4 |
+
from agent import AutonomousWebAgent
|
5 |
+
import random
|
6 |
+
import logging
|
7 |
+
import sys
|
8 |
+
import time
|
9 |
+
import codecs
|
10 |
+
|
11 |
+
# Configure logging
|
12 |
+
logging.basicConfig(level=logging.INFO,
|
13 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
14 |
+
handlers=[
|
15 |
+
logging.FileHandler("agent_training.log", encoding='utf-8'),
|
16 |
+
logging.StreamHandler(codecs.getwriter('utf-8')(sys.stdout.buffer))
|
17 |
+
])
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
# List of diverse queries
|
22 |
+
QUERIES = [
|
23 |
+
"machine learning", "climate change", "renewable energy", "artificial intelligence",
|
24 |
+
"quantum computing", "blockchain technology", "gene editing", "virtual reality",
|
25 |
+
"space exploration", "cybersecurity", "autonomous vehicles", "Internet of Things",
|
26 |
+
"3D printing", "nanotechnology", "bioinformatics", "augmented reality", "robotics",
|
27 |
+
"data science", "neural networks", "cloud computing", "edge computing", "5G technology",
|
28 |
+
"cryptocurrency", "natural language processing", "computer vision"
|
29 |
+
]
|
30 |
+
|
31 |
+
@defer.inlineCallbacks
|
32 |
+
def train_agent():
|
33 |
+
# Updated state_size to 7 to match the feature extraction in AutonomousWebAgent
|
34 |
+
state_size = 7 # word_count, link_count, header_count, semantic_similarity, image_count, script_count, css_count
|
35 |
+
action_size = 3 # 0: Click Link, 1: Summarize, 2: RAG Generate
|
36 |
+
num_options = 3 # 0: Search, 1: Summarize, 2: RAG Generate
|
37 |
+
|
38 |
+
# Initialize the AutonomousWebAgent with the required arguments
|
39 |
+
agent = AutonomousWebAgent(
|
40 |
+
state_size=state_size,
|
41 |
+
action_size=action_size,
|
42 |
+
num_options=num_options, # Added parameter for HRL
|
43 |
+
hidden_size=64,
|
44 |
+
learning_rate=0.001,
|
45 |
+
gamma=0.99,
|
46 |
+
epsilon=1.0,
|
47 |
+
epsilon_decay=0.995,
|
48 |
+
epsilon_min=0.01,
|
49 |
+
knowledge_base_path='knowledge_base.json'
|
50 |
+
)
|
51 |
+
logger.info(f"Initialized AutonomousWebAgent with state_size={state_size}, action_size={action_size}, num_options={num_options}")
|
52 |
+
|
53 |
+
num_episodes = 10 # Adjust as needed
|
54 |
+
total_training_reward = 0
|
55 |
+
start_time = time.time()
|
56 |
+
|
57 |
+
for episode in range(num_episodes):
|
58 |
+
query = random.choice(QUERIES)
|
59 |
+
logger.info(f"Starting episode {episode + 1}/{num_episodes} with query: {query}")
|
60 |
+
episode_start_time = time.time()
|
61 |
+
|
62 |
+
try:
|
63 |
+
# Initiate the search process
|
64 |
+
search_deferred = agent.search(query)
|
65 |
+
search_deferred.addTimeout(300, reactor) # 5-minute timeout
|
66 |
+
total_reward = yield search_deferred
|
67 |
+
total_training_reward += total_reward
|
68 |
+
episode_duration = time.time() - episode_start_time
|
69 |
+
logger.info(f"Episode {episode + 1}/{num_episodes}, Query: {query}, Total Reward: {total_reward}, Duration: {episode_duration:.2f} seconds")
|
70 |
+
except defer.TimeoutError:
|
71 |
+
logger.error(f"Episode {episode + 1} timed out")
|
72 |
+
total_reward = -1 # Assign a negative reward for timeout
|
73 |
+
total_training_reward += total_reward
|
74 |
+
except Exception as e:
|
75 |
+
logger.error(f"Error in episode {episode + 1}: {str(e)}", exc_info=True)
|
76 |
+
total_reward = -1 # Assign a negative reward for errors
|
77 |
+
total_training_reward += total_reward
|
78 |
+
|
79 |
+
# Update target models periodically
|
80 |
+
if (episode + 1) % 10 == 0:
|
81 |
+
logger.info(f"Updating target models at episode {episode + 1}")
|
82 |
+
agent.update_worker_target_model()
|
83 |
+
agent.update_manager_target_model()
|
84 |
+
agent.manager.update_target_model()
|
85 |
+
|
86 |
+
# Log overall progress
|
87 |
+
progress = (episode + 1) / num_episodes
|
88 |
+
elapsed_time = time.time() - start_time
|
89 |
+
estimated_total_time = elapsed_time / progress if progress > 0 else 0
|
90 |
+
remaining_time = estimated_total_time - elapsed_time
|
91 |
+
logger.info(f"Overall progress: {progress:.2%}, Elapsed time: {elapsed_time:.2f}s, Estimated remaining time: {remaining_time:.2f}s")
|
92 |
+
|
93 |
+
total_training_time = time.time() - start_time
|
94 |
+
average_reward = total_training_reward / num_episodes
|
95 |
+
logger.info(f"Training completed. Total reward: {total_training_reward}, Average reward per episode: {average_reward:.2f}")
|
96 |
+
logger.info(f"Total training time: {total_training_time:.2f} seconds")
|
97 |
+
logger.info("Saving models.")
|
98 |
+
|
99 |
+
# Save both Worker and Manager models
|
100 |
+
agent.save_worker_model("worker_model.pth")
|
101 |
+
agent.save_manager_model("manager_model.pth")
|
102 |
+
agent.save("web_agent_model.pth") # Assuming this saves additional components if needed
|
103 |
+
|
104 |
+
if reactor.running:
|
105 |
+
logger.info("Stopping reactor")
|
106 |
+
reactor.stop()
|
107 |
+
|
108 |
+
def main():
|
109 |
+
logger.info("Starting agent training")
|
110 |
+
d = task.deferLater(reactor, 0, train_agent)
|
111 |
+
d.addErrback(lambda failure: logger.error(f"An error occurred: {failure}", exc_info=True))
|
112 |
+
d.addBoth(lambda _: reactor.stop())
|
113 |
+
reactor.run()
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
main()
|