Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def split_heads(x, num_heads): | |
""" Split heads | |
:param x: A tensor with shape [batch, length, channels] | |
:param num_heads: An integer | |
:returns: A tensor with shape [batch, heads, length, channels / heads] | |
""" | |
assert x.shape[-1] % num_heads == 0, str(x.shape) | |
return x.reshape(x.shape[:-1] + (num_heads, x.shape[-1] // num_heads)).permute(0, 2, 1, 3) | |
def combine_heads(x): | |
""" Combine heads | |
:param x: A tensor with shape [batch, heads, length, channels] | |
:returns: A tensor with shape [batch, length, heads * channels] | |
""" | |
x = x.permute([0, 2, 1, 3]) | |
return x.reshape(x.shape[:-2] + (x.shape[-1] * x.shape[-2],)) | |
class SimpleAttention(nn.Module): | |
def __init__(self, query_size=192, key_size=192, value_size=192, num_heads=1): | |
super(SimpleAttention, self).__init__() | |
self.q_transform = nn.Linear(query_size, query_size, bias=False) | |
self.k_transform = nn.Linear(key_size, query_size, bias=False) | |
self.v_transform = nn.Linear(value_size, query_size, bias=False) | |
self.output_transform = nn.Linear(query_size, query_size, bias=False) | |
self.query_size = query_size | |
self.key_size = key_size | |
self.value_size = value_size | |
self.num_heads = num_heads | |
def forward(self, query, key, value, attn_mask=None, bias=None): | |
q = self.q_transform(query) | |
k = self.k_transform(key) | |
v = self.v_transform(value) | |
logits = torch.bmm(q, k.transpose(1, 2)) # [batch, length_q, length_k] | |
if bias is not None: | |
logits += bias | |
if attn_mask is not None: | |
logits = logits + attn_mask * -1e9 | |
weights = F.softmax(logits, dim=-1) | |
out = torch.bmm(weights, v) | |
out = self.output_transform(out) | |
return out, weights | |