はじめに
最近リリースされる大規模言語モデル(LLM)では、Mixture of Experts(MoE)と呼ばれる構造がよく採用されています。
MoEは、モデル内部に複数のネットワーク、すなわちExpertを持ち、入力トークンごとにその一部のExpertだけを選択して計算する仕組みです。これにより、モデル全体のパラメータ数を大きく保ちながら、推論時に実際に使う計算量を抑えることができます。
本記事では、Qwen3.5のMoEモデルを題材に、MoEがソースコード上でどのような構造として実装されているのか、また入力に応じてどのようにExpertが選択されるのかを解説します。
想定読者
本記事は、MoEの概念はある程度知っているものの、実際のモデル実装ではどのように動いているのかがまだイメージしづらい方を主な読者として想定しています。
そのため、MoEの基本概念そのものについては詳しく扱いません。MoEについて初めて学ぶ方は、まず以下の記事などを読んでおくと、本記事の内容を追いやすくなると思います。
また、ソースコードを読みながら解説するため、PythonおよびPyTorchの基本的な読み書きに慣れている方を前提としています。
なお、Qwen3.5 MoEには、Gated DeltaNetによるlinear attentionや、full attentionとのハイブリッド構成など、MoE以外にも多くの工夫が導入されています。ただし、本記事の主題はあくまでMoE部分の実装理解であるため、それらの仕組みについては必要な範囲で触れるにとどめ、詳細な解説は行いません。
前提
まず、Qwen3.5には、MoEモデルである
Qwen3.5-35B-A3B/
Qwen3.5-122B-A10B/
Qwen3.5-397B-A17Bと、denseモデルである
Qwen3.5-9B/
Qwen3.5-27Bが基盤モデルとして公開されています。
MoEモデル名に含まれる
○○B-AxxBは、
○○Bがモデル全体のパラメータ数、
AxxBが推論時にアクティブになるパラメータ数の目安を表しています。たとえば
Qwen3.5-35B-A3Bであれば、モデル全体では35B規模のパラメータを持ちますが、各トークンの計算で主に使用されるのは約3B分のパラメータです。
本記事では、このうち
Qwen3.5-35B-A3Bを前提に解説します。また、ソースコードは以下のHugging Face Transformersの実装をもとにします。
この実装には、テキスト生成を扱う
Qwen3_5MoeForCausalLMと、画像・動画入力も扱えるマルチモーダル用の
Qwen3_5MoeForConditionalGenerationがあります。本記事では、MoEの構造を追いやすくするため、テキスト生成用の
Qwen3_5MoeForCausalLMに絞って解説します。
Qwen3.5 MoEの全体像
上図は、
Qwen3.5-35B-A3Bの全体構造を簡略化したものです。
モデルの中心となるのは
DecoderLayerであり、
Qwen3.5-35B-A3Bでは合計40層のDecoderLayerが積み重ねられています。ただし、すべての層が同じ構造になっているわけではありません。内部では、
linear_attention layerが3層続いた後に
full_attention layerが1層配置される構成を1グループとしており、このグループが10回繰り返されます。
つまり、全体としては次のような構造になります。
1 group = 3 × linear_attention layer + 1 × full_attention layer 10 groups × 4 layers = 40 layers
各DecoderLayerは、大きく見ると「Attention系のToken Mixer」と「Sparse MoE Block」から構成されます。
linear_attention layerではGated DeltaNetが使われ、
full_attention layerでは通常のSelf-Attentionに近い処理が使われます。一方で、後段にはどちらの層でもSparse MoE Blockが配置されます。
Configから構造を確認する
次に、
Qwen3_5MoeTextConfigからモデル構造に関わる主要な設定を確認します。
class Qwen3_5MoeTextConfig(PreTrainedConfig):
# 一部のパラメータは省略
vocab_size: int = 248320 # 語彙数。lm_headの出力次元で、次トークン候補の総数
hidden_size: int = 2048 # 隠れ状態ベクトルの次元数。各トークンを2048次元で表現する
num_hidden_layers: int = 40 # DecoderLayerの層数。ここでは40層
num_attention_heads: int = 16 # Full Attentionで使うQueryヘッド数
num_key_value_heads: int = 2 # Full Attentionで使うKey/Valueヘッド数。GQA用の設定
hidden_act: str = "silu" # MLPやMoE内部で使う活性化関数
head_dim: int = 256 # Full Attentionの各ヘッドの次元数
linear_conv_kernel_dim: int = 4 # Gated DeltaNet内の短い畳み込みに使うカーネルサイズ
linear_key_head_dim: int = 128 # Linear Attention側のKeyヘッド1つあたりの次元数
linear_value_head_dim: int = 128 # Linear Attention側のValueヘッド1つあたりの次元数
linear_num_key_heads: int = 16 # Linear Attention側のKeyヘッド数
linear_num_value_heads: int = 32 # Linear Attention側のValueヘッド数
moe_intermediate_size: int = 512 # 各MoE expert内部の中間層次元数
shared_expert_intermediate_size: int = 512 # 全トークンで共有されるshared expertの中間層次元数
num_experts_per_tok: int = 8 # 1トークンあたり選択されるexpert数。Top-k routingのk
num_experts: int = 256 # MoEに用意されているexpertの総数
このConfigを見ると、モデル全体の大まかな構造を把握できます。
まず、
vocab_sizeは
248320に設定されています。これは、モデルが出力候補として扱うトークンIDの総数です。最終的な
lm_headは、各トークン位置の隠れ状態をこの語彙数ぶんのlogitに変換します。
また、
hidden_sizeは
2048です。これは、各トークンがモデル内部で2048次元のベクトルとして表現されることを意味します。DecoderLayerの数は
num_hidden_layers = 40であり、先ほどの図で示したように、40層のDecoderLayerが積み重ねられます。
Attentionに関する設定としては、full attention用の
num_attention_heads、
num_key_value_heads、
head_dimに加えて、linear attention用の
linear_key_head_dim、
linear_value_head_dim、
linear_num_key_heads、
linear_num_value_headsなどが定義されています。ここから、Qwen3.5 MoEではfull attentionだけでなく、Gated DeltaNetによるlinear attentionも併用されていることが分かります。
さらに、本記事で主に扱うMoEに関する設定として、
num_experts = 256と
num_experts_per_tok = 8が定義されています。これは、モデル内に256個のExpertが用意されており、各トークンに対してそのうち8個のExpertが選択されることを意味します。
つまり、このConfigからは、
Qwen3.5-35B-A3Bが「40層のDecoderLayer」「full attentionとlinear attentionのハイブリッド構成」「256個のExpertからトークンごとに8個を選択するSparse MoE構造」を持つモデルであることが確認できます。
実装を追う
ここからは、実際のソースコードを上位のクラスから順に見ていきます。
流れとしては、まずユーザーが直接呼び出す
Qwen3_5MoeForCausalLMを確認し、その内部で使われる
Qwen3_5MoeTextModel、さらにその中に積み重ねられている
Qwen3_5MoeDecoderLayerへと進みます。
Qwen3_5MoeForCausalLM ↓ Qwen3_5MoeTextModel ↓ Qwen3_5MoeDecoderLayer ↓ Qwen3_5MoeSparseMoeBlock
最終的には、本記事の主題である
Qwen3_5MoeSparseMoeBlockの内部を詳しく見ていきます。
テキスト生成モデルの入口:Qwen3_5MoeForCausalLM
まずは、ユーザーが直接呼び出すモデルである
Qwen3_5MoeForCausalLMから見ていきます。
class Qwen3_5MoeForCausalLM(Qwen3_5MoePreTrainedModel, GenerationMixin):
def __init__(self, config):
super().__init__(config)
self.model = Qwen3_5MoeTextModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.router_aux_loss_coef = config.router_aux_loss_coef
self.num_experts = config.num_experts
self.num_experts_per_tok = config.num_experts_per_tok
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_router_logits=None,
logits_to_keep=0,
**kwargs,
) -> MoeCausalLMOutputWithPast:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_router_logits=output_router_logits,
**kwargs,
)
hidden_states = outputs.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
Qwen3_5MoeForCausalLMは、テキスト生成用モデルの入口にあたるクラスです。
大きな流れは次のとおりです。
input_ids ↓ Qwen3_5MoeTextModel ↓ hidden_states ↓ lm_head ↓ logits
まず、ユーザーから渡された
input_idsは
Qwen3_5MoeTextModelに入力されます。
Qwen3_5MoeTextModelは、Embedding、DecoderLayer、RMSNormなどを通して、各トークンの特徴量である
hidden_statesを出力します。
その後、
lm_headによって
hidden_statesが語彙数ぶんのスコアに変換されます。
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
今回のConfigでは、
hidden_size = 2048、
vocab_size = 248320です。そのため、各トークン位置の2048次元ベクトルは、248,320個の語彙トークンそれぞれに対するスコアへ変換されます。
このスコアが
logitsです。
hidden_states: [batch_size, seq_len, hidden_size] logits: [batch_size, seq_len, vocab_size]
また、
labelsが与えられている場合は言語モデリング用の
lossが計算されます。さらに、
output_router_logits=Trueの場合は、MoEのExpert選択が偏りすぎないようにするための補助損失
aux_lossも計算されます。
最後に、
loss、
aux_loss、
logits、
router_logitsなどが
MoeCausalLMOutputWithPastにまとめられて返されます。
MoeCausalLMOutputWithPastは、通常のテンソルそのものではなく、モデル出力を名前付きでまとめるためのデータクラス系オブジェクトです。
言語モデル本体:Qwen3_5MoeTextModel
次に、
Qwen3_5MoeTextModelを見ていきます。ここが、EmbeddingやDecoderLayerを含む言語モデル本体です。
class Qwen3_5MoeTextModel(Qwen3_5MoePreTrainedModel):
def __init__(self, config: Qwen3_5MoeTextConfig):
super().__init__(config)
self.embed_tokens = nn.Embedding(
config.vocab_size,
config.hidden_size,
config.pad_token_id,
)
self.layers = nn.ModuleList(
[
Qwen3_5MoeDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen3_5MoeTextRotaryEmbedding(config=config)
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
**kwargs,
) -> BaseModelOutputWithPast:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
causal_mask = create_causal_mask(
config=self.config,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
)
linear_attn_mask = self._update_linear_attn_mask(
attention_mask,
past_key_values,
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
layer_mask = (
linear_attn_mask
if self.config.layer_types[i] == "linear_attention"
else causal_mask
)
hidden_states = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=layer_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return Qwen3_5MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)
Qwen3_5MoeTextModelでは、まず
input_idsがEmbedding層に入力されます。
inputs_embeds = self.embed_tokens(input_ids)
これにより、各トークンIDは2048次元のベクトルに変換されます。
input_ids: [batch_size, seq_len] inputs_embeds: [batch_size, seq_len, hidden_size]
ここでの
hidden_sizeはConfigで指定されている
2048です。つまり、各トークンはモデル内部では2048次元の特徴ベクトルとして扱われます。
続いて、Attentionに使うマスクが作成されます。
causal_mask = create_causal_mask(...) linear_attn_mask = self._update_linear_attn_mask(...)
Qwen3.5 MoEでは、層によって
full_attentionと
linear_attentionが切り替わるため、それぞれに応じたマスクが使われます。
また、位置情報としてRotary Position Embedding、いわゆるRoPEが計算されます。
position_embeddings = self.rotary_emb(hidden_states, position_ids)
RoPEは、トークンの順序や相対的な位置関係をAttention計算に反映するための位置埋め込みです。
その後、
DecoderLayerを順番に通していきます。
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
layer_mask = (
linear_attn_mask
if self.config.layer_types[i] == "linear_attention"
else causal_mask
)
hidden_states = decoder_layer(...)
Decoderでは、
self.config.layer_types[i]によって、各層が
linear_attentionなのか
full_attentionなのかを判定しています。
linear_attentionの層では
linear_attn_maskが使われ、
full_attentionの層では
causal_maskが使われます。
すべてのDecoderLayerを通った後、最後にRMSNormが適用されます。
hidden_states = self.norm(hidden_states)
そして、最終的な
hidden_statesが
last_hidden_stateとして返されます。
DecoderLayerの構造
続いて、Qwen3.5 MoEのメインブロックである
DecoderLayerを見ていきます。
ここまでで、Qwen3.5 MoEでは
linear_attentionと
full_attentionが層ごとに使い分けられていることを確認しました。
DecoderLayerでは、前半でトークン間の情報を混ぜる
Token Mixerが使われ、後半で
Sparse MoE BlockによるFFN処理が行われます。Qwen3.5 MoEでは、この
Token Mixerとして、層によって
linear_attentionまたは
full_attentionが使い分けられています。
class Qwen3_5MoeDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Qwen3_5MoeTextConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.layer_type = config.layer_types[layer_idx]
if self.layer_type == "linear_attention":
self.linear_attn = Qwen3_5MoeGatedDeltaNet(config, layer_idx)
elif self.layer_type == "full_attention":
self.self_attn = Qwen3_5MoeAttention(config, layer_idx)
self.mlp = Qwen3_5MoeSparseMoeBlock(config)
self.input_layernorm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> torch.FloatTensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Token Mixer
if self.layer_type == "linear_attention":
hidden_states = self.linear_attn(
hidden_states=hidden_states,
cache_params=past_key_values,
attention_mask=attention_mask,
**kwargs,
)
elif self.layer_type == "full_attention":
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Sparse MoE Block
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if isinstance(hidden_states, tuple):
hidden_states, _ = hidden_states
hidden_states = residual + hidden_states
return hidden_states
このコードを見ると、
Qwen3_5MoeDecoderLayerは大きく次の2つの部分に分けられます。
1. Token Mixer - linear_attention: Qwen3_5MoeGatedDeltaNet - full_attention: Qwen3_5MoeAttention 2. Sparse MoE Block - Qwen3_5MoeSparseMoeBlock
まず、入力された
hidden_statesは
input_layernormによって正規化されます。
hidden_states = self.input_layernorm(hidden_states)
その後、
self.layer_typeに応じて、
linear_attentionまたは
full_attentionのどちらかに入力されます。
if self.layer_type == "linear_attention":
hidden_states = self.linear_attn(...)
elif self.layer_type == "full_attention":
hidden_states, _ = self.self_attn(...)
各層でどちらのAttentionを使うかは、Config内の
layer_typesによって決まります。
def __post_init__(self, **kwargs):
kwargs.setdefault("partial_rotary_factor", 0.25)
if self.layer_types is None:
interval_pattern = kwargs.pop("full_attention_interval", 4)
self.layer_types = [
"linear_attention" if bool((i + 1) % interval_pattern) else "full_attention"
for i in range(self.num_hidden_layers)
]
super().__post_init__(**kwargs)
このコードでは、
full_attention_intervalのデフォルト値が
4になっています。そのため、4層に1回だけ
full_attentionが使われ、それ以外の層では
linear_attentionが使われます。
linear_attention linear_attention linear_attention full_attention
つまり、
Qwen3.5-35B-A3Bでは、この4層のパターンが10回繰り返され、合計40層のDecoderLayerを構成しています。
full_attentionと
linear_attentionの違いについて
本記事はMoEの解説を主題としているため、それぞれのAttention機構については深く扱いません。ただし、Qwen3.5 MoEでは
full_attentionと
linear_attentionが併用されているため、ここでは両者の違いを簡単に整理します。
full_attention:
通常のSelf-Attentionに近く、各トークンが他のトークンとのAttentionスコアを計算します。そのため、文脈中のトークン間の関係を直接的に捉えやすい一方で、系列長が長くなるほど計算量やメモリ使用量が大きくなります。
Qwen3.5 MoEの
full_attentionでは、通常の
q,
k,
vに加えてAttention出力を調整するgateが導入されています。また、Key/Valueヘッド数をQueryヘッド数より少なくするGrouped Query Attention(GQA)も使われており、KV cacheの増大を抑えたり、推論時のK/V読み出しコストを下げたりする工夫が含まれています。
linear_attention:
full_attentionでは全トークン間のAttentionスコアを計算するため、系列長を
Nとすると計算量が概ね
O(N^2)に増えます。これに対して、
linear_attentionは全トークン間のAttention行列を明示的に作らず、系列情報をより効率的に扱うための仕組みです。
Qwen3.5 MoEでは、
linear_attentionとして
Gated DeltaNetが使われています。Gated DeltaNetは、通常のAttentionとは異なり、系列情報を状態として扱いながらトークン間の情報を混ぜるLinear Attention系のToken Mixerです。これにより、長い系列に対する計算量やメモリ使用量を抑えやすくなります。
このように、Qwen3.5 MoEでは、効率性を重視した
linear_attentionと、通常のSelf-Attentionに近い
full_attentionを組み合わせることで、計算量を抑えつつ文脈情報を扱う構成になっています。
どちらのToken Mixerを通った場合でも、その後段には共通して
Qwen3_5MoeSparseMoeBlockが配置されています。
self.mlp = Qwen3_5MoeSparseMoeBlock(config)
そのため、DecoderLayerの流れは次のように整理できます。
ここまでで、入力が
Qwen3_5MoeForCausalLMから
Qwen3_5MoeTextModel、さらに
Qwen3_5MoeDecoderLayerを通り、その中で
Qwen3_5MoeSparseMoeBlockに到達するまでの流れを確認しました。
以降では、本記事の主題である
Qwen3_5MoeSparseMoeBlockの内部、特にRouterによるExpert選択、Expert計算、出力の集約について詳しく見ていきます。
Sparse MoE Blockの内部
Sparse MoE Block全体の流れ
ここからは、本記事のメインであるMoEブロックの内部を見ていきます。
Attentionブロックを通った
hidden_statesは、
Qwen3_5MoeSparseMoeBlockに入力されます。このブロックでは、入力トークンごとに使用するExpertを選択し、その出力を集約することで、最終的なMoEブロックの出力を作ります。
まずは、全体の流れを確認します。
class Qwen3_5MoeSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.gate = Qwen3_5MoeTopKRouter(config)
self.experts = Qwen3_5MoeExperts(config)
self.shared_expert = Qwen3_5MoeMLP(
config,
intermediate_size=config.shared_expert_intermediate_size,
)
self.shared_expert_gate = torch.nn.Linear(
config.hidden_size,
1,
bias=False,
)
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
shared_expert_output = self.shared_expert(hidden_states_reshaped)
_, routing_weights, selected_experts = self.gate(hidden_states_reshaped)
expert_output = self.experts(
hidden_states_reshaped,
selected_experts,
routing_weights,
)
shared_expert_output = (
F.sigmoid(self.shared_expert_gate(hidden_states_reshaped))
* shared_expert_output
)
expert_output = expert_output + shared_expert_output
expert_output = expert_output.reshape(
batch_size,
sequence_length,
hidden_dim,
)
return expert_output
このコードの流れを大きく見ると、次のようになります。
ここで
Bは
batch_size、
Sは
sequence_length、
Hは
hidden_dimを表します。
最初に、入力された
hidden_statesは
[B, S, H]から
[B*S, H]に変形されます。これは、MoEでは各トークンごとにExpertを選択するため、バッチ方向と系列長方向をまとめて「全トークンを1行ずつ並べたテンソル」として扱うためです。
その後、処理は大きく2つの経路に分かれます。1つ目は、Routerによって選択されたExpertを通る経路です。2つ目は、すべてのトークンに共通して適用されるshared expertの経路です。
最後に、この2つの経路の出力を足し合わせ、形を
[B, S, H]に戻すことで、MoEブロックの出力が得られます。
以降では、まずRouterによるExpert選択、次に選択されたExpertによる計算、最後にshared expertを含む出力の集約という順番で詳しく見ていきます。
選択されたExpertによる計算:Qwen3_5MoeExperts
Routerによって各トークンに対して使用するExpertが決まると、次に
Qwen3_5MoeExpertsで実際のExpert計算が行われます。
class Qwen3_5MoeExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.moe_intermediate_size
self.gate_up_proj = nn.Parameter(
torch.empty(
self.num_experts,
2 * self.intermediate_dim,
self.hidden_dim,
)
)
self.down_proj = nn.Parameter(
torch.empty(
self.num_experts,
self.hidden_dim,
self.intermediate_dim,
)
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(
top_k_index,
num_classes=self.num_experts,
)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(
expert_mask.sum(dim=(-1, -2)),
0,
).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == self.num_experts:
continue
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(
current_state,
self.gate_up_proj[expert_idx],
).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(
current_hidden_states,
self.down_proj[expert_idx],
)
current_hidden_states = (
current_hidden_states
* top_k_weights[token_idx, top_k_pos, None]
)
final_hidden_states.index_add_(
0,
token_idx,
current_hidden_states.to(final_hidden_states.dtype),
)
return final_hidden_states
ここでは、まず
top_k_indexをもとに、今回の入力で実際に使われるExpertだけを抽出しています。
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(
top_k_index,
num_classes=self.num_experts,
)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(
expert_mask.sum(dim=(-1, -2)),
0,
).nonzero()
top_k_indexは、各トークンに対して選ばれたExpert IDを表します。これをone-hot化し、Expert方向を先頭に並べ替えることで、各Expertがどのトークンに選ばれたのかを扱いやすくしています。
その後、
expert_hitによって、今回の入力で1回以上選ばれたExpertだけを取り出します。つまり、256個すべてのExpertを毎回計算するのではなく、実際に使われるExpertだけを処理対象にしています。
次に、Expertごとにループし、そのExpertに割り当てられたトークンを取り出します。
top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx]
token_idxは、現在のExpertに割り当てられたトークンの位置を表します。
そのため、
hidden_states[token_idx]によって、現在のExpertが処理すべきトークンの特徴量だけを取り出しています。
Expert内部のMLP計算では、まず
gate_up_projによって線形変換を行います。
gate, up = nn.functional.linear(
current_state,
self.gate_up_proj[expert_idx],
).chunk(2, dim=-1)
gate_up_projはExpertごとに異なる重みを持っています。
Qwen3.5-35B-A3Bでは、
num_experts = 256、
hidden_dim = 2048、
intermediate_dim = 512なので、
gate_up_projの形は次のようになります。
[num_experts, 2 * intermediate_dim, hidden_dim] = [256, 1024, 2048]
実際には、
self.gate_up_proj[expert_idx]によって、現在のExpertに対応する重みだけを取り出して使います。
self.gate_up_proj[expert_idx]: [1024, 2048]
current_stateの形を
[n_tokens_for_expert, hidden_dim]とすると、線形変換後の出力は次の形になります。
current_state: [n_tokens_for_expert, 2048] linear output: [n_tokens_for_expert, 1024]
この出力を
.chunk(2, dim=-1)で2分割し、
gateと
upに分けています。
gate: [n_tokens_for_expert, 512] up: [n_tokens_for_expert, 512]
この分岐と再合流の流れを図にすると、次のようになります。
ここで行われている中心的な処理は、次の部分です。
current_hidden_states = self.act_fn(gate) * up
これは、通常の単純なMLPではなく、SwiGLUに近いゲート付きMLP構造です。
gate側に活性化関数をかけたものが、
up側の特徴量をどの程度通すかを制御します。
gate: どの成分をどの程度通すかを制御する up: 変換後の特徴量本体 act_fn(gate) * up: gateによって制御された中間表現
その後、
down_projによって中間表現を再び
hidden_dimに戻します。
current_hidden_states = nn.functional.linear(
current_hidden_states,
self.down_proj[expert_idx],
)
down_projもExpertごとに異なる重みを持っています。
形は次のようになります。
[num_experts, hidden_dim, intermediate_dim]
Qwen3.5-35B-A3Bでは、具体的には次の形です。
[256, 2048, 512]
現在のExpertに対応する重みだけを見ると、次の形になります。
self.down_proj[expert_idx]: [2048, 512]
current_hidden_statesは
[n_tokens_for_expert, 512]なので、
down_projを通すことで次のように戻ります。
[n_tokens_for_expert, 512] ↓ [n_tokens_for_expert, 2048]
ここまでで、現在のExpertによる出力が得られます。
次に、この出力にRouterで計算された重みを掛けます。
current_hidden_states = (
current_hidden_states
* top_k_weights[token_idx, top_k_pos, None]
)
top_k_weightsは、選ばれたExpertの出力をどの程度反映するかを表す重みです。各トークンはTop-k個のExpertに通されるため、それぞれのExpert出力に対応する重みを掛けてから集約します。
最後に、
index_add_によって、対応するトークン位置へ出力を加算します。
final_hidden_states.index_add_(
0,
token_idx,
current_hidden_states.to(final_hidden_states.dtype),
)
同じトークンは複数のExpertに割り当てられるため、各Expertから得られた出力を同じトークン位置に足し合わせる必要があります。
index_add_は、そのための加算処理です。
つまり、
Qwen3_5MoeExpertsでは、Routerで選ばれたExpertごとに対象トークンを集め、そのExpert専用のゲート付きMLPで変換し、Routerの重みを掛けたうえで、最終的な出力テンソルに加算しています。
Shared Expertによる共通経路と出力の集約
Qwen3.5 MoEでは、256個のExpertの中から各トークンごとにTop-k個のExpertを選択します。一方で、それとは別に、すべてのトークンが必ず通る共通経路として
shared_expertも用意されています。
該当する処理は、
Qwen3_5MoeSparseMoeBlockの
forward内の次の部分です。
hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
shared_expert_output = self.shared_expert(hidden_states_reshaped)
shared_expert_output = (
F.sigmoid(self.shared_expert_gate(hidden_states_reshaped))
* shared_expert_output
)
expert_output = expert_output + shared_expert_output
まず、Attentionブロックを通った
hidden_statesは、MoEブロック内で
[batch_size × sequence_length, hidden_dim]の形に変形されます。
hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
その後、この
hidden_states_reshapedは
shared_expertに入力されます。
shared_expert_output = self.shared_expert(hidden_states_reshaped)
shared_expertは、Routerによって選択される通常のExpertとは異なり、すべてのトークンに共通して適用されるMLPです。つまり、Top-k RouterでどのExpertが選ばれたかに関係なく、各トークンは必ずこの共通経路を通ります。
さらに、shared expertの出力には
shared_expert_gateによるゲートが掛けられます。
shared_expert_output = (
F.sigmoid(self.shared_expert_gate(hidden_states_reshaped))
* shared_expert_output
)
ここでは、
shared_expert_gateによって各トークンごとにスカラー値を計算し、
sigmoidを通してから
shared_expert_outputに掛けています。これにより、shared expertの出力をどの程度反映するかをトークンごとに調整しています。
最後に、Routerで選ばれたExpertによる出力
expert_outputと、共通経路である
shared_expert_outputを足し合わせます。
expert_output = expert_output + shared_expert_output
つまり、
Qwen3_5MoeSparseMoeBlockの最終出力は、次の2つを合成したものです。
1. Routerで選ばれたTop-k Expertの出力 2. すべてのトークンが通るshared expertの出力
このように、Qwen3.5 MoEでは、トークンごとに選択されるSparseなExpert経路に加えて、全トークン共通のshared expert経路も組み合わせています。これにより、Expertごとの専門的な変換だけでなく、全トークンに共通する変換も同時に利用できる構造になっています。
最終的に得られたMoEブロックの出力は、DecoderLayer内で残差接続された後、次のDecoderLayer、または最終層であれば
Final RMSNormと
LM Headに渡されます。そして、
LM Headによって語彙数ぶんのlogitsに変換され、次トークン予測に使われます。
【補足】補助損失 aux_loss
MoEでは、各トークンごとに使用するExpertをRouterが選択します。しかし、学習中にRouterの出力が偏ると、一部のExpertばかりが使われ、他のExpertがほとんど使われない状態になる可能性があります。
このような偏りを抑えるために、Qwen3.5 MoEでは補助損失として
aux_lossが計算されます。これは、通常の言語モデリングlossとは別に、Expertの使用が特定のExpertに集中しすぎないようにするためのlossです。
該当する処理は、
Qwen3_5MoeForCausalLMの
forward内にあります。
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
output_router_logits=Trueの場合、各MoE層から出力された
router_logitsを使って
load_balancing_loss_funcが呼び出されます。
そして、
labelsが与えられている学習時には、通常の言語モデリングlossに対して、係数
router_aux_loss_coefを掛けた
aux_lossが加えられます。
最終的なloss = 言語モデリングloss + router_aux_loss_coef × aux_loss
load_balancing_loss_funcの大まかな処理は次のとおりです。
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
まず、各層の
router_logitsを結合し、softmaxによってExpert方向の確率分布
routing_weightsを計算します。
その後、Top-kによって実際に選ばれるExpertを取得し、one-hot化します。
ここで重要なのは、次の2つの値です。
tokens_per_expert = torch.mean(expert_mask.float(), dim=0) router_prob_per_expert = torch.mean(routing_weights, dim=0)
tokens_per_expertは、Top-kによって実際に各Expertがどれくらい選ばれたかを表します。
つまり、Routerの確率分布から実際に採用されたExpertの使用率です。
一方、
router_prob_per_expertは、Top-kでExpertを選ぶ前のRouter確率を平均したものです。
これは、Routerが各Expertに平均的にどれくらいの確率を割り当てていたかを表します。
整理すると、次のようになります。
tokens_per_expert = Top-k後の実際のExpert使用率 router_prob_per_expert = Top-k前のRouter確率の平均
最後に、この2つを掛け合わせて合計します。
overall_loss = torch.sum(
tokens_per_expert * router_prob_per_expert.unsqueeze(0)
)
return overall_loss * num_experts
この式では、実際によく選ばれているExpertであり、かつRouterが高い確率を割り当てているExpertがあると、そのExpertに対応する積が大きくなります。
つまり、特定のExpertにルーティングが集中しているほど
aux_lossが大きくなります。
直感的には、次のようなExpertの偏りを罰しています。
Expert 0ばかりが選ばれる RouterもExpert 0に高い確率を出し続ける → aux_lossが大きくなる
逆に、Expertが比較的均等に使われていれば、特定のExpertだけに大きな値が集中しにくくなるため、
aux_lossは小さくなります。
なお、
attention_maskが与えられている場合は、padding tokenを除外して
tokens_per_expertや
router_prob_per_expertを計算します。padding部分のRouter結果までExpert使用率に含めてしまうと、実際の有効トークンに対するExpert使用状況を正しく評価できないためです。
このように、
aux_lossはMoEモデルにおいて、Expertが偏って使われることを防ぐための補助的な正則化項として機能します。
まとめ
本記事では、Qwen3.5 MoEの実装を、
Qwen3_5MoeForCausalLMから
Qwen3_5MoeSparseMoeBlockまで順に追いながら、MoEがソースコード上でどのように動いているのかを確認しました。
Qwen3.5 MoEでは、DecoderLayerの後半に
Qwen3_5MoeSparseMoeBlockが配置されており、通常のDenseなFFNの代わりに、Routerで選択されたExpertによる変換が行われます。
Sparse MoE Blockの中心となる処理は、次の3つです。
Qwen3_5MoeTopKRouter → 各トークンに対して使用するExpertを選択する Qwen3_5MoeExperts → 選択されたExpertで実際に特徴量を変換する shared_expert → すべてのトークンが通る共通経路を提供する
つまり、Qwen3.5 MoEのSparse MoE Blockは、単に複数のMLPを並べた構造ではなく、RouterによるExpert選択、Expertごとの計算、shared expertによる共通経路を組み合わせた構造になっています。
また、学習時にはExpertの使用が偏りすぎないように、補助的に
aux_lossも使われます。
この流れを押さえることで、「入力トークンごとにExpertを選び、必要なExpertだけを使って計算する」というMoEの仕組みを、Qwen3.5の実装と対応づけて理解しやすくなります。