kevinwang676 commited on
Commit
d831168
·
verified ·
1 Parent(s): c2e4b4e

Create mcqa_bert.py

Browse files
Files changed (1) hide show
  1. mcqa_bert.py +26 -0
mcqa_bert.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mcqa_bert.py
2
+ # --------------------------------------------------
3
+ # Plain BertModel + single‑unit classification head
4
+ # --------------------------------------------------
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import BertModel
8
+
9
+
10
+ class MCQABERT(nn.Module):
11
+ def __init__(self, ckpt: str = "bert-base-uncased"):
12
+ super().__init__()
13
+ self.encoder = BertModel.from_pretrained(ckpt)
14
+ self.head = nn.Linear(self.encoder.config.hidden_size, 1)
15
+
16
+ # --------------------------------------------------
17
+
18
+ def forward(self, input_ids, attention_mask):
19
+ out = self.encoder(
20
+ input_ids=input_ids,
21
+ attention_mask=attention_mask,
22
+ return_dict=True,
23
+ )
24
+ cls_vec = out.last_hidden_state[:, 0] # [CLS]
25
+ logits = self.head(cls_vec).squeeze(-1) # (B)
26
+ return logits