sharmabhi commited on
Commit
0829957
·
1 Parent(s): 1318f51

Upload 2 files

Browse files
Files changed (2) hide show
  1. dqn.py +93 -0
  2. mdp.py +86 -0
dqn.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import random
4
+ from sklearn.utils import shuffle
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.autograd as autograd
8
+ from torchcontrib.optim import SWA
9
+ from collections import deque
10
+
11
+ from preprocess import *
12
+
13
+ class DQN(nn.Module):
14
+
15
+ def __init__(self, input_dim, output_dim):
16
+ super(DQN, self).__init__()
17
+ self.input_dim = input_dim
18
+ self.output_dim = output_dim
19
+ self.fc = nn.Sequential( \
20
+ nn.Linear(self.input_dim[0], 32), \
21
+ nn.ReLU(), \
22
+ nn.Linear(32, self.output_dim))
23
+
24
+ def forward(self, state):
25
+ return self.fc(state)
26
+
27
+ class DQNAgent:
28
+
29
+ def __init__(self, input_dim, dataset,
30
+ learning_rate=3e-4,
31
+ gamma=0.99,
32
+ buffer=None,
33
+ buffer_size=10000,
34
+ tau=0.999,
35
+ swa=False,
36
+ pre_trained_model=None):
37
+ self.learning_rate = learning_rate
38
+ self.gamma = gamma
39
+ self.tau = tau
40
+ self.model = DQN(input_dim, 1)
41
+ if pre_trained_model:
42
+ self.model = pre_trained_model
43
+ base_opt = torch.optim.Adam(self.model.parameters())
44
+ self.swa = swa
45
+ self.dataset=dataset
46
+ self.MSE_loss = nn.MSELoss()
47
+ self.replay_buffer = buffer
48
+ if swa:
49
+ self.optimizer = SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=0.05)
50
+ else:
51
+ self.optimizer = base_opt
52
+
53
+ def get_action(self, state, dataset=None):
54
+ if dataset is None:
55
+ dataset = self.dataset
56
+ inputs = get_multiple_model_inputs(state, state.remaining, dataset)
57
+ model_inputs = autograd.Variable(torch.from_numpy(inputs).float().unsqueeze(0))
58
+ expected_returns = self.model.forward(model_inputs)
59
+ value, index = expected_returns.max(1)
60
+ return state.remaining[index[0]]
61
+
62
+ def compute_loss(self, batch, dataset, verbose=False):
63
+ states, actions, rewards, next_states, dones = batch
64
+ model_inputs = np.array([get_model_inputs(states[i], actions[i], dataset)\
65
+ for i in range(len(states))])
66
+ model_inputs = torch.FloatTensor(model_inputs)
67
+
68
+ rewards = torch.FloatTensor(rewards)
69
+ dones = torch.FloatTensor(dones)
70
+
71
+ curr_Q = self.model.forward(model_inputs)
72
+ model_inputs = np.array([get_model_inputs(next_states[i], actions[i], dataset) \
73
+ for i in range(len(next_states))])
74
+ model_inputs = torch.FloatTensor(model_inputs)
75
+ next_Q = self.model.forward(model_inputs)
76
+ max_next_Q = torch.max(next_Q, 1)[0]
77
+ expected_Q = rewards.squeeze(1) + (1 - dones) * self.gamma * max_next_Q
78
+
79
+ if verbose:
80
+ print(curr_Q, expected_Q)
81
+ loss = self.MSE_loss(curr_Q.squeeze(0), expected_Q.detach())
82
+ return loss
83
+
84
+ def update(self, batch_size, verbose=False):
85
+ batch = self.replay_buffer.sample(batch_size)
86
+ loss = self.compute_loss(batch, self.dataset, verbose)
87
+ train_loss = loss.float()
88
+ self.optimizer.zero_grad()
89
+ loss.backward()
90
+ self.optimizer.step()
91
+ if self.swa:
92
+ self.optimizer.swap_swa_sgd()
93
+ return train_loss
mdp.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # State and Buffer Classes
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import random
6
+ from sklearn.utils import shuffle
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.autograd as autograd
10
+ from torchcontrib.optim import SWA
11
+ from collections import deque
12
+
13
+ from preprocess import *
14
+
15
+ def compute_reward(t, relevance):
16
+ """
17
+ Reward function for MDP
18
+ """
19
+ if t == 0:
20
+ return 0
21
+ return relevance / np.log2(t + 1)
22
+
23
+ class State:
24
+
25
+ def __init__(self, t, query, remaining):
26
+ self.t = t
27
+ self.qid = query #useful for sorting buffer
28
+ self.remaining = remaining
29
+
30
+ def pop(self):
31
+ return self.remaining.pop()
32
+
33
+ def initial(self):
34
+ return self.t == 0
35
+
36
+ def terminal(self):
37
+ return len(self.remaining) == 0
38
+
39
+ class BasicBuffer:
40
+
41
+ def __init__(self, max_size):
42
+ self.max_size = max_size
43
+ self.buffer = deque(maxlen=max_size)
44
+
45
+ def push(self, state, action, reward, next_state, done):
46
+ experience = (state, action, np.array([reward]), next_state, done)
47
+ self.buffer.append(experience)
48
+
49
+ def push_batch(self, df, n):
50
+ for i in range(n):
51
+ random_qid = random.choice(list(df["qid"]))
52
+ filtered_df = df.loc[df["qid"] == int(random_qid)].reset_index()
53
+ row_order = [x for x in range(len(filtered_df))]
54
+ X = [x[1]["doc_id"] for x in filtered_df.iterrows()]
55
+ random.shuffle(row_order)
56
+ for t,r in enumerate(row_order):
57
+ cur_row = filtered_df.iloc[r]
58
+ old_state = State(t, cur_row["qid"], X[:])
59
+ action = cur_row["doc_id"]
60
+ new_state = State(t+1, cur_row["qid"], X[:])
61
+ reward = compute_reward(t+1, cur_row["rank"])
62
+ self.push(old_state, action, reward, new_state, t+1 == len(row_order))
63
+ filtered_df.drop(filtered_df.index[[r]])
64
+
65
+ def sample(self, batch_size):
66
+ state_batch = []
67
+ action_batch = []
68
+ reward_batch = []
69
+ next_state_batch = []
70
+ done_batch = []
71
+
72
+ batch = random.sample(self.buffer, batch_size)
73
+
74
+ for experience in batch:
75
+ state, action, reward, next_state, done = experience
76
+ state_batch.append(state)
77
+ action_batch.append(action)
78
+ reward_batch.append(reward)
79
+ next_state_batch.append(next_state)
80
+ done_batch.append(done)
81
+
82
+ return (state_batch, action_batch, reward_batch,
83
+ next_state_batch, done_batch)
84
+
85
+ def __len__(self):
86
+ return len(self.buffer)