- 核心改进
计算顺序优化:从 (Q·K^T)·V 改为 Q·(K^T·V),避免显式计算注意力矩阵
复杂度降低:从 O(n²d) 降到 O(nd²),当序列长度 n > 特征维度 d 时更高效
核函数:使用 elu(x) + 1 保证非负性,替代 softmax
- 关键组件
核函数:将 Q、K 映射到非负空间
矩阵运算:先计算 K^T·V,再与 Q 相乘
归一化:除以注意力权重和,保持稳定性
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class LinearAttention(nn.Module):
"""
线性注意力实现
时间复杂度: O(n*d^2) 而非标准注意力的 O(n^2*d)
"""
def __init__(self, dim, num_heads=8, qkv_bias=False, eps=1e-6):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.eps = eps
# Q, K, V 线性变换
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
def kernel_function(self, x):
"""
核函数:使用 elu + 1 来保证非负性
也可以使用其他核函数如 softmax, relu 等
"""
return F.elu(x) + 1
def forward(self, x):
"""
Args:
x: 输入张量 [batch_size, seq_len, dim]
Returns:
输出张量 [batch_size, seq_len, dim]
"""
B, N, C = x.shape
# 计算 Q, K, V
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, num_heads, N, head_dim]
q, k, v = qkv[0], qkv[1], qkv[2]
# 应用核函数
q = self.kernel_function(q) # [B, num_heads, N, head_dim]
k = self.kernel_function(k) # [B, num_heads, N, head_dim]
# 线性注意力计算
# 关键:改变计算顺序 (Q * (K^T * V)) 而非 ((Q * K^T) * V)
# 先计算 K^T * V: [B, num_heads, head_dim, head_dim]
kv = torch.einsum('bhnd,bhnm->bhdm', k, v)
# 归一化项: K 的列和 [B, num_heads, head_dim]
k_sum = k.sum(dim=2) # [B, num_heads, head_dim]
# 计算输出: Q * (K^T * V)
out = torch.einsum('bhnd,bhdm->bhnm', q, kv)
# 归一化
normalizer = torch.einsum('bhnd,bhd->bhn', q, k_sum)
normalizer = normalizer.unsqueeze(-1).clamp(min=self.eps)
out = out / normalizer
# 合并多头
out = out.transpose(1, 2).reshape(B, N, C)
# 输出投影
out = self.proj(out)
return out
# 示例使用
if __name__ == "__main__":
# 设置参数
batch_size = 2
seq_len = 100
dim = 256
num_heads = 8
# 创建模型
model = LinearAttention(dim=dim, num_heads=num_heads)
# 创建随机输入
x = torch.randn(batch_size, seq_len, dim)
print(f"输入形状: {x.shape}")
# 前向传播
output = model(x)
print(f"输出形状: {output.shape}")
print(f"输出统计:")
print(f" 均值: {output.mean().item():.4f}")
print(f" 标准差: {output.std().item():.4f}")
print(f" 最小值: {output.min().item():.4f}")
print(f" 最大值: {output.max().item():.4f}")
# 对比标准注意力的计算复杂度
print(f"\n计算复杂度对比:")
print(f" 标准注意力: O(n²d) = O({seq_len}² × {dim}) = {seq_len**2 * dim:,}")
print(f" 线性注意力: O(nd²) = O({seq_len} × {dim}²) = {seq_len * dim**2:,}")
# 当序列长度较长时,线性注意力更高效
if seq_len > dim:
print(f" 线性注意力节省: {(seq_len**2 * dim) / (seq_len * dim**2):.2f}x")
# 测试梯度
loss = output.sum()
loss.backward()
print(f"\n梯度检查通过 ✓")
print(f"参数梯度存在: {model.qkv.weight.grad is not None}")