from torch import nn
import torch.nn.functional as F
import torch
import math
class MoELayer(nn.Module):
def __init__(self, d_model: int, num_experts: int, num_experts_per_tok: int, hidden_size: int):
super().__init__()
# 路由器:根据输入的token计算每个token与专家的关联度
self.router = nn.Linear(d_model, num_experts, bias=False)
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
# 专家的投影矩阵
self.gate_up_proj = nn.Parameter(torch.randn(num_experts, d_model, 2 * hidden_size) * 0.02)
self.down_proj = nn.Parameter(torch.randn(num_experts, hidden_size, d_model) * 0.02)
# 激活函数
self.activation = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, D = x.size()
x_flat = x.view(B * T, D)
# 路由打分:计算每个 token 与专家的匹配度
logits = self.router(x)
topk_logits, topk_indices = torch.topk(logits, self.num_experts_per_tok, dim=-1)
# 计算每个 token 路由到专家的权重
routing_weights = torch.sigmoid(topk_logits).view(-1)
# 为每个 token 分配专家
token_idx = torch.arange(B * T, device=x.device).repeat_interleave(self.num_experts_per_tok)
selected_experts = topk_indices.view(-1)
# 专家的前向计算:每个 token 会经过多个专家
expert_inputs = x_flat[token_idx]
expert_proj = self.gate_up_proj[selected_experts]
down_proj = self.down_proj[selected_experts]
gate_up = torch.bmm(expert_inputs.unsqueeze(1), expert_proj).squeeze(1)
gate, up = gate_up.chunk(2, dim=-1)
# 激活并计算专家输出
act = self.activation(gate) * up
out = torch.bmm(act.unsqueeze(1), down_proj).squeeze(1) * routing_weights.unsqueeze(-1)
# 聚合专家输出:使用 scatter_add 进行加权合并
combined = torch.zeros_like(x_flat)
combined_weight = torch.zeros(B * T, 1, device=x.device)
# 计算加权平均值
combined.scatter_add_(0, token_idx.unsqueeze(-1).expand(-1, D), out)
combined_weight.scatter_add_(0, token_idx.unsqueeze(-1), routing_weights.unsqueeze(-1))
combined = combined / (combined_weight + 1e-6)
return combined.view(B, T, D)
from dataclasses import dataclass
@dataclass
class ModelConfig:
vocab_size: int = 665 # 词表大小
d_model: int = 256 # Transformer 隐藏层维度
n_layers: int = 6 # Transformer 层数
n_heads: int = 4 # 多头注意力的头数
block_size: int = 16 # 最大序列长度
rope_theta: float = 10000.0 # RoPE 旋转位置编码的频率参数
num_experts: int = 8 # MoE 总专家数
experts_per_tok: int = 2 # 每个 token 路由到的专家数量(Top-k)
expert_hidden_size: int = 1024 # 每个专家的隐藏层维度
shared_hidden_size: int = 1024 # 共享专家的隐藏层维度
model_cfg = ModelConfig()
moe = MoELayer(model_cfg.d_model, model_cfg.num_experts, model_cfg.experts_per_tok, model_cfg.expert_hidden_size)
x = torch.randn(8, 128, 256)
x = moe(x)
print(x)
MOE-算法
33 views