相对位置编码
原论文:
P. Shaw, J. Uszkoreit, and A. Vaswani, ‘Self-Attention with Relative Position Representations’, in Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 2 (Short Papers), New Orleans, Louisiana, 2018, pp. 464–468. doi: 10.18653/v1/N18-2074.
# 没有位置编码的Attention
设dz表示输入和输出特征的维度,i∈[1,n],n表示输入样本数量。输入向量组成矩阵X∈Rn×dz。
那么输出矩阵Z∈Rn×dz计算过程为:
Z=softmax(dz(XWQ)(XWK)T)(XWV)
其中WQ∈Rdmodel×dz、WK∈Rdmodel×dz、WV∈Rdmodel×dz
为了解释相对位置编码,将计算过程展开,输入向量xi∈Rdz:
Z=softmax⎝⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎛dz⎣⎢⎢⎢⎡x1⋮xn⎦⎥⎥⎥⎤⋅WQ⋅(⎣⎢⎢⎢⎡x1⋮xn⎦⎥⎥⎥⎤⋅WK)T⎠⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎞⋅⎣⎢⎢⎢⎡x1⋮xn⎦⎥⎥⎥⎤⋅WV=softmax⎝⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎛dz⎣⎢⎢⎢⎡x1WQ⋮xnWQ⎦⎥⎥⎥⎤⎣⎢⎢⎢⎡x1WK⋮xnWK⎦⎥⎥⎥⎤T⎠⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎞⎣⎢⎢⎢⎡x1WV⋮xnWV⎦⎥⎥⎥⎤=softmax⎝⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎛dz(⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎡x1WQ⋮xiWQ⋮xnWQ⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎤)[(x1WK)T,…,(xjWK)T…,(xnWK)T]⎠⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎞⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎡x1WV⋮xjWV⋮xnWV⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎤=softmax(dz{xiWQ(xjWK)T}i,j∈[1,n])⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎡x1WV⋮xjWV⋮xnWV⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎤={softmax(dzxiWQ(xjWK)T)}i,j∈[1,n]⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎡x1WV⋮xjWV⋮xnWV⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎤=[j=1∑nsoftmax(dzxiWQ(xjWK)T)xjWV]i∈[1,n]=⎣⎢⎢⎢⎡z1⋮zn⎦⎥⎥⎥⎤
所以:
zi=j=1∑nsoftmax(dzxiWQ(xjWK)T)xjWV
# 加上相对位置编码的Attention
zi=j=1∑nsoftmax(dzxiWQ(xjWK+aijK)T)(xjWV+aijV)
其中,aijK和aijV就是序列中的元素i相对于元素j的相对位置编码。进一步,这两项位置编码值都取自长为2k+1的位置编码集:
wKwV=(w−kK,…,wkK)=(w−kV,…,wkV)
其取法如下:
aijKaijVclip(x,k)=wclipK(j−i,k)=wclipV(j−i,k)=max(−k,min(k,x))
其实就是根据ij差值去wK和wV里取值,并且设定差值最大为k,超过k的位置编码为固定值。
也挺好理解的。
# 相对位置编码的矩阵形式
Z=⎣⎢⎢⎢⎡z1⋮zn⎦⎥⎥⎥⎤=[j=1∑nsoftmax(dzxiWQ(xjWK+aijK)T)(xjWV+aijV)]i∈[1,n]=[j=1∑nsoftmax(dzxiWQ(xjWK)T+xiWQ(aijK)T)(xjWV+aijV)]i∈[1,n]
这么一看aijK和aijV都是向量,那由aijK和aijV组成的矩阵就是个Rn×n×dz的三维矩阵了,没法表示啊😂,还是算了吧
# 在代码中的体现
来一段图像RSTT的Python代码。RSTT是一种能同时进行插帧和超分辨率的Transformer:
class WindowAttention3D(nn.Module):
"""Window based multi-head self/cross attention (W-MSA/W-MCA) module with relative
position bias.
It supports both of shifted and non-shifted window.
"""
def __init__(self, dim, num_frames_q, num_frames_kv, window_size, num_heads,
qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
"""Initialization function.
Args:
dim (int): Number of input channels.
num_frames (int): Number of input frames.
window_size (tuple[int]): The size of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Defaults to True.
qk_scale (float, optional): Override default qk scale of head_dim ** -0.5 if set. Defaults to None.
attn_drop (float, optional): Dropout ratio of attention weight. Defaults to 0.0
proj_drop (float, optional): Dropout ratio of output. Defaults to 0.0
"""
super().__init__()
self.dim = dim
self.num_frames_q = num_frames_q # D1
self.num_frames_kv = num_frames_kv # D2
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads # nH
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# 这个relative_position_bias_table一看就是位置编码
# 名字取得叫位置编码不说,还有很明显的2n-1计算,很显然是w
self.relative_position_bias_table = nn.Parameter( # TODO
torch.zeros((2 * num_frames_q - 1) * (2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
) # 2*D-1 * 2*Wh-1 * 2*Ww-1, nH
# 这个Attention层中的相对位置编码是按窗口大小定的
# 仔细看看这里的2n-1计算是对于num_frames_q和window_size进行的,并且还把它们乘起来了。
# 那么可以推断这个位置编码包含多个维度:特征在不同帧上的相对位置、特征在图片上的相对位置(长宽两个维度)
# Get pair-wise relative position index for each token inside the window
coords_d_q = torch.arange(self.num_frames_q) # 从1数到Q的帧数
coords_d_kv = torch.arange(0, self.num_frames_q, int((self.num_frames_q + 1) // self.num_frames_kv)) # 从1数到KV的帧数
# 注意:在RSTT中,num_frames_q等于输出帧的数量,即输入样本进行插帧后的帧数;
# 而num_frames_kv来自于未插帧的原始数据,所以这里的coords_d_kv从1数到KV的帧数是跳着数的
coords_h = torch.arange(self.window_size[0]) # 从1数到窗口长度
coords_w = torch.arange(self.window_size[1]) # 从1数到窗口宽度
# 接下来meshgrid把上面的这几个数数的数组组成坐标表
coords_q = torch.stack(torch.meshgrid([coords_d_q, coords_h, coords_w])) # 3, D1, Wh, Ww
coords_kv = torch.stack(torch.meshgrid([coords_d_kv, coords_h, coords_w])) # 3, D2, Wh, Ww
# 然后拉平
coords_q_flatten = torch.flatten(coords_q, 1) # 3, D1*Wh*Ww
coords_kv_flatten = torch.flatten(coords_kv, 1) # 3, D2*Wh*Ww
# 接下来就是计算相对量,就是在w里取数时用的k值
relative_coords = coords_q_flatten[:, :, None] - coords_kv_flatten[:, None, :] # 3, D1*Wh*Ww, D2*Wh*Ww
# 这个None似乎可以让矩阵在运算时扩展,于是让(3, D1*Wh*Ww, None)与(3, None, D2*Wh*Ww)的计算结果变成(3, D1*Wh*Ww, D2*Wh*Ww)
# 于是现在relative_coords就是输入中每个值与其他所有值的index之差了
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # D1*Wh*Ww, D2*Wh*Ww, 3
# 现在relative_coords里面有负值,然后下面这一步就是让所有的index都从0开始
relative_coords[:, :, 0] += self.num_frames_q - 1 # Q矩阵中所有元素与KV矩阵中所有元素在“帧”维度上的距离
relative_coords[:, :, 1] += self.window_size[0] - 1 # Q矩阵中所有元素与KV矩阵中所有元素在“高”维度上的距离
relative_coords[:, :, 2] += self.window_size[1] - 1 # Q矩阵中所有元素与KV矩阵中所有元素在“宽”维度上的距离
# relative_coords[x, y, :]即表示KV矩阵上的元素x与Q矩阵上的元素y的“帧,高,宽”距离表示
# 现在relative_coords里面就是每个维度上输入矩阵元素的距离,然后接下来就是计算累计距离
relative_coords[:, :, 0] *= (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) # 看来每一帧是最大的距离单位
relative_coords[:, :, 1] *= 2 * self.window_size[1] - 1 # window_size[1]是第二大的距离单位
# 没有relative_coords[:, :, 2]?因为self.window_size[0]是最小的距离单位啦
# 求和,得到最终的k值
relative_position_index = relative_coords.sum(-1) # D1*Wh*Ww, D2*Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
# 相对位置编码初始化完成
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, q, kv=None, mask=None):
"""Forward function.
Args:
q (torch.Tensor): (B*nW, D1*Wh*Ww, C)
kv (torch.Tensor): (B*nW, D2*Wh*Ww, C). Defaults to None.
mask (torch.Tensor, optional): Mask for shifted window attention (nW, D1*Wh*Ww, D2*Wh*Ww). Defaults to None.
Returns:
torch.Tensor: (B*nW, D1*Wh*Ww, C)
"""
kv = q if kv is None else kv
B_, N1, C = q.shape # N1 = D1*Wh*Ww, B_ = B*nW
B_, N2, C = kv.shape # N2 = D2*Wh*Ww, B_ = B*nW
q = self.q(q).reshape(B_, N1, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
kv = self.kv(kv).reshape(B_, N2, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = q[0], kv[0], kv[1] # B_, nH, N1(2), C
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # B_, nH, N1, N2
# 相对位置编码用法就非常简单了,直接按k值进w里取数然后与attantion相加就完了
# 这一步是取数
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
N1, N2, -1) # D1*Wh*Ww, D2*Wh*Ww, nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, D1*Wh*Ww, D2*Wh*Ww
# 这一步是相加
attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, D1*Wh*Ww, D2*Wh*Ww
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N1, N2) + mask.unsqueeze(1).unsqueeze(0) # B, nW, nH, D1*Wh*Ww, D2*Wh*Ww
attn = attn.view(-1, self.num_heads, N1, N2)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N1, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# 位置编号矩阵如何生成
输入:
U=QKV={qdQ,hQ,wQ},dQ∈[1,DQ],hQ∈[1,HQ],wQ∈[1,WQ]={udU,hU,wU},dU∈[1,DU],hU∈[1,HU],wU∈[1,WU]
其中D为帧数、H为图片长度、W为图片宽度
位置编号矩阵:
i=j=Ii,j={Ii,j}dQHQWQ+hQWQ+wQdUHUWU+hUWU+wU(⎣⎢⎢⎡dUhUwU⎦⎥⎥⎤−⎣⎢⎢⎡dQhQwQ⎦⎥⎥⎤+⎣⎢⎢⎡DQHQWQ⎦⎥⎥⎤)T⎣⎢⎢⎡(HU+HQ)(WU+WQ)WU+WQ1⎦⎥⎥⎤
其中,⎣⎢⎢⎡dUhUwU⎦⎥⎥⎤−⎣⎢⎢⎡dQhQwQ⎦⎥⎥⎤即表示了相对位置,⎣⎢⎢⎡DQHQWQ⎦⎥⎥⎤是为了让相对位置值从0开始,而⎣⎢⎢⎡(HU+HQ)(WU+WQ)WU+WQ1⎦⎥⎥⎤则把三元组形式的相对位置编码计算为一个标量值。
于是,Ii,j就把输入矩阵Q中的每个元素和U中的每个元素计算出了一个相对位置值。很显然,这个位置值最大为(DU+DQ)(HU+HQ)(WU+WQ),可以给矩阵Q和U中的任意两个元素都配上一个值。之后只需要根据Ii,j从一个长度为(DU+DQ)(HU+HQ)(WU+WQ)的列表中找相对位置编码值即可。