手撕-线性注意力


  • 核心改进

计算顺序优化:从 (Q·K^T)·V 改为 Q·(K^T·V),避免显式计算注意力矩阵

复杂度降低:从 O(n²d) 降到 O(nd²),当序列长度 n > 特征维度 d 时更高效

核函数:使用 elu(x) + 1 保证非负性,替代 softmax

  • 关键组件

核函数:将 Q、K 映射到非负空间

矩阵运算:先计算 K^T·V,再与 Q 相乘

归一化:除以注意力权重和,保持稳定性

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class LinearAttention(nn.Module):
    """
    线性注意力实现
    时间复杂度: O(n*d^2) 而非标准注意力的 O(n^2*d)
    """
    def __init__(self, dim, num_heads=8, qkv_bias=False, eps=1e-6):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.eps = eps

        # Q, K, V 线性变换
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)

    def kernel_function(self, x):
        """
        核函数:使用 elu + 1 来保证非负性
        也可以使用其他核函数如 softmax, relu 等
        """
        return F.elu(x) + 1

    def forward(self, x):
        """
        Args:
            x: 输入张量 [batch_size, seq_len, dim]
        Returns:
            输出张量 [batch_size, seq_len, dim]
        """
        B, N, C = x.shape

        # 计算 Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, B, num_heads, N, head_dim]
        q, k, v = qkv[0], qkv[1], qkv[2]

        # 应用核函数
        q = self.kernel_function(q)  # [B, num_heads, N, head_dim]
        k = self.kernel_function(k)  # [B, num_heads, N, head_dim]

        # 线性注意力计算
        # 关键:改变计算顺序 (Q * (K^T * V)) 而非 ((Q * K^T) * V)
        # 先计算 K^T * V: [B, num_heads, head_dim, head_dim]
        kv = torch.einsum('bhnd,bhnm->bhdm', k, v)

        # 归一化项: K 的列和 [B, num_heads, head_dim]
        k_sum = k.sum(dim=2)  # [B, num_heads, head_dim]

        # 计算输出: Q * (K^T * V)
        out = torch.einsum('bhnd,bhdm->bhnm', q, kv)

        # 归一化
        normalizer = torch.einsum('bhnd,bhd->bhn', q, k_sum)
        normalizer = normalizer.unsqueeze(-1).clamp(min=self.eps)
        out = out / normalizer

        # 合并多头
        out = out.transpose(1, 2).reshape(B, N, C)

        # 输出投影
        out = self.proj(out)

        return out


# 示例使用
if __name__ == "__main__":
    # 设置参数
    batch_size = 2
    seq_len = 100
    dim = 256
    num_heads = 8

    # 创建模型
    model = LinearAttention(dim=dim, num_heads=num_heads)

    # 创建随机输入
    x = torch.randn(batch_size, seq_len, dim)

    print(f"输入形状: {x.shape}")

    # 前向传播
    output = model(x)

    print(f"输出形状: {output.shape}")
    print(f"输出统计:")
    print(f"  均值: {output.mean().item():.4f}")
    print(f"  标准差: {output.std().item():.4f}")
    print(f"  最小值: {output.min().item():.4f}")
    print(f"  最大值: {output.max().item():.4f}")

    # 对比标准注意力的计算复杂度
    print(f"\n计算复杂度对比:")
    print(f"  标准注意力: O(n²d) = O({seq_len}² × {dim}) = {seq_len**2 * dim:,}")
    print(f"  线性注意力: O(nd²) = O({seq_len} × {dim}²) = {seq_len * dim**2:,}")

    # 当序列长度较长时,线性注意力更高效
    if seq_len > dim:
        print(f"  线性注意力节省: {(seq_len**2 * dim) / (seq_len * dim**2):.2f}x")

    # 测试梯度
    loss = output.sum()
    loss.backward()
    print(f"\n梯度检查通过 ✓")
    print(f"参数梯度存在: {model.qkv.weight.grad is not None}")