ALiBi 位置编码


1. 背景:为什么需要位置编码?

原始的 Transformer 模型本身是 置换不变 的。也就是说,它对输入序列中单词的顺序不敏感。如果不提供位置信息,模型会将句子 “猫追老鼠” 和 “老鼠追猫” 视为相同的。

为了解决这个问题,我们需要一种方法将单词的顺序或位置信息注入到模型中。经典的 Transformer 使用了 正弦余弦位置编码,将其与词向量相加。后来,像 T5 这样的模型则使用了相对位置编码,在注意力机制中通过偏置项来体现词与词之间的相对距离。

2. ALiBi 的核心思想

ALiBi 的全称是 Attention with Linear Biases(带线性偏置的注意力)。

它的核心思想非常简单且直观:

“惩罚”注意力分数,惩罚量与两个词之间的距离成线性关系。离得越远的词,受到的惩罚越大。

在输入层面添加任何位置向量,而是直接修改注意力机制的计算过程

3. ALiBi 的工作原理

我们来看一下它是如何实现的。

1. 标准注意力分数计算: 在标准的自注意力中,查询向量(Query)和键向量(Key)之间的分数计算为: $$ \text{Attention}(Q, K, V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) V $$ 其中 $$ S = \frac{QK^T}{\sqrt{d_k}} $$ 被称为注意力分数矩阵。

2. ALiBi 的修改: ALiBi 在这个分数矩阵 ( S ) 上直接加上一个静态的、非学习的偏置矩阵 ( M )。 $$ \text{Attention}(Q, K, V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} + M \right) V $$

3. 偏置矩阵 ( M ) 是如何构造的?

  • ( M ) 不是一个通过训练学习到的参数。
  • 它是一个预先定义好的矩阵,其值只与两个 token 之间的相对距离有关。
  • 假设我们有一个长度为 4 的序列 [A, B, C, D]
  • 在计算 token A 对所有 token 的注意力时,我们需要为 AAABACAD 的注意力分数分别加上一个偏置。
  • 这个偏置是基于 A 与其他 token 的相对位置来确定的。

    • AA 的距离是 0,偏置 = 0
    • AB 的距离是 1,偏置 = -m
    • AC 的距离是 2,偏置 = -2m
    • AD 的距离是 3,偏置 = -3m
  • 这里的 m 是一个与注意力头相关的斜率。对于有 ( n ) 个头的模型,每个头 ( i ) 的 ( m ) 值定义为: $ m_i = \frac{1}{2^{8i/n}} $ 例如,一个 8 头的模型,其头的斜率分别是:$ \frac{1}{2^1}, \frac{1}{2^2}, ..., \frac{1}{2^8} $。

  • 最终,矩阵 ( M ) 看起来像一个“斜坡”,左上角为 0,向右下角方向线性递减(为负值)。

一个简化的 ( M ) 矩阵示例(对于4个token的序列,假设 m=1): $$ M = \begin{matrix} 0 & -\infty & -\infty & -\infty \ -1 & 0 & -\infty & -\infty \ -2 & -1 & 0 & -\infty \ -3 & -2 & -1 & 0 \ \end{matrix} $$

(在实际应用中,-inf 的部分是用于掩盖未来信息的因果掩码,ALiBi 的偏置是加在非 -inf 的位置上。)

4. ALiBi 的优势

  1. 强大的外推能力
  2. 这是 ALiBi 最突出的优点。在训练时使用较短的序列(如 1024 个 token),模型无需任何微调就能在推理时处理长得多的序列(如 16K+ 个 token)。
  3. 因为它的偏置是线性的。模型在训练时学会了“距离越远,注意力越少”的线性关系。当序列变长时,这个线性规则可以自然地推广到更远的距离。而正弦编码等绝对位置编码在遇到训练时未见过的位置时,性能会急剧下降。

  4. 无需位置嵌入参数

  5. ALiBi 不向模型中添加任何可学习的参数(偏置矩阵 ( M ) 是静态计算的)。这使得模型更小,训练更稳定。

  6. 训练速度更快

  7. 由于不需要计算位置向量的内积,并且在推理时可以缓存 ( M ) 矩阵,ALiBi 通常能带来轻微的训练和推理速度提升。

  8. 在多种任务上表现优异

  9. 论文中的实验表明,ALiBi 在语言建模、机器翻译等任务上,性能不逊于甚至优于其他位置编码方法,同时具备了无与伦比的外推能力。

5. 总结与比喻

你可以把 ALiBi 想象成一个内置的、基于距离的注意力衰减机制

  • 标准注意力:模型需要自己去学习“距离近的词可能更相关”这个模式。
  • ALiBi:我们直接“告诉”模型这个模式——“对于每个查询,你更应该关注它附近的词,并且关注度随着距离线性下降”。我们通过一个固定的、线性的惩罚项来强制模型遵循这个模式。

这种直接的“硬编码”先验知识,反而让模型具备了强大的泛化能力,尤其是在处理长序列时。正因为如此,ALiBi 被广泛应用于需要超长上下文的大型语言模型中,例如 MosaicML 的 MPT 系列模型 就采用了 ALiBi。

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt

class ALiBiPositionalBias(nn.Module):
    """
    ALiBi (Attention with Linear Biases) 位置编码实现
    """
    def __init__(self, num_heads, max_positions=512):
        super().__init__()
        self.num_heads = num_heads
        self.max_positions = max_positions

        # 为每个头生成斜率 (slope)
        slopes = torch.Tensor(self._get_slopes(num_heads))
        self.register_buffer('slopes', slopes)

        # 预计算偏置矩阵
        self.register_buffer('bias_matrix', self._precompute_bias_matrix())

    def _get_slopes(self, n):
        """
        为 n 个头生成斜率,按照论文中的公式
        """
        def get_slopes_power_of_2(n):
            start = (2**(-2**-(math.log2(n)-3)))
            ratio = start
            return [start*ratio**i for i in range(n)]

        if math.log2(n).is_integer():
            return get_slopes_power_of_2(n)
        else:
            # 最接近的2的幂
            closest_power_of_2 = 2 ** math.floor(math.log2(n))
            return (get_slopes_power_of_2(closest_power_of_2) + 
                    self._get_slopes(2 * closest_power_of_2)[0::2][:n - closest_power_of_2])

    def _precompute_bias_matrix(self):
        """
        预计算偏置矩阵
        """
        # 创建相对位置矩阵
        context_position = torch.arange(self.max_positions).unsqueeze(0)  # [1, seq_len]
        memory_position = torch.arange(self.max_positions).unsqueeze(1)  # [seq_len, 1]

        # 相对位置 = 目标位置 - 源位置
        relative_position = memory_position - context_position  # [seq_len, seq_len]
        relative_position = torch.abs(relative_position)

        # 为每个头创建偏置矩阵 [num_heads, seq_len, seq_len]
        bias_matrix = torch.zeros(self.num_heads, self.max_positions, self.max_positions)

        for h in range(self.num_heads):
            # 每个头使用自己的斜率
            bias_matrix[h] = -relative_position * self.slopes[h]

        return bias_matrix

    def forward(self, seq_len):
        """
        返回对应序列长度的偏置矩阵
        Args:
            seq_len: 序列长度
        Returns:
            bias: [num_heads, seq_len, seq_len] 的偏置矩阵
        """
        if seq_len <= self.max_positions:
            return self.bias_matrix[:, :seq_len, :seq_len]
        else:
            # 如果序列长度超过预计算的最大长度,动态计算
            context_position = torch.arange(seq_len).unsqueeze(0)
            memory_position = torch.arange(seq_len).unsqueeze(1)
            relative_position = memory_position - context_position
            relative_position = torch.abs(relative_position)

            bias_matrix = torch.zeros(self.num_heads, seq_len, seq_len)
            for h in range(self.num_heads):
                bias_matrix[h] = -relative_position * self.slopes[h]

            return bias_matrix.to(self.bias_matrix.device)

class ALiBiAttention(nn.Module):
    """
    集成 ALiBi 的自注意力层
    """
    def __init__(self, d_model, num_heads, max_positions=512):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"

        # 线性变换层
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.wo = nn.Linear(d_model, d_model)

        # ALiBi 偏置
        self.alibi = ALiBiPositionalBias(num_heads, max_positions)

    def forward(self, x, mask=None):
        """
        Args:
            x: [batch_size, seq_len, d_model]
            mask: 可选的注意力掩码 [batch_size, seq_len, seq_len]
        Returns:
            output: [batch_size, seq_len, d_model]
        """
        batch_size, seq_len, d_model = x.shape

        # 线性变换
        Q = self.wq(x)  # [batch_size, seq_len, d_model]
        K = self.wk(x)  # [batch_size, seq_len, d_model]
        V = self.wv(x)  # [batch_size, seq_len, d_model]

        # 重塑为多头
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        # scores: [batch_size, num_heads, seq_len, seq_len]

        # 添加 ALiBi 偏置
        alibi_bias = self.alibi(seq_len)  # [num_heads, seq_len, seq_len]
        scores = scores + alibi_bias.unsqueeze(0)  # 添加batch维度

        # 应用可选的注意力掩码
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # 计算注意力权重
        attention_weights = F.softmax(scores, dim=-1)

        # 应用注意力权重到值向量
        output = torch.matmul(attention_weights, V)  # [batch_size, num_heads, seq_len, head_dim]

        # 重塑回原始形状
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)

        # 最终线性变换
        output = self.wo(output)

        return output, attention_weights

def demo_alibi_attention():
    """演示 ALiBi 注意力的使用"""
    print("=== ALiBi 位置编码 Demo ===\n")

    # 设置参数
    d_model = 64
    num_heads = 4
    batch_size = 2
    seq_len = 8

    # 创建模型和随机输入
    attention_layer = ALiBiAttention(d_model, num_heads)
    x = torch.randn(batch_size, seq_len, d_model)

    print(f"输入形状: {x.shape}")
    print(f"模型参数: d_model={d_model}, num_heads={num_heads}")

    # 前向传播
    output, attn_weights = attention_layer(x)
    print(f"输出形状: {output.shape}")
    print(f"注意力权重形状: {attn_weights.shape}")

    return attention_layer, x, output, attn_weights

def visualize_alibi_bias():
    """可视化 ALiBi 偏置矩阵"""
    print("\n=== 可视化 ALiBi 偏置矩阵 ===")

    num_heads = 4
    seq_len = 10
    alibi = ALiBiPositionalBias(num_heads)

    # 获取偏置矩阵
    bias_matrix = alibi(seq_len)

    # 打印斜率
    print(f"各头的斜率: {alibi.slopes.tolist()}")

    # 可视化每个头的偏置矩阵
    fig, axes = plt.subplots(1, num_heads, figsize=(16, 4))

    for i in range(num_heads):
        im = axes[i].imshow(bias_matrix[i].detach().numpy(), cmap='coolwarm', aspect='auto')
        axes[i].set_title(f'头 {i}\n斜率: {alibi.slopes[i]:.4f}')
        axes[i].set_xlabel('目标位置')
        axes[i].set_ylabel('源位置')
        plt.colorbar(im, ax=axes[i])

    plt.tight_layout()
    plt.savefig('alibi_bias_matrix.png', dpi=300, bbox_inches='tight')
    print("偏置矩阵可视化已保存为 'alibi_bias_matrix.png'")

    return bias_matrix

def test_extrapolation():
    """测试 ALiBi 的外推能力"""
    print("\n=== 测试外推能力 ===")

    # 训练时的序列长度
    train_seq_len = 8
    # 推理时的序列长度(比训练时长)
    eval_seq_len = 16

    num_heads = 4
    alibi = ALiBiPositionalBias(num_heads, max_positions=eval_seq_len)

    # 获取训练和推理时的偏置矩阵
    train_bias = alibi(train_seq_len)
    eval_bias = alibi(eval_seq_len)

    print(f"训练序列长度: {train_seq_len}")
    print(f"推理序列长度: {eval_seq_len}")
    print(f"偏置矩阵可以无缝扩展到更长的序列")

    # 检查训练部分的偏置是否一致
    train_part_from_eval = eval_bias[:, :train_seq_len, :train_seq_len]
    is_consistent = torch.allclose(train_bias, train_part_from_eval, atol=1e-6)

    print(f"训练部分在扩展后保持一致: {is_consistent}")

    return train_bias, eval_bias

if __name__ == "__main__":
    # 设置随机种子以便复现
    torch.manual_seed(42)

    # 运行演示
    attention_layer, x, output, attn_weights = demo_alibi_attention()

    # 可视化偏置矩阵
    bias_matrix = visualize_alibi_bias()

    # 测试外推能力
    train_bias, eval_bias = test_extrapolation()

    print("\n=== Demo 完成 ===")
    print("这个 Demo 展示了:")
    print("1. ALiBi 位置编码的实现")
    print("2. 如何集成到自注意力机制中")
    print("3. 偏置矩阵的可视化")
    print("4. ALiBi 强大的外推能力")