MOE-算法


from torch import nn
import torch.nn.functional as F
import torch
import math

class MoELayer(nn.Module):
    def __init__(self, d_model: int, num_experts: int, num_experts_per_tok: int, hidden_size: int):
        super().__init__()

        # 路由器:根据输入的token计算每个token与专家的关联度
        self.router = nn.Linear(d_model, num_experts, bias=False)
        self.num_experts = num_experts
        self.num_experts_per_tok = num_experts_per_tok

        # 专家的投影矩阵
        self.gate_up_proj = nn.Parameter(torch.randn(num_experts, d_model, 2 * hidden_size) * 0.02)
        self.down_proj = nn.Parameter(torch.randn(num_experts, hidden_size, d_model) * 0.02)

        # 激活函数
        self.activation = nn.SiLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, D = x.size()
        x_flat = x.view(B * T, D)

        # 路由打分:计算每个 token 与专家的匹配度
        logits = self.router(x)
        topk_logits, topk_indices = torch.topk(logits, self.num_experts_per_tok, dim=-1)

        # 计算每个 token 路由到专家的权重
        routing_weights = torch.sigmoid(topk_logits).view(-1)

        # 为每个 token 分配专家
        token_idx = torch.arange(B * T, device=x.device).repeat_interleave(self.num_experts_per_tok)
        selected_experts = topk_indices.view(-1)

        # 专家的前向计算:每个 token 会经过多个专家
        expert_inputs = x_flat[token_idx]
        expert_proj = self.gate_up_proj[selected_experts]
        down_proj = self.down_proj[selected_experts]

        gate_up = torch.bmm(expert_inputs.unsqueeze(1), expert_proj).squeeze(1)
        gate, up = gate_up.chunk(2, dim=-1)

        # 激活并计算专家输出
        act = self.activation(gate) * up
        out = torch.bmm(act.unsqueeze(1), down_proj).squeeze(1) * routing_weights.unsqueeze(-1)

        # 聚合专家输出:使用 scatter_add 进行加权合并
        combined = torch.zeros_like(x_flat)
        combined_weight = torch.zeros(B * T, 1, device=x.device)

        # 计算加权平均值
        combined.scatter_add_(0, token_idx.unsqueeze(-1).expand(-1, D), out)
        combined_weight.scatter_add_(0, token_idx.unsqueeze(-1), routing_weights.unsqueeze(-1))

        combined = combined / (combined_weight + 1e-6)

        return combined.view(B, T, D)


from dataclasses import dataclass

@dataclass
class ModelConfig:
    vocab_size: int = 665           # 词表大小
    d_model: int = 256              # Transformer 隐藏层维度
    n_layers: int = 6               # Transformer 层数
    n_heads: int = 4                # 多头注意力的头数
    block_size: int = 16             # 最大序列长度
    rope_theta: float = 10000.0     # RoPE 旋转位置编码的频率参数
    num_experts: int = 8            # MoE 总专家数
    experts_per_tok: int = 2        # 每个 token 路由到的专家数量(Top-k)
    expert_hidden_size: int = 1024  # 每个专家的隐藏层维度
    shared_hidden_size: int = 1024  # 共享专家的隐藏层维度


model_cfg = ModelConfig()

moe = MoELayer(model_cfg.d_model, model_cfg.num_experts, model_cfg.experts_per_tok, model_cfg.expert_hidden_size)

x = torch.randn(8, 128, 256)

x = moe(x)

print(x)