Yin的笔记本

vuePress-theme-reco Howard Yin    2021 - 2025
Yin的笔记本 Yin的笔记本

Choose mode

  • dark
  • auto
  • light
Home
Category
  • CNCF
  • Docker
  • namespaces
  • Kubernetes
  • Kubernetes对象
  • Linux
  • MyIdeas
  • Revolution
  • WebRTC
  • 云计算
  • 人工智能
  • 分布式
  • 图像处理
  • 图形学
  • 微服务
  • 数学
  • OJ笔记
  • 博弈论
  • 形式语言与自动机
  • 数据库
  • 服务器运维
  • 编程语言
  • C
  • Git
  • Go
  • Java
  • JavaScript
  • Python
  • Nvidia
  • Shell
  • Tex
  • Rust
  • Vue
  • 视频编解码
  • 计算机网络
  • SDN
  • 论文笔记
  • 讨论
  • 边缘计算
  • 量子信息技术
Tag
TimeLine
About
查看源码
author-avatar

Howard Yin

304

Article

153

Tag

Home
Category
  • CNCF
  • Docker
  • namespaces
  • Kubernetes
  • Kubernetes对象
  • Linux
  • MyIdeas
  • Revolution
  • WebRTC
  • 云计算
  • 人工智能
  • 分布式
  • 图像处理
  • 图形学
  • 微服务
  • 数学
  • OJ笔记
  • 博弈论
  • 形式语言与自动机
  • 数据库
  • 服务器运维
  • 编程语言
  • C
  • Git
  • Go
  • Java
  • JavaScript
  • Python
  • Nvidia
  • Shell
  • Tex
  • Rust
  • Vue
  • 视频编解码
  • 计算机网络
  • SDN
  • 论文笔记
  • 讨论
  • 边缘计算
  • 量子信息技术
Tag
TimeLine
About
查看源码
  • 相对位置编码

    • 没有位置编码的Attention
      • 加上相对位置编码的Attention
        • 相对位置编码的矩阵形式
          • 在代码中的体现
            • 位置编号矩阵如何生成

            相对位置编码

            vuePress-theme-reco Howard Yin    2021 - 2025

            相对位置编码


            Howard Yin 2022-07-26 09:43:01 人工智能位置编码

            原论文:

            P. Shaw, J. Uszkoreit, and A. Vaswani, ‘Self-Attention with Relative Position Representations’, in Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 2 (Short Papers), New Orleans, Louisiana, 2018, pp. 464–468. doi: 10.18653/v1/N18-2074.

            # 没有位置编码的Attention

            设dzd_zdz​表示输入和输出特征的维度,i∈[1,n]i\in[1,n]i∈[1,n],nnn表示输入样本数量。输入向量组成矩阵X∈Rn×dzX\in\mathbb R^{n\times d_z}X∈Rn×dz​。

            那么输出矩阵Z∈Rn×dzZ\in\mathbb R^{n\times d_z}Z∈Rn×dz​计算过程为:

            Z=softmax((XWQ)(XWK)Tdz)(XWV)Z=softmax(\frac{(X W^Q)(X W^K)^T}{\sqrt{d_z}})(X W^V) Z=softmax(dz​​(XWQ)(XWK)T​)(XWV)

            其中WQ∈Rdmodel×dzW^Q\in\mathbb R^{d_{model}\times d_z}WQ∈Rdmodel​×dz​、WK∈Rdmodel×dzW^K\in\mathbb R^{d_{model}\times d_z}WK∈Rdmodel​×dz​、WV∈Rdmodel×dzW^V\in\mathbb R^{d_{model}\times d_z}WV∈Rdmodel​×dz​

            为了解释相对位置编码,将计算过程展开,输入向量xi∈Rdz\bm x_i\in\mathbb R^{d_z}xi​∈Rdz​:

            Z=softmax([x1⋮xn]⋅WQ⋅([x1⋮xn]⋅WK)Tdz)⋅[x1⋮xn]⋅WV=softmax([x1WQ⋮xnWQ][x1WK⋮xnWK]Tdz)[x1WV⋮xnWV]=softmax(([x1WQ⋮xiWQ⋮xnWQ])[(x1WK)T,…,(xjWK)T…,(xnWK)T]dz)[x1WV⋮xjWV⋮xnWV]=softmax({xiWQ(xjWK)T}i,j∈[1,n]dz)[x1WV⋮xjWV⋮xnWV]={softmax(xiWQ(xjWK)Tdz)}i,j∈[1,n][x1WV⋮xjWV⋮xnWV]=[∑j=1nsoftmax(xiWQ(xjWK)Tdz)xjWV]i∈[1,n]=[z1⋮zn]\begin{aligned} Z &=softmax\left(\frac{\left[\begin{gathered}x_1\\\vdots\\x_n\end{gathered}\right]\cdot W^Q\cdot (\left[\begin{gathered}x_1\\\vdots\\x_n\end{gathered}\right]\cdot W^K)^T}{\sqrt{d_z}}\right)\cdot \left[\begin{gathered}x_1\\\vdots\\x_n\end{gathered}\right]\cdot W^V\\ &=softmax\left(\frac{\left[\begin{gathered}x_1 W^Q\\\vdots\\x_n W^Q\end{gathered}\right]\left[\begin{gathered}x_1 W^K\\\vdots\\x_n W^K\end{gathered}\right]^T}{\sqrt{d_z}}\right)\left[\begin{gathered}x_1 W^V\\\vdots\\x_n W^V\end{gathered}\right]\\ &=softmax\left(\frac{(\left[\begin{gathered}x_1 W^Q\\\vdots\\x_iW^Q\\\vdots\\x_n W^Q\end{gathered}\right])\left[\begin{gathered}(x_1 W^K)^T,\dots,(x_jW^K)^T\dots,(x_n W^K)^T\end{gathered}\right]}{\sqrt{d_z}}\right)\left[\begin{gathered}x_1 W^V\\\vdots\\x_jW^V\\\vdots\\x_n W^V\end{gathered}\right]\\ &=softmax\left(\frac{\left\{x_i W^Q(x_j W^K)^T\right\}_{i,j\in[1,n]}}{\sqrt{d_z}}\right)\left[\begin{gathered}x_1 W^V\\\vdots\\x_jW^V\\\vdots\\x_n W^V\end{gathered}\right]\\ &=\left\{softmax\left(\frac{x_i W^Q(x_j W^K)^T}{\sqrt{d_z}}\right)\right\}_{i,j\in[1,n]}\left[\begin{gathered}x_1 W^V\\\vdots\\x_jW^V\\\vdots\\x_n W^V\end{gathered}\right]\\ &=\left[\sum_{j=1}^nsoftmax\left(\frac{x_i W^Q(x_j W^K)^T}{\sqrt{d_z}}\right)x_j W^V\right]_{i\in[1,n]}\\ &=\left[\begin{gathered}z_1\\\vdots\\z_n\end{gathered}\right] \end{aligned}Z​=softmax⎝⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎛​dz​​⎣⎢⎢⎢⎡​x1​⋮xn​​⎦⎥⎥⎥⎤​⋅WQ⋅(⎣⎢⎢⎢⎡​x1​⋮xn​​⎦⎥⎥⎥⎤​⋅WK)T​⎠⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎞​⋅⎣⎢⎢⎢⎡​x1​⋮xn​​⎦⎥⎥⎥⎤​⋅WV=softmax⎝⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎛​dz​​⎣⎢⎢⎢⎡​x1​WQ⋮xn​WQ​⎦⎥⎥⎥⎤​⎣⎢⎢⎢⎡​x1​WK⋮xn​WK​⎦⎥⎥⎥⎤​T​⎠⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎞​⎣⎢⎢⎢⎡​x1​WV⋮xn​WV​⎦⎥⎥⎥⎤​=softmax⎝⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎜⎛​dz​​(⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎡​x1​WQ⋮xi​WQ⋮xn​WQ​⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎤​)[(x1​WK)T,…,(xj​WK)T…,(xn​WK)T​]​⎠⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎟⎞​⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎡​x1​WV⋮xj​WV⋮xn​WV​⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎤​=softmax(dz​​{xi​WQ(xj​WK)T}i,j∈[1,n]​​)⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎡​x1​WV⋮xj​WV⋮xn​WV​⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎤​={softmax(dz​​xi​WQ(xj​WK)T​)}i,j∈[1,n]​⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎡​x1​WV⋮xj​WV⋮xn​WV​⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎤​=[j=1∑n​softmax(dz​​xi​WQ(xj​WK)T​)xj​WV]i∈[1,n]​=⎣⎢⎢⎢⎡​z1​⋮zn​​⎦⎥⎥⎥⎤​​

            所以:

            zi=∑j=1nsoftmax(xiWQ(xjWK)Tdz)xjWVz_i=\sum_{j=1}^nsoftmax\left(\frac{x_i W^Q(x_j W^K)^T}{\sqrt{d_z}}\right)x_j W^V zi​=j=1∑n​softmax(dz​​xi​WQ(xj​WK)T​)xj​WV

            # 加上相对位置编码的Attention

            zi=∑j=1nsoftmax(xiWQ(xjWK+aijK)Tdz)(xjWV+aijV)z_i=\sum_{j=1}^nsoftmax\left(\frac{x_i W^Q(x_j W^K+a_{ij}^K)^T}{\sqrt{d_z}}\right)(x_j W^V+a_{ij}^V) zi​=j=1∑n​softmax(dz​​xi​WQ(xj​WK+aijK​)T​)(xj​WV+aijV​)

            其中,aijKa_{ij}^KaijK​和aijVa_{ij}^VaijV​就是序列中的元素iii相对于元素jjj的相对位置编码。进一步,这两项位置编码值都取自长为2k+12k+12k+1的位置编码集:

            wK=(w−kK,…,wkK)wV=(w−kV,…,wkV)\begin{aligned} w^K&=(w_{-k}^K,\dots,w_k^K)\\ w^V&=(w_{-k}^V,\dots,w_k^V) \end{aligned} wKwV​=(w−kK​,…,wkK​)=(w−kV​,…,wkV​)​

            其取法如下:

            aijK=wclipK(j−i,k)aijV=wclipV(j−i,k)clip(x,k)=max(−k,min(k,x))\begin{aligned} a_{ij}^K&=w_{clip}^K(j-i,k)\\ a_{ij}^V&=w_{clip}^V(j-i,k)\\ clip(x,k)&=max(-k,min(k,x)) \end{aligned} aijK​aijV​clip(x,k)​=wclipK​(j−i,k)=wclipV​(j−i,k)=max(−k,min(k,x))​

            其实就是根据ijijij差值去wKw^KwK和wVw^VwV里取值,并且设定差值最大为kkk,超过kkk的位置编码为固定值。

            也挺好理解的。

            # 相对位置编码的矩阵形式

            Z=[z1⋮zn]=[∑j=1nsoftmax(xiWQ(xjWK+aijK)Tdz)(xjWV+aijV)]i∈[1,n]=[∑j=1nsoftmax(xiWQ(xjWK)T+xiWQ(aijK)Tdz)(xjWV+aijV)]i∈[1,n]\begin{aligned} Z &=\left[\begin{gathered}z_1\\\vdots\\z_n\end{gathered}\right]\\ &=\left[\sum_{j=1}^nsoftmax\left(\frac{x_i W^Q(x_j W^K+a_{ij}^K)^T}{\sqrt{d_z}}\right)(x_j W^V+a_{ij}^V)\right]_{i\in[1,n]}\\ &=\left[\sum_{j=1}^nsoftmax\left(\frac{x_i W^Q(x_j W^K)^T+x_i W^Q(a_{ij}^K)^T}{\sqrt{d_z}}\right)(x_j W^V+a_{ij}^V)\right]_{i\in[1,n]} \end{aligned}Z​=⎣⎢⎢⎢⎡​z1​⋮zn​​⎦⎥⎥⎥⎤​=[j=1∑n​softmax(dz​​xi​WQ(xj​WK+aijK​)T​)(xj​WV+aijV​)]i∈[1,n]​=[j=1∑n​softmax(dz​​xi​WQ(xj​WK)T+xi​WQ(aijK​)T​)(xj​WV+aijV​)]i∈[1,n]​​

            这么一看aijKa_{ij}^KaijK​和aijVa_{ij}^VaijV​都是向量,那由aijKa_{ij}^KaijK​和aijVa_{ij}^VaijV​组成的矩阵就是个Rn×n×dz\mathbb R^{n\times n\times d_z}Rn×n×dz​的三维矩阵了,没法表示啊😂,还是算了吧

            # 在代码中的体现

            来一段图像RSTT的Python代码。RSTT是一种能同时进行插帧和超分辨率的Transformer:

            class WindowAttention3D(nn.Module):
                """Window based multi-head self/cross attention (W-MSA/W-MCA) module with relative 
                position bias. 
                It supports both of shifted and non-shifted window.
                """
                def __init__(self, dim, num_frames_q, num_frames_kv, window_size, num_heads,
                             qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
                    """Initialization function.
            
                    Args:
                        dim (int): Number of input channels.
                        num_frames (int): Number of input frames.
                        window_size (tuple[int]): The size of the window.
                        num_heads (int): Number of attention heads.
                        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Defaults to True.
                        qk_scale (float, optional): Override default qk scale of head_dim ** -0.5 if set. Defaults to None.
                        attn_drop (float, optional): Dropout ratio of attention weight. Defaults to 0.0
                        proj_drop (float, optional): Dropout ratio of output. Defaults to 0.0
                    """
                    super().__init__()
                    self.dim = dim
                    self.num_frames_q = num_frames_q # D1
                    self.num_frames_kv = num_frames_kv # D2
                    self.window_size = window_size  # Wh, Ww
                    self.num_heads = num_heads # nH
                    head_dim = dim // num_heads
                    self.scale = qk_scale or head_dim ** -0.5
            
                    # 这个relative_position_bias_table一看就是位置编码
                    # 名字取得叫位置编码不说,还有很明显的2n-1计算,很显然是w
                    self.relative_position_bias_table = nn.Parameter( # TODO
                        torch.zeros((2 * num_frames_q - 1) * (2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
                    )  # 2*D-1 * 2*Wh-1 * 2*Ww-1, nH
                    # 这个Attention层中的相对位置编码是按窗口大小定的
                    # 仔细看看这里的2n-1计算是对于num_frames_q和window_size进行的,并且还把它们乘起来了。
                    # 那么可以推断这个位置编码包含多个维度:特征在不同帧上的相对位置、特征在图片上的相对位置(长宽两个维度)
            
                    # Get pair-wise relative position index for each token inside the window
                    coords_d_q = torch.arange(self.num_frames_q) # 从1数到Q的帧数
                    coords_d_kv = torch.arange(0, self.num_frames_q, int((self.num_frames_q + 1) // self.num_frames_kv)) # 从1数到KV的帧数
                    # 注意:在RSTT中,num_frames_q等于输出帧的数量,即输入样本进行插帧后的帧数;
                    # 而num_frames_kv来自于未插帧的原始数据,所以这里的coords_d_kv从1数到KV的帧数是跳着数的
            
                    coords_h = torch.arange(self.window_size[0]) # 从1数到窗口长度
                    coords_w = torch.arange(self.window_size[1]) # 从1数到窗口宽度
            
                    # 接下来meshgrid把上面的这几个数数的数组组成坐标表
                    coords_q = torch.stack(torch.meshgrid([coords_d_q, coords_h, coords_w]))  # 3, D1, Wh, Ww
                    coords_kv = torch.stack(torch.meshgrid([coords_d_kv, coords_h, coords_w]))  # 3, D2, Wh, Ww
            
                    # 然后拉平
                    coords_q_flatten = torch.flatten(coords_q, 1)  # 3, D1*Wh*Ww
                    coords_kv_flatten = torch.flatten(coords_kv, 1)  # 3, D2*Wh*Ww
            
                    # 接下来就是计算相对量,就是在w里取数时用的k值
                    relative_coords = coords_q_flatten[:, :, None] - coords_kv_flatten[:, None, :]  # 3, D1*Wh*Ww, D2*Wh*Ww
                    # 这个None似乎可以让矩阵在运算时扩展,于是让(3, D1*Wh*Ww, None)与(3, None, D2*Wh*Ww)的计算结果变成(3, D1*Wh*Ww, D2*Wh*Ww)
                    # 于是现在relative_coords就是输入中每个值与其他所有值的index之差了
                    relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # D1*Wh*Ww, D2*Wh*Ww, 3
                    # 现在relative_coords里面有负值,然后下面这一步就是让所有的index都从0开始
                    relative_coords[:, :, 0] += self.num_frames_q - 1 # Q矩阵中所有元素与KV矩阵中所有元素在“帧”维度上的距离
                    relative_coords[:, :, 1] += self.window_size[0] - 1 # Q矩阵中所有元素与KV矩阵中所有元素在“高”维度上的距离
                    relative_coords[:, :, 2] += self.window_size[1] - 1 # Q矩阵中所有元素与KV矩阵中所有元素在“宽”维度上的距离
                    # relative_coords[x, y, :]即表示KV矩阵上的元素x与Q矩阵上的元素y的“帧,高,宽”距离表示
            
                    # 现在relative_coords里面就是每个维度上输入矩阵元素的距离,然后接下来就是计算累计距离
                    relative_coords[:, :, 0] *= (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) # 看来每一帧是最大的距离单位
                    relative_coords[:, :, 1] *= 2 * self.window_size[1] - 1 # window_size[1]是第二大的距离单位
                    # 没有relative_coords[:, :, 2]?因为self.window_size[0]是最小的距离单位啦
            
                    # 求和,得到最终的k值
                    relative_position_index = relative_coords.sum(-1)  # D1*Wh*Ww, D2*Wh*Ww
                    self.register_buffer("relative_position_index", relative_position_index)
                    # 相对位置编码初始化完成
            
                    self.q = nn.Linear(dim, dim, bias=qkv_bias)
                    self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
                    self.attn_drop = nn.Dropout(attn_drop)
                    self.proj = nn.Linear(dim, dim)
                    self.proj_drop = nn.Dropout(proj_drop)
            
                    trunc_normal_(self.relative_position_bias_table, std=.02)
                    self.softmax = nn.Softmax(dim=-1)
            
                def forward(self, q, kv=None, mask=None):
                    """Forward function.
            
                    Args:
                        q (torch.Tensor): (B*nW, D1*Wh*Ww, C)
                        kv (torch.Tensor): (B*nW, D2*Wh*Ww, C). Defaults to None.
                        mask (torch.Tensor, optional): Mask for shifted window attention (nW, D1*Wh*Ww, D2*Wh*Ww). Defaults to None.
            
                    Returns:
                        torch.Tensor: (B*nW, D1*Wh*Ww, C)
                    """
                    kv = q if kv is None else kv
                    B_, N1, C = q.shape # N1 = D1*Wh*Ww, B_ = B*nW
                    B_, N2, C = kv.shape # N2 = D2*Wh*Ww, B_ = B*nW
            
                    q = self.q(q).reshape(B_, N1, 1, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
                    kv = self.kv(kv).reshape(B_, N2, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
                    q, k, v = q[0], kv[0], kv[1] # B_, nH, N1(2), C
                    q = q * self.scale
                    attn = (q @ k.transpose(-2, -1)) # B_, nH, N1, N2
            
                    # 相对位置编码用法就非常简单了,直接按k值进w里取数然后与attantion相加就完了
                    # 这一步是取数
                    relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
                        N1, N2, -1)  # D1*Wh*Ww, D2*Wh*Ww, nH
                    relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, D1*Wh*Ww, D2*Wh*Ww
                    # 这一步是相加
                    attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, D1*Wh*Ww, D2*Wh*Ww
            
                    if mask is not None:
                        nW = mask.shape[0]
                        attn = attn.view(B_ // nW, nW, self.num_heads, N1, N2) + mask.unsqueeze(1).unsqueeze(0) # B, nW, nH, D1*Wh*Ww, D2*Wh*Ww
                        attn = attn.view(-1, self.num_heads, N1, N2)
                        attn = self.softmax(attn)
                    else:
                        attn = self.softmax(attn)
            
                    attn = self.attn_drop(attn)
            
                    x = (attn @ v).transpose(1, 2).reshape(B_, N1, C)
                    x = self.proj(x)
                    x = self.proj_drop(x)
                    return x, attn
            
            1
            2
            3
            4
            5
            6
            7
            8
            9
            10
            11
            12
            13
            14
            15
            16
            17
            18
            19
            20
            21
            22
            23
            24
            25
            26
            27
            28
            29
            30
            31
            32
            33
            34
            35
            36
            37
            38
            39
            40
            41
            42
            43
            44
            45
            46
            47
            48
            49
            50
            51
            52
            53
            54
            55
            56
            57
            58
            59
            60
            61
            62
            63
            64
            65
            66
            67
            68
            69
            70
            71
            72
            73
            74
            75
            76
            77
            78
            79
            80
            81
            82
            83
            84
            85
            86
            87
            88
            89
            90
            91
            92
            93
            94
            95
            96
            97
            98
            99
            100
            101
            102
            103
            104
            105
            106
            107
            108
            109
            110
            111
            112
            113
            114
            115
            116
            117
            118
            119
            120
            121
            122
            123
            124
            125
            126
            127

            # 位置编号矩阵如何生成

            输入:

            Q={qdQ,hQ,wQ},dQ∈[1,DQ],hQ∈[1,HQ],wQ∈[1,WQ]U=KV={udU,hU,wU},dU∈[1,DU],hU∈[1,HU],wU∈[1,WU]\begin{aligned} &&Q&=\left\{q_{d_Q,h_Q,w_Q}\right\},d_Q\in[1,D_Q],h_Q\in[1,H_Q],w_Q\in[1,W_Q]\\ U=&&KV&=\left\{u_{d_U,h_U,w_U}\right\},d_U\in[1,D_U],h_U\in[1,H_U],w_U\in[1,W_U] \end{aligned} U=​​QKV​={qdQ​,hQ​,wQ​​},dQ​∈[1,DQ​],hQ​∈[1,HQ​],wQ​∈[1,WQ​]={udU​,hU​,wU​​},dU​∈[1,DU​],hU​∈[1,HU​],wU​∈[1,WU​]​

            其中DDD为帧数、HHH为图片长度、WWW为图片宽度

            位置编号矩阵:

            {Ii,j}i=dQHQWQ+hQWQ+wQj=dUHUWU+hUWU+wUIi,j=([dUhUwU]−[dQhQwQ]+[DQHQWQ])T[(HU+HQ)(WU+WQ)WU+WQ1]\begin{aligned} &\left\{I_{i,j}\right\}\\ i=&d_QH_QW_Q+h_QW_Q+w_Q\\ j=&d_UH_UW_U+h_UW_U+w_U\\ I_{i,j}=&(\left[\begin{gathered}d_U\\h_U\\w_U\end{gathered}\right]-\left[\begin{gathered}d_Q\\h_Q\\w_Q\end{gathered}\right]+\left[\begin{gathered}D_Q\\H_Q\\W_Q\end{gathered}\right])^T\left[\begin{gathered}(H_U+H_Q)(W_U+W_Q)\\ W_U+W_Q\\1\end{gathered}\right] \end{aligned} i=j=Ii,j​=​{Ii,j​}dQ​HQ​WQ​+hQ​WQ​+wQ​dU​HU​WU​+hU​WU​+wU​(⎣⎢⎢⎡​dU​hU​wU​​⎦⎥⎥⎤​−⎣⎢⎢⎡​dQ​hQ​wQ​​⎦⎥⎥⎤​+⎣⎢⎢⎡​DQ​HQ​WQ​​⎦⎥⎥⎤​)T⎣⎢⎢⎡​(HU​+HQ​)(WU​+WQ​)WU​+WQ​1​⎦⎥⎥⎤​​

            其中,[dUhUwU]−[dQhQwQ]\left[\begin{gathered}d_U\\h_U\\w_U\end{gathered}\right]-\left[\begin{gathered}d_Q\\h_Q\\w_Q\end{gathered}\right]⎣⎢⎢⎡​dU​hU​wU​​⎦⎥⎥⎤​−⎣⎢⎢⎡​dQ​hQ​wQ​​⎦⎥⎥⎤​即表示了相对位置,[DQHQWQ]\left[\begin{gathered}D_Q\\H_Q\\W_Q\end{gathered}\right]⎣⎢⎢⎡​DQ​HQ​WQ​​⎦⎥⎥⎤​是为了让相对位置值从0开始,而[(HU+HQ)(WU+WQ)WU+WQ1]\left[\begin{gathered}(H_U+H_Q)(W_U+W_Q)\\ W_U+W_Q\\1\end{gathered}\right]⎣⎢⎢⎡​(HU​+HQ​)(WU​+WQ​)WU​+WQ​1​⎦⎥⎥⎤​则把三元组形式的相对位置编码计算为一个标量值。

            于是,Ii,jI_{i,j}Ii,j​就把输入矩阵QQQ中的每个元素和UUU中的每个元素计算出了一个相对位置值。很显然,这个位置值最大为(DU+DQ)(HU+HQ)(WU+WQ)(D_U+D_Q)(H_U+H_Q)(W_U+W_Q)(DU​+DQ​)(HU​+HQ​)(WU​+WQ​),可以给矩阵QQQ和UUU中的任意两个元素都配上一个值。之后只需要根据Ii,jI_{i,j}Ii,j​从一个长度为(DU+DQ)(HU+HQ)(WU+WQ)(D_U+D_Q)(H_U+H_Q)(W_U+W_Q)(DU​+DQ​)(HU​+HQ​)(WU​+WQ​)的列表中找相对位置编码值即可。

            帮助我们改善此页面!
            创建于: 2022-07-13 03:46:50

            更新于: 2022-07-26 09:43:11