support training
#15
by
zhouzaida
- opened
- modeling_kimi_vl.py +69 -3
modeling_kimi_vl.py
CHANGED
@@ -904,6 +904,7 @@ class MoEGate(nn.Module):
|
|
904 |
self.n_routed_experts = config.n_routed_experts
|
905 |
self.routed_scaling_factor = config.routed_scaling_factor
|
906 |
self.scoring_func = config.scoring_func
|
|
|
907 |
self.seq_aux = config.seq_aux
|
908 |
self.topk_method = config.topk_method
|
909 |
self.n_group = config.n_group
|
@@ -970,6 +971,10 @@ class MoEGate(nn.Module):
|
|
970 |
) # [n, e]
|
971 |
_, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
|
972 |
topk_weight = scores.gather(1, topk_idx)
|
|
|
|
|
|
|
|
|
973 |
else:
|
974 |
raise NotImplementedError(
|
975 |
f"insupportable TopK function for MoE gating: {self.topk_method}"
|
@@ -983,7 +988,57 @@ class MoEGate(nn.Module):
|
|
983 |
topk_weight * self.routed_scaling_factor
|
984 |
) # must multiply the scaling factor
|
985 |
|
986 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
987 |
|
988 |
|
989 |
class DeepseekV3MoE(nn.Module):
|
@@ -1036,9 +1091,20 @@ class DeepseekV3MoE(nn.Module):
|
|
1036 |
def forward(self, hidden_states):
|
1037 |
identity = hidden_states
|
1038 |
orig_shape = hidden_states.shape
|
1039 |
-
topk_idx, topk_weight = self.gate(hidden_states)
|
1040 |
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
1041 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1042 |
y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
|
1043 |
if self.config.n_shared_experts is not None:
|
1044 |
y = y + self.shared_experts(identity)
|
|
|
904 |
self.n_routed_experts = config.n_routed_experts
|
905 |
self.routed_scaling_factor = config.routed_scaling_factor
|
906 |
self.scoring_func = config.scoring_func
|
907 |
+
self.alpha = config.aux_loss_alpha
|
908 |
self.seq_aux = config.seq_aux
|
909 |
self.topk_method = config.topk_method
|
910 |
self.n_group = config.n_group
|
|
|
971 |
) # [n, e]
|
972 |
_, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
|
973 |
topk_weight = scores.gather(1, topk_idx)
|
974 |
+
elif self.topk_method == "greedy":
|
975 |
+
topk_weight, topk_idx = torch.topk(
|
976 |
+
scores, k=self.top_k, dim=-1, sorted=False
|
977 |
+
)
|
978 |
else:
|
979 |
raise NotImplementedError(
|
980 |
f"insupportable TopK function for MoE gating: {self.topk_method}"
|
|
|
988 |
topk_weight * self.routed_scaling_factor
|
989 |
) # must multiply the scaling factor
|
990 |
|
991 |
+
if self.training and self.alpha > 0.0:
|
992 |
+
scores_for_aux = scores
|
993 |
+
aux_topk = self.top_k
|
994 |
+
# always compute aux loss based on the naive greedy topk method
|
995 |
+
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
996 |
+
if self.seq_aux:
|
997 |
+
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
998 |
+
ce = torch.zeros(
|
999 |
+
bsz, self.n_routed_experts, device=hidden_states.device
|
1000 |
+
)
|
1001 |
+
ce.scatter_add_(
|
1002 |
+
1,
|
1003 |
+
topk_idx_for_aux_loss,
|
1004 |
+
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device),
|
1005 |
+
).div_(seq_len * aux_topk / self.n_routed_experts)
|
1006 |
+
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(
|
1007 |
+
dim=1
|
1008 |
+
).mean() * self.alpha
|
1009 |
+
else:
|
1010 |
+
mask_ce = F.one_hot(
|
1011 |
+
topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts
|
1012 |
+
)
|
1013 |
+
ce = mask_ce.float().mean(0)
|
1014 |
+
Pi = scores_for_aux.mean(0)
|
1015 |
+
fi = ce * self.n_routed_experts
|
1016 |
+
aux_loss = (Pi * fi).sum() * self.alpha
|
1017 |
+
else:
|
1018 |
+
aux_loss = None
|
1019 |
+
|
1020 |
+
return topk_idx, topk_weight, aux_loss
|
1021 |
+
|
1022 |
+
|
1023 |
+
class AddAuxiliaryLoss(torch.autograd.Function):
|
1024 |
+
"""
|
1025 |
+
The trick function of adding auxiliary (aux) loss,
|
1026 |
+
which includes the gradient of the aux loss during backpropagation.
|
1027 |
+
"""
|
1028 |
+
|
1029 |
+
@staticmethod
|
1030 |
+
def forward(ctx, x, loss):
|
1031 |
+
assert loss.numel() == 1
|
1032 |
+
ctx.dtype = loss.dtype
|
1033 |
+
ctx.required_aux_loss = loss.requires_grad
|
1034 |
+
return x
|
1035 |
+
|
1036 |
+
@staticmethod
|
1037 |
+
def backward(ctx, grad_output):
|
1038 |
+
grad_loss = None
|
1039 |
+
if ctx.required_aux_loss:
|
1040 |
+
grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
|
1041 |
+
return grad_output, grad_loss
|
1042 |
|
1043 |
|
1044 |
class DeepseekV3MoE(nn.Module):
|
|
|
1091 |
def forward(self, hidden_states):
|
1092 |
identity = hidden_states
|
1093 |
orig_shape = hidden_states.shape
|
1094 |
+
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
1095 |
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
1096 |
+
if self.training:
|
1097 |
+
flat_topk_idx = topk_idx.view(-1)
|
1098 |
+
hidden_states = hidden_states.repeat_interleave(
|
1099 |
+
self.num_experts_per_tok, dim=0
|
1100 |
+
)
|
1101 |
+
y = torch.empty_like(hidden_states)
|
1102 |
+
for i, expert in enumerate(self.experts):
|
1103 |
+
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
|
1104 |
+
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
1105 |
+
y = y.to(hidden_states.dtype).view(*orig_shape)
|
1106 |
+
y = AddAuxiliaryLoss.apply(y, aux_loss)
|
1107 |
+
else:
|
1108 |
y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
|
1109 |
if self.config.n_shared_experts is not None:
|
1110 |
y = y + self.shared_experts(identity)
|