关注
Transformer详解(附代码)

引言

T r a n s f o r m e r \mathrm{Transformer} Transformer模型是 G o o g l e \mathrm{Google} Google团队在 2017 2017 2017 6 6 6月由 A s h i s h   V a s w a n i \mathrm{Ashish\text{ }Vaswani} Ashish Vaswani等人在论文《 A t t e n t i o n   I s   A l l   Y o u   N e e d \mathrm{Attention\text{ }Is\text{ }All \text{ }You \text{ } Need} Attention Is All You Need》所提出,当前它已经成为 N L P \mathrm{NLP} NLP领域中的首选模型。 T r a n s f o r m e r \mathrm{Transformer} Transformer抛弃了 R N N \mathrm{RNN} RNN的顺序结构,采用了 S e l f \mathrm{Self} Self- A t t e n t i o n \mathrm{Attention} Attention机制,使得模型可以并行化训练,而且能够充分利用训练资料的全局信息,加入 T r a n s f o r m e r \mathrm{Transformer} Transformer S e q 2 s e q \mathrm{Seq2seq} Seq2seq模型在 N L P \mathrm{NLP} NLP的各个任务上都有了显著的提升。本文做了大量的图示目的是能够更加清晰地讲解 T r a n s f o r m e r \mathrm{Transformer} Transformer的运行原理,以及相关组件的操作细节,文末还有完整可运行的代码示例。

注意力机制

T r a n s f o r m e r \mathrm{Transformer} Transformer中的核心机制就是 S e l f \mathrm{Self} Self- A t t e n t i o n \mathrm{Attention} Attention S e l f \mathrm{Self} Self- A t t e n t i o n \mathrm{Attention} Attention机制的本质来自于人类视觉注意力机制。当人视觉在感知东西时候往往会更加关注某个场景中显著性的物体,为了合理利用有限的视觉信息处理资源,人需要选择视觉区域中的特定部分,然后集中关注它。注意力机制主要目的就是对输入进行注意力权重的分配,即决定需要关注输入的哪部分,并对其分配有限的信息处理资源给重要的部分。

Self-Attention

S e l f \mathrm{Self} Self- A t t e n t i o n \mathrm{Attention} Attention工作原理如上图所示,给定输入 w o r d   e m b e d d i n g \mathrm{word\text{ }embedding} word embedding向量 a 1 , a 2 , a 3 ∈ R d l × 1 a^1,a^2,a^3 \in \mathbb{R}^{d_l \times 1} a1,a2,a3Rdl×1,然后对于输入向量 a i , i ∈ { 1 , 2 , 3 } a^i,i\in \{1,2,3\} ai,i{1,2,3}通过矩阵 W q ∈ R d k × d l , W k ∈ R d k × d l , W v ∈ R d l × d l W^q\in \mathbb{R}^{d_k \times d_l},W^k\in \mathbb{R}^{d_k \times d_l},W^v\in \mathbb{R}^{d_l\times d_l} WqRdk×dl,WkRdk×dl,WvRdl×dl进行线性变换得到 Q u e r y \mathrm{Query} Query向量 q i ∈ R d k × 1 q^i\in\mathbb{R}^{d_k \times 1} qiRdk×1 K e y \mathrm{Key} Key向量 k i ∈ R d k × 1 k^i\in \mathbb{R}^{d_k \times 1} kiRdk×1,以及 V a l u e \mathrm{Value} Value向量 v i ∈ R d l × 1 v^i\in \mathbb{R}^{d_l \times 1} viRdl×1,即 { q i = W q ⋅ a i k i = W k ⋅ a i , i ∈ { 1 , 2 , 3 } v i = W v ⋅ a i \left\{\begin{aligned}q^i&=W^q \cdot a^i\\k^i&=W^k \cdot a^i,\quad i\in\{1,2,3\}\\v^i&=W^v \cdot a^i\end{aligned}\right. qikivi=Wqai=Wkai,i{1,2,3}=Wvai如果令矩阵 A = ( a 1 , a 2 , a 3 ) ∈ R d l × 3 A=(a^1,a^2,a^3)\in\mathbb{R}^{d_l \times 3} A=(a1,a2,a3)Rdl×3 Q = ( q 1 , q 2 , q 3 ) ∈ R d k × 3 Q=(q^1,q^2,q^3)\in\mathbb{R}^{d_k \times 3} Q=(q1,q2,q3)Rdk×3 K = ( k 1 , k 2 , k 3 ) ∈ R d k × 3 K=(k^1,k^2,k^3)\in\mathbb{R}^{d_k \times 3} K=(k1,k2,k3)Rdk×3 V = ( v 1 , v 2 , v 3 ) ∈ R d l × 3 V=(v^1,v^2,v^3)\in\mathbb{R}^{d_l \times 3} V=(v1,v2,v3)Rdl×3,则此时则有 { Q = W q ⋅ A K = W k ⋅ A V = W v ⋅ A \left\{\begin{aligned}Q&=W^q \cdot A\\K&=W^k \cdot A\\V&=W^v \cdot A\end{aligned}\right. QKV=WqA=WkA=WvA接着再利用得到的 Q u e r y \mathrm{Query} Query向量和 K e y \mathrm{Key} Key向量计算注意力得分,论文中采用的注意力计算公式为点积缩放公式 α l i = ( q i ) ⊤ ⋅ k l d k = d k d k ∑ n = 1 d k k n l ⋅ q n i , i , l ∈ { 1 , 2 , 3 } \alpha^{i}_l=\frac{(q^i)^{\top}\cdot k^l}{\sqrt{d^k}}=\frac{\sqrt{d^k}}{d^k}\sum\limits_{n=1}^{d^k}k^l_n\cdot q^i_n,\quad i,l \in \{1,2,3\} αli=dk (qi)kl=dkdk n=1dkknlqni,i,l{1,2,3}论文中假定 K e y \mathrm{Key} Key向量 k l = ( k 1 l , k 2 l , k 3 l ) k^l=(k^l_1,k^l_2,k^l_3) kl=(k1l,k2l,k3l)的元素和 Q u e r y \mathrm{Query} Query向量 q i = ( q 1 i , q 2 i , q 3 i ) q^i=(q^i_1,q^i_2,q^i_3) qi=(q1i,q2i,q3i)的元素独立同分布,且令均值为 0 0 0,方差为 1 1 1,则此时注意力向量 a i ∈ R 3 × 1 a^{i}\in \mathbb{R}^{3 \times 1} aiR3×1的第 l l l个分量 α l i \alpha^{i}_l αli的均值为 0 0 0,方差 1 1 1具体的计算公式如下 E [ α l i ] = d k d k ∑ n = 1 d k E [ k n l ] ⋅ E [ q n i ] = 0 , i , l ∈ { 1 , 2 , 3 } V a r [ α l i ] = 1 d k ∑ n = 1 d k V a r [ k n l ] ⋅ V a r [ q n i ] = 1 , i , l ∈ { 1 , 2 , 3 } \begin{aligned}\mathbb{E}\left[\alpha^i_l\right]&=\frac{\sqrt{d^k}}{d^k}\sum\limits_{n=1}^{d^k}\mathbb{E}\left[k^l_n\right]\cdot \mathbb{E}\left[q^i_n\right]=0,\quad i,l \in \{1,2,3\}\\ \mathrm{Var}\left[\alpha^i_l\right]&=\frac{1}{d^k}\sum\limits_{n=1}^{d^k}\mathrm{Var}\left[k^l_n\right]\cdot \mathrm{Var}\left[q^i_n\right]=1,\quad i,l \in \{1,2,3\}\end{aligned} E[αli]Var[αli]=dkdk n=1dkE[knl]E[qni]=0,i,l{1,2,3}=dk1n=1dkVar[knl]Var[qni]=1,i,l{1,2,3}令注意力分数矩阵 Λ = ( α 1 , α 2 , α 3 ) ∈ R 3 × 3 \Lambda=(\alpha^1,\alpha^2,\alpha^3)\in \mathbb{R}^{3 \times 3} Λ=(α1,α2,α3)R3×3,则有 Λ = K ⊤ ⋅ Q d k \Lambda=\frac{K^{\top}\cdot Q}{\sqrt{d^k}} Λ=dk KQ注意分数向量 α i \alpha^i αi经过 s o f t m a x \mathrm{softmax} softmax层得到归一化后的注意力分布 β i \beta^i βi,即为 β j i = e α j i ∑ n = 1 3 e α n i , i , j = { 1 , 2 , 3 } \beta^i_j = \frac{e^{\alpha^{i}_j}}{\sum\limits_{n=1}^3e^{\alpha^{i}_n}},\quad i,j=\{1,2,3\} βji=n=13eαnieαji,i,j={1,2,3}最后利用得到的注意力分布向量 β i \beta^i βi V a l u e \mathrm{Value} Value矩阵 V V V获得最后的输出 b i ∈ R d l × 1 b^i\in \mathbb{R}^{d_l \times 1} biRdl×1,则有 b i = ∑ l = 1 3 β l i ⋅ v l , i ∈ { 1 , 2 , 3 } b^i=\sum\limits^{3}_{l=1}\beta^{i}_l \cdot v^{l},\quad i \in \{1,2,3\} bi=l=13βlivl,i{1,2,3}令输出矩阵 B = ( b 1 , b 2 , b 3 ) ∈ R d l × 3 B=(b^1,b^2,b^3)\in\mathbb{R}^{d_l\times 3} B=(b1,b2,b3)Rdl×3,则有 B = A t t e n t i o n ( Q , K , V ) = V ⋅ s o f t m a x ( K ⊤ ⋅ Q d k ) B=\mathrm{Attention}(Q,K,V)=V\cdot\mathrm{softmax}\left(\frac{K^{\top}\cdot Q}{\sqrt{d^k}}\right) B=Attention(Q,K,V)=Vsoftmax(dk KQ)

Multi-Head Attention

M u l t i \mathrm{Multi} Multi- H e a d   A t t e n t i o n \mathrm{Head\text{ }Attention} Head Attention的工作原理与 S e l f \mathrm{Self} Self- A t t e n t i o n \mathrm{Attention} Attention的工作原理非常类似。为了方便图解可视化将 M u l t i \mathrm{Multi} Multi- H e a d \mathrm{Head} Head设置为 2 2 2- H e a d \mathrm{Head} Head,如果 M u l t i \mathrm{Multi} Multi- H e a d \mathrm{Head} Head设置为 8 8 8- H e a d \mathrm{Head} Head,则上图的 q i , k i , v i , i ∈ { 1 , 2 , 3 } q^i,k^i,v^i,i\in\{1,2,3\} qi,ki,vi,i{1,2,3}的下一步的分支数为 8 8 8。给定输入 w o r d   e m b e d d i n g \mathrm{word\text{ }embedding} word embedding向量 a 1 , a 2 , a 3 ∈ R d l × 1 a^1,a^2,a^3 \in \mathbb{R}^{d_l \times 1} a1,a2,a3Rdl×1,然后对于输入向量 a i , i ∈ { 1 , 2 , 3 } a^i,i\in \{1,2,3\} ai,i{1,2,3}通过矩阵 W q ∈ R d k × d l , W k ∈ R d k × d l , W v ∈ R d l × d l W^q\in \mathbb{R}^{d_k \times d_l},W^k\in \mathbb{R}^{d_k \times d_l},W^v\in \mathbb{R}^{d_l\times d_l} WqRdk×dl,WkRdk×dl,WvRdl×dl进行第一次线性变换得到 Q u e r y \mathrm{Query} Query向量 q i ∈ R d k × 1 q^i\in\mathbb{R}^{d_k \times 1} qiRdk×1 K e y \mathrm{Key} Key向量 k i ∈ R d k × 1 k^i \in\mathbb{R}^{d_k \times 1} kiRdk×1,以及 V a l u e \mathrm{Value} Value向量 v i ∈ R d l × 1 v^i \in\mathbb{R}^{d_l \times 1} viRdl×1。然后再对 Q u e r y \mathrm{Query} Query向量 q i q^i qi通过矩阵 W q 1 ∈ R d m × d k W^{q1}\in \mathbb{R}^{d_m \times d_k} Wq1Rdm×dk W q 2 ∈ R d m × d k W^{q2}\in \mathbb{R}^{d_m\times d_k} Wq2Rdm×dk进行第二次线性变换得到 q i 1 ∈ R d m × 1 q^{i1}\in \mathbb{R}^{d_m \times 1} qi1Rdm×1 q i 2 ∈ R d m × 1 q^{i2}\in \mathbb{R}^{d_m\times 1} qi2Rdm×1,同理对 K e y \mathrm{Key} Key向量 k i k^i ki通过矩阵 W k 1 ∈ R d m × d k W^{k1}\in \mathbb{R}^{d_m \times d_k} Wk1Rdm×dk W k 2 ∈ R d m × d k W^{k2}\in \mathbb{R}^{d_m\times d_k} Wk2Rdm×dk进行第二次线性变换得到 k i 1 ∈ R d m × 1 k^{i1}\in \mathbb{R}^{d_m\times 1} ki1Rdm×1 k i 2 ∈ R d m × 1 k^{i2}\in \mathbb{R}^{d_m\times 1} ki2Rdm×1,对 V a l u e \mathrm{Value} Value向量 v i v^i vi通过矩阵 W v 1 ∈ R d l 2 × d l W^{v1}\in \mathbb{R}^{\frac{d_l}{2}\times d_l} Wv1R2dl×dl W v 2 ∈ R d l 2 × d l W^{v2}\in \mathbb{R}^{\frac{d_l}{2}\times d_l} Wv2R2dl×dl进行第二次线性变换得到 v i 1 ∈ R d l 2 × 1 v^{i1}\in \mathbb{R}^{\frac{d_l}{2}\times 1} vi1R2dl×1 v i 2 ∈ R d l 2 × 1 v^{i2}\in \mathbb{R}^{\frac{d_l}{2}\times 1} vi2R2dl×1,具体的计算公式如下所示: { q i h = W q h ⋅ W q ⋅ a i k i h = W k h ⋅ W k ⋅ a i , i = { 1 , 2 , 3 } , h = { 1 , 2 } v i h = W v h ⋅ W v ⋅ a i \left\{\begin{aligned}q^{ih}&=W^{qh}\cdot W^{q} \cdot a^i\\ k^{ih}&=W^{kh}\cdot W^{k} \cdot a^i,\quad i=\{1,2,3\},\quad h=\{1,2\}\\v^{ih}&=W^{vh}\cdot W^{v} \cdot a^i\end{aligned}\right. qihkihvih=WqhWqai=WkhWkai,i={1,2,3},h={1,2}=WvhWvai令矩阵 Q 1 = ( q 11 , q 21 , q 31 ) ∈ R d m × 3 Q 2 = ( q 12 , q 22 , q 32 ) ∈ R d m × 3 K 1 = ( k 11 , k 21 , k 31 ) ∈ R d m × 3 K 2 = ( k 12 , k 22 , k 32 ) ∈ R d m × 3 V 1 = ( v 11 , v 21 , v 31 ) ∈ R d l 2 × 3 V 2 = ( v 12 , v 22 , v 32 ) ∈ R d l 2 × 3 \begin{array}{ll}Q^{1}=(q^{11},q^{21},q^{31})\in \mathbb{R}^{d_m\times 3}&\quad Q^2=(q^{12},q^{22},q^{32})\in\mathbb{R}^{d_m\times 3}\\K^{1}=(k^{11},k^{21},k^{31})\in \mathbb{R}^{d_m\times 3}&\quad K^2=(k^{12},k^{22},k^{32})\in\mathbb{R}^{d_m\times 3}\\V^{1}=(v^{11},v^{21},v^{31})\in \mathbb{R}^{\frac{d_l}{2}\times 3}&\quad V^2=(v^{12},v^{22},v^{32})\in\mathbb{R}^{\frac{d_l}{2}\times 3}\end{array} Q1=(q11,q21,q31)Rdm×3K1=(k11,k21,k31)Rdm×3V1=(v11,v21,v31)R2dl×3Q2=(q12,q22,q32)Rdm×3K2=(k12,k22,k32)Rdm×3V2=(v12,v22,v32)R2dl×3此时则有 Q 1 = W q 1 ⋅ W q ⋅ A Q 2 = W q 2 ⋅ W q ⋅ A K 1 = W k 1 ⋅ W k ⋅ A K 2 = W k 2 ⋅ W k ⋅ A V 1 = W v 1 ⋅ W v ⋅ A V 2 = W v 2 ⋅ W v ⋅ A \begin{array}{ll}Q^{1}=W^{q1}\cdot W^{q} \cdot A &\quad Q^2=W^{q2}\cdot W^{q} \cdot A\\K^{1}=W^{k1}\cdot W^{k} \cdot A&\quad K^2=W^{k2}\cdot W^{k} \cdot A\\V^{1}=W^{v1}\cdot W^{v} \cdot A&\quad V^2=W^{v2}\cdot W^{v} \cdot A\end{array} Q1=Wq1WqAK1=Wk1WkAV1=Wv1WvAQ2=Wq2WqAK2=Wk2WkAV2=Wv2WvA对于每个 H e a d \mathrm{Head} Head利用得到对于 Q u e r y \mathrm{Query} Query向量和 K e y \mathrm{Key} Key向量计算对应的注意力得分,其中注意力向量 α i h \alpha^{ih} αih的第 l l l个分量的计算公式为 α l i h = ( q i h ) ⊤ ⋅ k l h , i ∈ { 1 , 2 , 3 } , h ∈ { 1 , 2 } , l ∈ { 1 , 2 , 3 } \alpha^{ih}_l=(q^{ih})^{\top}\cdot k^{lh},\quad i\in\{1,2,3\},h\in\{1,2\},l\in\{1,2,3\} αlih=(qih)klh,i{1,2,3},h{1,2},l{1,2,3}令注意力分数矩阵 Λ 1 = ( α 11 , α 21 , α 31 ) \Lambda^1=(\alpha^{11},\alpha^{21},\alpha^{31}) Λ1=(α11,α21,α31) Λ 2 = ( α 12 , α 22 , α 32 ) \Lambda^2=(\alpha^{12},\alpha^{22},\alpha^{32}) Λ2=(α12,α22,α32),则有 Λ 1 = ( K 1 ) ⊤ ⋅ Q 1 d m , Λ 2 = ( K 2 ) ⊤ ⋅ Q 2 d m \Lambda^{1}=\frac{(K^1)^{\top}\cdot Q^1}{\sqrt{d_m}},\quad\Lambda^{2}=\frac{(K^2)^{\top}\cdot Q^2}{\sqrt{d_m}} Λ1=dm (K1)Q1,Λ2=dm (K2)Q2注意分数向量 α i h \alpha^{ih} αih经过 s o f t m a x \mathrm{softmax} softmax层得到归一化后的注意力分布 β i h \beta^{ih} βih,即为 β j i h = e α j i h ∑ n = 1 3 e α n i h , i , j = { 1 , 2 , 3 } , h = { 1 , 2 } \beta^{ih}_j = \frac{e^{\alpha^{ih}_j}}{\sum\limits_{n=1}^3e^{\alpha^{ih}_n}},\quad i,j=\{1,2,3\}, h=\{1,2\} βjih=n=13eαniheαjih,i,j={1,2,3},h={1,2}对于每一个 H e a d \mathrm{Head} Head利用得到的注意力分布向量 β i h \beta^{ih} βih V a l u e \mathrm{Value} Value矩阵 V h V^h Vh获得最后的输出 b i h ∈ R d l 2 × 1 b^{ih}\in \mathbb{R}^{\frac{d_l}{2} \times 1} bihR2dl×1,则有 b i h = ∑ l = 1 3 β l i h ⋅ v l h , i ∈ { 1 , 2 , 3 } , h ∈ { 1 , 2 } b^{ih}=\sum\limits^{3}_{l=1}\beta^{ih}_l \cdot v^{lh},\quad i \in \{1,2,3\}, h\in\{1,2\} bih=l=13βlihvlh,i{1,2,3},h{1,2}两个 H e a d \mathrm{Head} Head b i h b^{ih} bih的向量按照如下方式拼接在一起,则有 B = ( b 11 b 21 b 31 b 12 b 22 b 32 ) ∈ R d l × 3 B=\left(\begin{array}{lll}b^{11}&b^{21}&b^{31}\\b^{12}&b^{22}&b^{32}\end{array}\right)\in \mathbb{R}^{d_l \times 3} B=(b11b12b21b22b31b32)Rdl×3给定参数矩阵 W O ∈ R d l × d l W^{O}\in \mathbb{R}^{d_l\times d_l} WORdl×dl,则输出矩阵为 O = W O ⋅ B ∈ R d l × 3 O=W^{O}\cdot B\in \mathbb{R}^{d_l \times 3} O=WOBRdl×3综上所述则有 O = M u l t i H e a d ( Q , K , V ) = W O ⋅ C o n c a t ( V 1 ⋅ s o f t m a x ( ( K 1 ) ⊤ ⋅ Q 1 d m ) V 2 ⋅ s o f t m a x ( ( K 2 ) ⊤ ⋅ Q 2 d m ) ) O=\mathrm{MultiHead}(Q,K,V)=W^O\cdot\mathrm{Concat}\left(\begin{array}{l}V^1\cdot\mathrm{softmax}\left(\frac{(K^1)^{\top}\cdot Q^1}{\sqrt{d_m}}\right)\\ \\V^2\cdot\mathrm{softmax}\left(\frac{(K^2)^{\top}\cdot Q^2}{\sqrt{d_m}}\right)\end{array}\right) O=MultiHead(Q,K,V)=WOConcatV1softmax(dm (K1)Q1)V2softmax(dm (K2)Q2)

Mask Self-Attention

如下图左半部分所示, S e l f \mathrm{Self} Self- A t t e n t i o n \mathrm{Attention} Attention的输出向量 b i , i ∈ { 1 , 2 , 3 , 4 } b^i, i \in \{1,2,3,4\} bi,i{1,2,3,4}综合了输入向量 a i , i ∈ { 1 , 2 , 3 , 4 } a^i, i \in \{1,2,3,4\} ai,i{1,2,3,4}的全部信息,由此可见, S e l f \mathrm{Self} Self- A t t e n t i o n \mathrm{Attention} Attention在实际编程中支持并行运算。如下图右半部分所示, M a s k   S e l f \mathrm{Mask \text{ } Self} Mask Self- A t t e n t i o n \mathrm{Attention} Attention的输出向量 b i b^i bi只利用了已知部分输入的向量 a i a^i ai的信息。例如, b 1 b1 b1只是与 a 1 a^1 a1有关; b 2 b^2 b2 a 1 a^1 a1 a 2 a^2 a2有关; b 3 b^3 b3 a 1 a^1 a1 a 2 a^2 a2 a 3 a^3 a3有关; b 4 b^4 b4 a 1 a^1 a1 a 2 a^2 a2 a 3 a^3 a3 a 4 a^4 a4有关。 M a s k   S e l f \mathrm{Mask \text{ } Self} Mask Self- A t t e n t i o n \mathrm{Attention} Attention T r a n s f o r m e r \mathrm{Transformer} Transformer中被用到过两次。

  • T r a n s f o r m e r \mathrm{Transformer} Transformer E n c o d e r \mathrm{Encoder} Encoder中如果输入一句话的 w o r d \mathrm{word} word长度小于指定的长度,为了能够让长度一致往往会用 0 0 0进行填充,此时则需要用 M a s k   S e l f \mathrm{Mask \text{ } Self} Mask Self- A t t e n t i o n \mathrm{Attention} Attention来计算注意力分布。
  • T r a n s f o r m e r \mathrm{Transformer} Transformer D e c o d e r \mathrm{Decoder} Decoder的输出是有时序关系的,当前的输出只与之前的输入有关,所以此时算注意力分布时需要用到 M a s k   S e l f \mathrm{Mask \text{ } Self} Mask Self- A t t e n t i o n \mathrm{Attention} Attention

Transformer模型

以上对 T r a n s f o r m e r \mathrm{Transformer} Transformer中的核心内容即自注意力机制进行了详细解剖,接下来会对 T r a n s f o r m e r \mathrm{Transformer} Transformer模型架构进行介绍。 T r a n s f o r m e r \mathrm{Transformer} Transformer模型是由 E n c o d e r \mathrm{Encoder} Encoder D e c o d e r \mathrm{Decoder} Decoder两个模块组成,具体的示意图如下所示,为了能够对 T r a n s f o r m e r \mathrm{Transformer} Transformer内部的操作细节进行更清晰的展示,下图以矩阵运算的视角对 T r a n s f o r m e r \mathrm{Transformer} Transformer的原理进行讲解。
E n c o d e r \mathrm{Encoder} Encoder模块操作的具体流程如下所示:

  • E n c o d e r \mathrm{Encoder} Encoder的输入由两部分组成分别是词编码矩阵 I ∈ R n × l × d I \in \mathbb{R}^{n \times l \times d} IRn×l×d和位置编码矩阵 P ∈ R n × l × d P \in \mathbb{R}^{n \times l \times d} PRn×l×d,其中 n n n表示句子数目, l l l表示一句话单词的最大数目, d d d表示的是词向量的维度。位置编码矩阵 P P P表示的是每个单词在一句里的所有位置信息,因为 S e l f \mathrm{Self} Self- A t t e n t i o n \mathrm{Attention} Attention计算注意力分布的时候只能给出输出向量和输入向量之间的权重关系,但是不能给出词在一句话里的位置信息,所以需要在输入里引入位置编码矩阵 P P P。位置编码向量生成方法有很多。一种比较简单粗暴的方式就是根据单词在句子中的位置生成一个 o n e \mathrm{one} one- h o t \mathrm{hot} hot的位置编码;还有的方法是将位置编码当成参数进行训练学习;在该论文里是利用三角函数对位置进行编码,具体的公式如下所示 P E ( p o s , 2 i ) = sin ⁡ ( p o s 100 0 2 i / d ) , P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s 100 0 2 i / d ) \mathrm{PE}(pos,2i)=\sin(\frac{pos}{1000^{2i/d}}),\quad \mathrm{PE}(pos,2i+1)=\cos(\frac{pos}{1000^{2i/d}}) PE(pos,2i)=sin(10002i/dpos),PE(pos,2i+1)=cos(10002i/dpos)其中 P E \mathrm{PE} PE表示的是位置编码向量, p o s pos pos表示词在句子中的位置, i i i表示编码向量的位置索引。
  • 输入矩阵 I + P I+P I+P通过线性变换生成矩阵 Q Q Q K K K V V V。在实际编程中是将输入 I + P I+P I+P直接赋值给 Q Q Q K K K V V V。如果输入单词长度小于最大长度并 0 0 0来填充的时候,还要相应引入 M a s k \mathrm{Mask} Mask矩阵。
  • 将矩阵 Q Q Q K K K V V V输入到 M u l t i \mathrm{Multi} Multi- H e a d   A t t e n t i o n \mathrm{Head\text{ }Attention} Head Attention模块中进行注意分布的计算得到矩阵 I ′ ∈ R n × l × d I^{\prime}\in \mathbb{R}^{n \times l \times d} IRn×l×d,计算公式为 I ′ = M u l t i H e a d ( Q , K , V ) I^{\prime}=\mathrm{MultiHead}(Q,K,V) I=MultiHead(Q,K,V)具体的计算细节参考上文关于 M u l t i \mathrm{Multi} Multi- H e a d   A t t e n t i o n \mathrm{Head\text{ }Attention} Head Attention原理的讲解不在这里赘述。然后将原始输入 I + P I+P I+P与注意力分布 I ′ I^{\prime} I进行残差计算得到输出矩阵 I + P + I ′ ∈ R n × l × d I+P+I^{\prime}\in \mathbb{R}^{n \times l \times d} I+P+IRn×l×d
  • 对矩阵 I + P + I ′ = { x i j k } n l d I+P+I^{\prime}=\{x_{ijk}\}^{nld} I+P+I={xijk}nld进行层归一化操作得到 I ′ ′ ∈ R n × l × d I^{\prime\prime}\in\mathbb{R}^{n \times l \times d} IRn×l×d,具体的计算公式为 { μ i j = ∑ k = 1 d x i j k σ i j = ∑ k = 1 d ( x i j k − μ i j ) 2 ⟹ x ^ i j k = x i j k − u i j σ i j , i ∈ { 1 , ⋯   , n } , j ∈ { 1 , ⋯   , l } , k ∈ { 1 , ⋯   , d } \left\{\begin{aligned}\mu^{ij}&=\sum\limits_{k=1}^d x_{ijk}\\\sigma^{ij}&=\sqrt{\sum\limits_{k=1}^d\left(x_{ijk}-\mu^{ij}\right)^2}\end{aligned}\right. \Longrightarrow \hat{x}_{ijk}=\frac{x_{ijk}-u^{ij}}{\sigma^{ij}},\quad i\in\{1,\cdots,n\},j\in\{1,\cdots,l\},k\in\{1,\cdots,d\} μijσij=k=1dxijk=k=1d(xijkμij)2 x^ijk=σijxijkuij,i{1,,n},j{1,,l},k{1,,d}
  • I ′ ′ I^{\prime\prime} I输入到全连接神经网络中得到 I ′ ′ ′ ∈ R n × l × d I^{\prime\prime\prime}\in \mathbb{R}^{n \times l \times d} IRn×l×d ,然后再让全连接神经网络的输入 I ′ ′ I^{\prime\prime} I与输出 I ′ ′ ′ I^{\prime\prime\prime} I进行残差计算得到 I ′ ′ + I ′ ′ ′ I^{\prime\prime}+I^{\prime\prime\prime} I+I,接着对 I ′ ′ + I ′ ′ ′ I^{\prime\prime}+I^{\prime\prime\prime} I+I进行层归一化操作。
  • 以上是一个 B l o c k \mathrm{Block} Block的操作原理,将 N N N B l o c k \mathrm{Block} Block进行堆叠就组成了 E n c o d e r \mathrm{Encoder} Encoder的模块,得到的最后输出为 I N ∈ R n × l × d I^N \in \mathbb{R}^{n \times l \times d} INRn×l×d。这里需要注意的是 E n c o d e r \mathrm{Encoder} Encoder模块中的各个组件的操作顺序并不是固定的,也可以先进行归一化操作,然后再计算注意力分布,再归一化,再预测等。

D e c o d e r \mathrm{Decoder} Decoder模块操作的具体流程如下所示:

  • D e c o d e r \mathrm{Decoder} Decoder的输入也由两部分组成分别是词编码矩阵 O ∈ R n 1 × l 1 × d O \in \mathbb{R}^{n_1 \times l_1 \times d} ORn1×l1×d和位置编码矩阵 P O ∈ R n 1 × l 1 × d P^O \in \mathbb{R}^{n_1 \times l_1 \times d} PORn1×l1×d。因为 D e c o d e r \mathrm{Decoder} Decoder的输入是具有时顺序关系的(即上一步的输出为当前步输入)所以还需要输入 M a s k \mathrm{Mask} Mask矩阵 M M M以便计算注意力分布。
  • 输入矩阵 O + P O O+P^O O+PO通过线性变换生成矩阵 Q ^ \hat{Q} Q^ K ^ \hat{K} K^ V ^ \hat{V} V^。在实际编程中是将输入 O + P O O+P^O O+PO直接赋值给 Q ^ \hat{Q} Q^ K ^ \hat{K} K^ V ^ \hat{V} V^。如果输入单词长度小于最大长度并 0 0 0来填充的时候,还要相应引入 M a s k \mathrm{Mask} Mask矩阵。
  • 将矩阵 Q ^ \hat{Q} Q^ K ^ \hat{K} K^ V ^ \hat{V} V^以及 M a s k \mathrm{Mask} Mask矩阵 M M M输入到 M a s k   M u l t i \mathrm{Mask\text{ }Multi} Mask Multi- H e a d   A t t e n t i o n \mathrm{Head\text{ }Attention} Head Attention模块中进行注意分布的计算得到矩阵 O ′ ∈ R n 1 × l 1 × d O^{\prime}\in \mathbb{R}^{n_1 \times l_1 \times d} ORn1×l1×d,计算公式为 O ′ = M a s k M u l t i H e a d ( Q ^ , K ^ , V ^ , M ) O^{\prime}=\mathrm{MaskMultiHead}(\hat{Q},\hat{K},\hat{V},M) O=MaskMultiHead(Q^,K^,V^,M)具体的计算细节参考上文关于 M a s k   S e l f \mathrm{Mask \text{ }Self} Mask Self- A t t e n t i o n \mathrm{Attention} Attention的讲解不在这里赘述。然后将原始输入 O + P O O+P^O O+PO与注意力分布 O ′ O^{\prime} O进行残差计算得到输出矩阵 O + P O + O ′ ∈ R n 1 × l 1 × d O+P^O+O^{\prime}\in \mathbb{R}^{n_1 \times l_1 \times d} O+PO+ORn1×l1×d。接着再对矩阵 O + P O + O ′ O+P^O+O^{\prime} O+PO+O进行层归一化操作得到 O ′ ′ ∈ R n 1 × l 1 × d O^{\prime\prime}\in\mathbb{R}^{n_1 \times l_1 \times d} ORn1×l1×d
  • E n c o d e r \mathrm{Encoder} Encoder的输出 I N I^N IN通过线性变换得到 Q N Q^N QN K N K^N KN O ′ O^{\prime} O进行线性变换得到 V ^ ′ \hat{V}^{\prime} V^,利用矩阵 Q N Q^N QN K N K^N KN V ^ ′ \hat{V}^{\prime} V^进行交叉注意力分布的计算得到 O ′ ′ ′ O^{\prime\prime\prime} O,计算公式为 O ′ ′ ′ = M u l t i H e a d ( Q N , K N , V ^ ′ ) O^{\prime\prime\prime}=\mathrm{MultiHead}(Q^N,K^N,\hat{V}^{\prime}) O=MultiHead(QN,KN,V^)这里的交叉注意力分布综合 E n c o d e r \mathrm{Encoder} Encoder输出结果和 D e c o d e r \mathrm{Decoder} Decoder中间结果的信息。实际编程编程中将 I N I^N IN直接赋值给 Q ^ \hat{Q} Q^ K ^ \hat{K} K^ O ′ O^{\prime} O直接赋值给 V ^ ′ \hat{V}^{\prime} V^。然后将 O ′ ′ O^{\prime\prime} O与注意力分布 O ′ ′ ′ O^{\prime\prime\prime} O进行残差计算得到输出矩阵 O ′ ′ + O ′ ′ ′ O^{\prime\prime}+O^{\prime\prime\prime} O+O
  • 接着对 O ′ ′ + O ′ ′ ′ O^{\prime\prime}+O^{\prime\prime\prime} O+O进行层归一操作得到 O ′ ′ ′ ′ O^{\prime\prime\prime\prime} O,再将 O ′ ′ ′ ′ O^{\prime\prime\prime\prime} O输入到全连接神经网络中得到 O ′ ′ ′ ′ ′ O^{\prime\prime\prime\prime\prime} O,接着再做一步残差操作得到 O ′ ′ ′ ′ + O ′ ′ ′ ′ ′ O^{\prime\prime\prime\prime}+O^{\prime\prime\prime\prime\prime} O+O,最后再进行一层归一化操作。
  • 以上是一个 B l o c k \mathrm{Block} Block的操作原理,将 N N N B l o c k \mathrm{Block} Block进行堆叠就组成了 D e c o d e r \mathrm{Decoder} Decoder的模块,得到的输出为 O N ∈ R n 1 × l 1 × d O^N \in \mathbb{R}^{n_1 \times l_1 \times d} ONRn1×l1×d。然后在词汇字典中找到当前预测最大概率的单词,并将该单词词向量作为下一阶段的输入,重复以上步骤,直到输出“ e n d \mathrm{end} end”字符为止。

代码示例

T r a n s f o r m e r \mathrm{Transformer} Transformer具体的代码示例如下所示为一个国外博主视频里的代码,并根据上文对代码的一些细节进行了探讨。根据上文中 M u l t i \mathrm{Multi} Multi- H e a d   A t t e n t i o n \mathrm{Head\text{ }Attention} Head Attention原理示例图可知,严格来看 M u l t i \mathrm{Multi} Multi- H e a d   A t t e n t i o n \mathrm{Head\text{ }Attention} Head Attention在求注意分布的时候中间其实是有两步线性变换。给定输入向量 x ∈ R 256 × 1 x\in \mathbb{R}^{256\times 1} xR256×1 第一步线性变换直接让向量 x x x赋值给 q q q k k k v v v,这一过程以下程序中有所体现,在这里并不会产生歧义。第二步线性变换产生多 H e a d \mathrm{Head} Head,假设 H e a d = 8 \mathrm{Head}=8 Head=8的时候,按理说 q q q要与 8 8 8个矩阵 W q 1 , ⋯   , W q 8 W^{q1},\cdots,W^{q8} Wq1,,Wq8进行线性变换得到 8 8 8 q 1 , ⋯   , q 8 q^{1},\cdots,q^{8} q1,,q8,同理 k k k要与 8 8 8个矩阵 W k 1 , ⋯   , W k 8 W^{k1},\cdots,W^{k8} Wk1,,Wk8进行线性变换得到 8 8 8 k 1 , ⋯   , k 8 k^{1},\cdots,k^{8} k1,,k8 v v v要与 8 8 8个矩阵 W v 1 , ⋯   , W v 8 W^{v1},\cdots,W^{v8} Wv1,,Wv8进行线性变换得到 8 8 8 v 1 , ⋯   , v 8 v^{1},\cdots,v^{8} v1,,v8,如果按照这个方式在程序实现则需要定义24个权重矩阵,非常的麻烦。以下程序中有一个简单的权重定义方法,通过该方法也可以实现以上多 H e a d \mathrm{Head} Head的线性变换,以向量 q = ( q 1 , ⋯   , q 256 ) ⊤ ∈ R 256 × 1 q = (q_1,\cdots, q_{256})^{\top}\in \mathbb{R}^{256 \times 1} q=(q1,,q256)R256×1为例:

  • 首先将向量 q q q进行截断分成 H e a d = 8 \mathrm{Head}=8 Head=8个向量,即为 { q ( 1 ) = ( E , 0 , 0 , 0 , 0 , 0 , 0 , 0 ) ⋅ q q ( 2 ) = ( 0 , E , 0 , 0 , 0 , 0 , 0 , 0 ) ⋅ q q ( 3 ) = ( 0 , 0 , E , 0 , 0 , 0 , 0 , 0 ) ⋅ q q ( 4 ) = ( 0 , 0 , 0 , E , 0 , 0 , 0 , 0 ) ⋅ q q ( 5 ) = ( 0 , 0 , 0 , 0 , E , 0 , 0 , 0 ) ⋅ q q ( 6 ) = ( 0 , 0 , 0 , 0 , 0 , E , 0 , 0 ) ⋅ q q ( 7 ) = ( 0 , 0 , 0 , 0 , 0 , 0 , E , 0 ) ⋅ q q ( 8 ) = ( 0 , 0 , 0 , 0 , 0 , 0 , 0 , E ) ⋅ q \left\{\begin{aligned}q^{(1)}&=({\bf{E},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0}})\cdot q\\q^{(2)}&=({\bf{0},\bf{E},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0}})\cdot q\\q^{(3)}&=({\bf{0},\bf{0},\bf{E},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0}})\cdot q\\q^{(4)}&=({\bf{0},\bf{0},\bf{0},\bf{E},\bf{0},\bf{0},\bf{0},\bf{0}})\cdot q\\q^{(5)}&=({\bf{0},\bf{0},\bf{0},\bf{0},\bf{E},\bf{0},\bf{0},\bf{0}})\cdot q\\q^{(6)}&=({\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{E},\bf{0},\bf{0}})\cdot q\\q^{(7)}&=({\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{E},\bf{0}})\cdot q\\q^{(8)}&=({\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{E}})\cdot q \end{aligned}\right. q(1)q(2)q(3)q(4)q(5)q(6)q(7)q(8)=(E,0,0,0,0,0,0,0)q=(0,E,0,0,0,0,0,0)q=(0,0,E,0,0,0,0,0)q=(0,0,0,E,0,0,0,0)q=(0,0,0,0,E,0,0,0)q=(0,0,0,0,0,E,0,0)q=(0,0,0,0,0,0,E,0)q=(0,0,0,0,0,0,0,E)q其中 q ( i ) ∈ R 32 × 1 q^{(i)}\in \mathbb{R}^{32\times 1} q(i)R32×1 q q q的第 i i i个截断向量, E ∈ R 32 × 32 {\bf{E}}\in \mathbb{R}^{32 \times 32} ER32×32是单位矩阵, 0 ∈ R 32 × 32 {\bf{0}}\in \mathbb{R}^{32 \times 32} 0R32×32是零矩阵。
  • 然后对 q ( i ) , i ∈ { 1 , ⋯   , 8 } q^{(i)},i\in \{1,\cdots,8\} q(i),i{1,,8}用相同的权重矩阵 W ∈ R 32 × 32 W \in \mathbb{R}^{32 \times 32} WR32×32进行线性变换,此时可以发现,训练过程的时候只需要更新权重矩阵 W W W即可,而且可以进行多 H e a d \mathrm{Head} Head线性变换, 8 8 8个权重矩阵可以表示为: { W q 1 = W ⋅ ( E , 0 , 0 , 0 , 0 , 0 , 0 , 0 ) = ( W , 0 , 0 , 0 , 0 , 0 , 0 , 0 ) W q 2 = W ⋅ ( 0 , E , 0 , 0 , 0 , 0 , 0 , 0 ) = ( 0 , W , 0 , 0 , 0 , 0 , 0 , 0 ) W q 3 = W ⋅ ( 0 , 0 , E , 0 , 0 , 0 , 0 , 0 ) = ( 0 , 0 , W , 0 , 0 , 0 , 0 , 0 ) W q 4 = W ⋅ ( 0 , 0 , 0 , E , 0 , 0 , 0 , 0 ) = ( 0 , 0 , 0 , W , 0 , 0 , 0 , 0 ) W q 5 = W ⋅ ( 0 , 0 , 0 , 0 , E , 0 , 0 , 0 ) = ( 0 , 0 , 0 , 0 , W , 0 , 0 , 0 ) W q 6 = W ⋅ ( 0 , 0 , 0 , 0 , 0 , E , 0 , 0 ) = ( 0 , 0 , 0 , 0 , 0 , W , 0 , 0 ) W q 7 = W ⋅ ( 0 , 0 , 0 , 0 , 0 , 0 , E , 0 ) = ( 0 , 0 , 0 , 0 , 0 , 0 , W , 0 ) W q 8 = W ⋅ ( 0 , 0 , 0 , 0 , 0 , 0 , 0 , E ) = ( 0 , 0 , 0 , 0 , 0 , 0 , 0 , W ) \left\{\begin{aligned}W^{q1}&=W\cdot ({\bf{E},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0}})=(W,{\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0}})\\W^{q2}&=W\cdot ({\bf{0},\bf{E},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0}})=({\bf{0},}W{,\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0}})\\W^{q3}&=W\cdot ({\bf{0},\bf{0},\bf{E},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0}})=({\bf{0},\bf{0},}W{,\bf{0},\bf{0},\bf{0},\bf{0},\bf{0}})\\W^{q4}&=W\cdot ({\bf{0},\bf{0},\bf{0},\bf{E},\bf{0},\bf{0},\bf{0},\bf{0}})=({\bf{0},\bf{0},\bf{0},}W{,\bf{0},\bf{0},\bf{0},\bf{0}})\\W^{q5}&=W\cdot ({\bf{0},\bf{0},\bf{0},\bf{0},\bf{E},\bf{0},\bf{0},\bf{0}})=({\bf{0},\bf{0},\bf{0},\bf{0},}W{,\bf{0},\bf{0},\bf{0}})\\W^{q6}&=W\cdot ({\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{E},\bf{0},\bf{0}})=({\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},}W{,\bf{0},\bf{0}})\\W^{q7}&=W\cdot ({\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{E},\bf{0}})=({\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},}W{,\bf{0}})\\W^{q8}&=W\cdot ({\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{E}})=({\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},\bf{0},}W{})\end{aligned}\right. Wq1Wq2Wq3Wq4Wq5Wq6Wq7Wq8=W(E,0,0,0,0,0,0,0)=(W,0,0,0,0,0,0,0)=W(0,E,0,0,0,0,0,0)=(0,W,0,0,0,0,0,0)=W(0,0,E,0,0,0,0,0)=(0,0,W,0,0,0,0,0)=W(0,0,0,E,0,0,0,0)=(0,0,0,W,0,0,0,0)=W(0,0,0,0,E,0,0,0)=(0,0,0,0,W,0,0,0)=W(0,0,0,0,0,E,0,0)=(0,0,0,0,0,W,0,0)=W(0,0,0,0,0,0,E,0)=(0,0,0,0,0,0,W,0)=W(0,0,0,0,0,0,0,E)=(0,0,0,0,0,0,0,W)其中权重矩阵 W q i ∈ R 32 × 256 , i ∈ { 1 , ⋯   , 8 } W^{qi}\in\mathbb{R}^{32 \times 256},i\in\{1,\cdots,8\} WqiR32×256,i{1,,8}
import torch
import torch.nn as nn
import os

class SelfAttention(nn.Module):
	def __init__(self, embed_size, heads):
		super(SelfAttention, self).__init__()
		self.embed_size = embed_size
		self.heads = heads
		self.head_dim = embed_size // heads

		assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"

		self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
		self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
		self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
		self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

	def forward(self, values, keys, query, mask):
		N =query.shape[0]
		value_len , key_len , query_len = values.shape[1], keys.shape[1], query.shape[1]

		# split embedding into self.heads pieces
		values = values.reshape(N, value_len, self.heads, self.head_dim)
		keys = keys.reshape(N, key_len, self.heads, self.head_dim)
		queries = query.reshape(N, query_len, self.heads, self.head_dim)
		
		values = self.values(values)
		keys = self.keys(keys)
		queries = self.queries(queries)

		energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)
		# queries shape: (N, query_len, heads, heads_dim)
		# keys shape : (N, key_len, heads, heads_dim)
		# energy shape: (N, heads, query_len, key_len)

		if mask is not None:
			energy = energy.masked_fill(mask == 0, float("-1e20"))

		attention = torch.softmax(energy/ (self.embed_size ** (1/2)), dim=3)

		out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
		# attention shape: (N, heads, query_len, key_len)
		# values shape: (N, value_len, heads, heads_dim)
		# (N, query_len, heads, head_dim)

		out = self.fc_out(out)
		return out


class TransformerBlock(nn.Module):
	def __init__(self, embed_size, heads, dropout, forward_expansion):
		super(TransformerBlock, self).__init__()
		self.attention = SelfAttention(embed_size, heads)
		self.norm1 = nn.LayerNorm(embed_size)
		self.norm2 = nn.LayerNorm(embed_size)

		self.feed_forward = nn.Sequential(
			nn.Linear(embed_size, forward_expansion*embed_size),
			nn.ReLU(),
			nn.Linear(forward_expansion*embed_size, embed_size)
		)
		self.dropout = nn.Dropout(dropout)

	def forward(self, value, key, query, mask):
		attention = self.attention(value, key, query, mask)

		x = self.dropout(self.norm1(attention + query))
		forward = self.feed_forward(x)
		out = self.dropout(self.norm2(forward + x))
		return out


class Encoder(nn.Module):
	def __init__(
			self,
			src_vocab_size,
			embed_size,
			num_layers,
			heads,
			device,
			forward_expansion,
			dropout,
			max_length,
		):
		super(Encoder, self).__init__()
		self.embed_size = embed_size
		self.device = device
		self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
		self.position_embedding = nn.Embedding(max_length, embed_size)

		self.layers = nn.ModuleList(
			[
				TransformerBlock(
					embed_size,
					heads,
					dropout=dropout,
					forward_expansion=forward_expansion,
					)
				for _ in range(num_layers)]
		)
		self.dropout = nn.Dropout(dropout)


	def forward(self, x, mask):
		N, seq_length = x.shape
		positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
		out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
		for layer in self.layers:
			out = layer(out, out, out, mask)

		return out


class DecoderBlock(nn.Module):
	def __init__(self, embed_size, heads, forward_expansion, dropout, device):
		super(DecoderBlock, self).__init__()
		self.attention = SelfAttention(embed_size, heads)
		self.norm = nn.LayerNorm(embed_size)
		self.transformer_block = TransformerBlock(
			embed_size, heads, dropout, forward_expansion
		)

		self.dropout = nn.Dropout(dropout)

	def forward(self, x, value, key, src_mask, trg_mask):
		attention = self.attention(x, x, x, trg_mask)
		query = self.dropout(self.norm(attention + x))
		out = self.transformer_block(value, key, query, src_mask)
		return out

class Decoder(nn.Module):
	def __init__(
			self,
			trg_vocab_size,
			embed_size,
			num_layers,
			heads,
			forward_expansion,
			dropout,
			device,
			max_length,
	):
		super(Decoder, self).__init__()
		self.device = device
		self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
		self.position_embedding = nn.Embedding(max_length, embed_size)
		self.layers = nn.ModuleList(
			[DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
			for _ in range(num_layers)]
			)
		self.fc_out = nn.Linear(embed_size, trg_vocab_size)
		self.dropout = nn.Dropout(dropout)

	def forward(self, x ,enc_out , src_mask, trg_mask):
		N, seq_length = x.shape
		positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
		x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

		for layer in self.layers:
			x = layer(x, enc_out, enc_out, src_mask, trg_mask)

		out =self.fc_out(x)
		return out


class Transformer(nn.Module):
	def __init__(
			self,
			src_vocab_size,
			trg_vocab_size,
			src_pad_idx,
			trg_pad_idx,
			embed_size = 256,
			num_layers = 6,
			forward_expansion = 4,
			heads = 8,
			dropout = 0,
			device="cuda",
			max_length=100
		):
		super(Transformer, self).__init__()
		self.encoder = Encoder(
			src_vocab_size,
			embed_size,
			num_layers,
			heads,
			device,
			forward_expansion,
			dropout,
			max_length
			)
		self.decoder = Decoder(
			trg_vocab_size,
			embed_size,
			num_layers,
			heads,
			forward_expansion,
			dropout,
			device,
			max_length
			)


		self.src_pad_idx = src_pad_idx
		self.trg_pad_idx = trg_pad_idx
		self.device = device


	def make_src_mask(self, src):
		src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
		# (N, 1, 1, src_len)
		return src_mask.to(self.device)

	def make_trg_mask(self, trg):
		N, trg_len = trg.shape
		trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
			N, 1, trg_len, trg_len
		)
		return trg_mask.to(self.device)

	def forward(self, src, trg):
		src_mask = self.make_src_mask(src)
		trg_mask = self.make_trg_mask(trg)
		enc_src = self.encoder(src, src_mask)
		out = self.decoder(trg, enc_src, src_mask, trg_mask)
		return out


if __name__ == '__main__':
	device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	print(device)
	x = torch.tensor([[1,5,6,4,3,9,5,2,0],[1,8,7,3,4,5,6,7,2]]).to(device)
	trg = torch.tensor([[1,7,4,3,5,9,2,0],[1,5,6,2,4,7,6,2]]).to(device)

	src_pad_idx = 0
	trg_pad_idx = 0
	src_vocab_size = 10
	trg_vocab_size = 10
	model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(device)
	out = model(x, trg[:, : -1])
	print(out.shape)

转载自CSDN-专业IT技术社区

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

原文链接:https://blog.csdn.net/qq_38406029/article/details/122050257

评论

赞0

评论列表

微信小程序
QQ小程序

关于作者

点赞数:0
关注数:0
粉丝:0
文章:0
关注标签:0
加入于:--