Transformer架构详解

深入理解自注意力机制与位置编码

Posted by Feng Yu on September 1, 2024

Transformer架构的诞生标志着深度学习进入新纪元。2017年,Google团队在论文《Attention is All You Need》中提出Transformer,彻底颠覆了序列建模的传统范式。它不仅催生了BERT、GPT等里程碑模型,更成为现代大语言模型的基石。


一、核心架构

Encoder块结构:

Input Embedding + Positional Encoding
    ↓
Multi-Head Self-Attention
    ↓
Add & Norm
    ↓
Feed-Forward Network
    ↓
Add & Norm

完整实现

import torch.nn as nn

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        
        # 多头自注意力
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        
        # 前馈网络
        self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout)
        
        # Layer Normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Dropout
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # 自注意力 + 残差连接 + LayerNorm
        attn_output, _ = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_output))
        
        # FFN + 残差连接 + LayerNorm
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout2(ffn_output))
        
        return x

Decoder块增加交叉注意力:

class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        
        # 三个注意力层
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
    
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        # 1. Masked Self-Attention(防止看到未来)
        attn_output, _ = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout1(attn_output))
        
        # 2. Cross-Attention(attend to encoder)
        attn_output, _ = self.cross_attn(x, encoder_output, encoder_output, src_mask)
        x = self.norm2(x + self.dropout2(attn_output))
        
        # 3. FFN
        ffn_output = self.ffn(x)
        x = self.norm3(x + self.dropout3(ffn_output))
        
        return x

二、完整Transformer模型

class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        tgt_vocab_size,
        d_model=512,
        num_heads=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        d_ff=2048,
        dropout=0.1,
        max_len=5000
    ):
        super().__init__()
        
        # Embedding层
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        
        # 位置编码
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
        
        # Encoder
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_encoder_layers)
        ])
        
        # Decoder
        self.decoder_layers = nn.ModuleList([
            TransformerDecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_decoder_layers)
        ])
        
        # 输出层
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)
        
        self.dropout = nn.Dropout(dropout)
        self.d_model = d_model
    
    def generate_mask(self, src, tgt):
        # Encoder mask(padding)
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        # [batch, 1, 1, src_len]
        
        # Decoder mask(padding + causal)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        # [batch, 1, tgt_len, 1]
        
        seq_len = tgt.size(1)
        nopeak_mask = torch.tril(torch.ones(seq_len, seq_len)).bool()
        nopeak_mask = nopeak_mask.to(tgt.device)
        
        tgt_mask = tgt_mask & nopeak_mask
        
        return src_mask, tgt_mask
    
    def encode(self, src, src_mask):
        # Embedding + 位置编码
        x = self.src_embedding(src) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        
        # 通过Encoder层
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        
        return x
    
    def decode(self, tgt, encoder_output, src_mask, tgt_mask):
        # Embedding + 位置编码
        x = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        
        # 通过Decoder层
        for layer in self.decoder_layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        
        return x
    
    def forward(self, src, tgt):
        """
        Args:
            src: [batch, src_len]
            tgt: [batch, tgt_len]
        Returns:
            output: [batch, tgt_len, tgt_vocab_size]
        """
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        
        encoder_output = self.encode(src, src_mask)
        decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
        
        output = self.fc_out(decoder_output)
        return output

三、训练技巧

Label Smoothing

class LabelSmoothingLoss(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing
    
    def forward(self, pred, target):
        # pred: [batch * seq_len, vocab_size]
        # target: [batch * seq_len]
        
        vocab_size = pred.size(-1)
        confidence = 1.0 - self.smoothing
        
        true_dist = torch.zeros_like(pred)
        true_dist.fill_(self.smoothing / (vocab_size - 1))
        true_dist.scatter_(1, target.unsqueeze(1), confidence)
        
        return F.kl_div(F.log_softmax(pred, dim=-1), true_dist, reduction='batchmean')

Warmup Learning Rate

class NoamOpt:
    def __init__(self, d_model, warmup_steps, optimizer):
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.optimizer = optimizer
        self._step = 0
    
    def step(self):
        self._step += 1
        lr = self.d_model ** (-0.5) * min(
            self._step ** (-0.5),
            self._step * self.warmup_steps ** (-1.5)
        )
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        
        self.optimizer.step()

四、应用场景

机器翻译

# 训练
model = Transformer(src_vocab_size, tgt_vocab_size)
criterion = LabelSmoothingLoss()

for src, tgt in dataloader:
    output = model(src, tgt[:, :-1])  # 输入去掉最后一个token
    loss = criterion(
        output.reshape(-1, output.size(-1)),
        tgt[:, 1:].reshape(-1)  # 目标去掉第一个token
    )
    loss.backward()
    optimizer.step()

文本生成(Greedy Decoding):

@torch.no_grad()
def generate(model, src, max_len=50):
    model.eval()
    
    src_mask, _ = model.generate_mask(src, src[:, :1])
    encoder_output = model.encode(src, src_mask)
    
    ys = torch.ones(1, 1).fill_(start_token).long()
    
    for i in range(max_len):
        _, tgt_mask = model.generate_mask(src, ys)
        decoder_output = model.decode(ys, encoder_output, src_mask, tgt_mask)
        
        prob = model.fc_out(decoder_output[:, -1])
        next_word = prob.argmax(dim=-1).unsqueeze(0)
        
        ys = torch.cat([ys, next_word], dim=1)
        
        if next_word.item() == end_token:
            break
    
    return ys

五、总结

Transformer的关键创新

  1. 自注意力机制:O(n²)复杂度换取全局视野
  2. 位置编码:巧妙注入顺序信息
  3. 残差连接 + LayerNorm:稳定深层网络训练
  4. 并行化:GPU友好,训练效率高

局限性

  • 序列长度的平方复杂度
  • 对位置编码的依赖
  • 缺乏归纳偏置(如CNN的平移不变性)

后续改进方向

  • Sparse Attention(Longformer)
  • Linear Attention(Performer)
  • 相对位置编码(T5, DeBERTa)

Transformer彻底改变了NLP,并逐渐渗透到CV、多模态、强化学习等领域,堪称深度学习的”万能架构”!🚀


💬 交流与讨论

⚠️ 尚未完成 Giscus 配置。请在 _config.yml 中设置 repo_idcategory_id 后重新部署,即可启用升级后的评论系统。

配置完成后,评论区将自动支持 Markdown 代码高亮与 LaTeX 数学公式渲染,访客回复会同步到 GitHub Discussions,并具备通知功能。