Files changed (1) hide show
  1. modeling_kimi_vl.py +69 -3
modeling_kimi_vl.py CHANGED
@@ -904,6 +904,7 @@ class MoEGate(nn.Module):
904
  self.n_routed_experts = config.n_routed_experts
905
  self.routed_scaling_factor = config.routed_scaling_factor
906
  self.scoring_func = config.scoring_func
 
907
  self.seq_aux = config.seq_aux
908
  self.topk_method = config.topk_method
909
  self.n_group = config.n_group
@@ -970,6 +971,10 @@ class MoEGate(nn.Module):
970
  ) # [n, e]
971
  _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
972
  topk_weight = scores.gather(1, topk_idx)
 
 
 
 
973
  else:
974
  raise NotImplementedError(
975
  f"insupportable TopK function for MoE gating: {self.topk_method}"
@@ -983,7 +988,57 @@ class MoEGate(nn.Module):
983
  topk_weight * self.routed_scaling_factor
984
  ) # must multiply the scaling factor
985
 
986
- return topk_idx, topk_weight
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
987
 
988
 
989
  class DeepseekV3MoE(nn.Module):
@@ -1036,9 +1091,20 @@ class DeepseekV3MoE(nn.Module):
1036
  def forward(self, hidden_states):
1037
  identity = hidden_states
1038
  orig_shape = hidden_states.shape
1039
- topk_idx, topk_weight = self.gate(hidden_states)
1040
  hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
1041
- if not self.training:
 
 
 
 
 
 
 
 
 
 
 
1042
  y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
1043
  if self.config.n_shared_experts is not None:
1044
  y = y + self.shared_experts(identity)
 
904
  self.n_routed_experts = config.n_routed_experts
905
  self.routed_scaling_factor = config.routed_scaling_factor
906
  self.scoring_func = config.scoring_func
907
+ self.alpha = config.aux_loss_alpha
908
  self.seq_aux = config.seq_aux
909
  self.topk_method = config.topk_method
910
  self.n_group = config.n_group
 
971
  ) # [n, e]
972
  _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
973
  topk_weight = scores.gather(1, topk_idx)
974
+ elif self.topk_method == "greedy":
975
+ topk_weight, topk_idx = torch.topk(
976
+ scores, k=self.top_k, dim=-1, sorted=False
977
+ )
978
  else:
979
  raise NotImplementedError(
980
  f"insupportable TopK function for MoE gating: {self.topk_method}"
 
988
  topk_weight * self.routed_scaling_factor
989
  ) # must multiply the scaling factor
990
 
991
+ if self.training and self.alpha > 0.0:
992
+ scores_for_aux = scores
993
+ aux_topk = self.top_k
994
+ # always compute aux loss based on the naive greedy topk method
995
+ topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
996
+ if self.seq_aux:
997
+ scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
998
+ ce = torch.zeros(
999
+ bsz, self.n_routed_experts, device=hidden_states.device
1000
+ )
1001
+ ce.scatter_add_(
1002
+ 1,
1003
+ topk_idx_for_aux_loss,
1004
+ torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device),
1005
+ ).div_(seq_len * aux_topk / self.n_routed_experts)
1006
+ aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(
1007
+ dim=1
1008
+ ).mean() * self.alpha
1009
+ else:
1010
+ mask_ce = F.one_hot(
1011
+ topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts
1012
+ )
1013
+ ce = mask_ce.float().mean(0)
1014
+ Pi = scores_for_aux.mean(0)
1015
+ fi = ce * self.n_routed_experts
1016
+ aux_loss = (Pi * fi).sum() * self.alpha
1017
+ else:
1018
+ aux_loss = None
1019
+
1020
+ return topk_idx, topk_weight, aux_loss
1021
+
1022
+
1023
+ class AddAuxiliaryLoss(torch.autograd.Function):
1024
+ """
1025
+ The trick function of adding auxiliary (aux) loss,
1026
+ which includes the gradient of the aux loss during backpropagation.
1027
+ """
1028
+
1029
+ @staticmethod
1030
+ def forward(ctx, x, loss):
1031
+ assert loss.numel() == 1
1032
+ ctx.dtype = loss.dtype
1033
+ ctx.required_aux_loss = loss.requires_grad
1034
+ return x
1035
+
1036
+ @staticmethod
1037
+ def backward(ctx, grad_output):
1038
+ grad_loss = None
1039
+ if ctx.required_aux_loss:
1040
+ grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
1041
+ return grad_output, grad_loss
1042
 
1043
 
1044
  class DeepseekV3MoE(nn.Module):
 
1091
  def forward(self, hidden_states):
1092
  identity = hidden_states
1093
  orig_shape = hidden_states.shape
1094
+ topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
1095
  hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
1096
+ if self.training:
1097
+ flat_topk_idx = topk_idx.view(-1)
1098
+ hidden_states = hidden_states.repeat_interleave(
1099
+ self.num_experts_per_tok, dim=0
1100
+ )
1101
+ y = torch.empty_like(hidden_states)
1102
+ for i, expert in enumerate(self.experts):
1103
+ y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
1104
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
1105
+ y = y.to(hidden_states.dtype).view(*orig_shape)
1106
+ y = AddAuxiliaryLoss.apply(y, aux_loss)
1107
+ else:
1108
  y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
1109
  if self.config.n_shared_experts is not None:
1110
  y = y + self.shared_experts(identity)