← 기사 목록
日本語https://qiita.com/tags/llm/feed

Qwen3.5のモデル構造からMoEを理解する

추출된 키워드

37
Mixture of Experts·5MoE·5Qwen3.5·5Qwen3_5MoeTextModel·4Qwen3_5MoeDecoderLayer·4Qwen3_5MoeSparseMoeBlock·4Sparse MoE Block·4DecoderLayer·4Qwen3_5MoeForCausalLM·4full attention·4linear attention·4Gated DeltaNet·4Qwen3.5-35B-A3B·4Qwen3_5MoeRMSNorm·3Rotary Position Embedding·3RMSNorm·3RoPE·3lm_head·3Qwen3_5MoeTextRotaryEmbedding·3Qwen3_5MoeGatedDeltaNet·3Qwen3_5MoeAttention·3Self-Attention·3Token Mixer·3Top-k routing·3Qwen3_5MoeTextConfig·3Qwen3_5MoeForConditionalGeneration·3Hugging Face Transformers·3LLM·3Qwen3.5-27B·3Qwen3.5-9B·3Qwen3.5-397B-A17B·3Qwen3.5-122B-A10B·3MoeCausalLMOutputWithPast·2PyTorch·2Python·2GQA·2silu·2

원문

34,867
Qwen3.5のモデル構造からMoEを理解する

はじめに

最近リリースされる大規模言語モデル(LLM)では、Mixture of Experts(MoE)と呼ばれる構造がよく採用されています。

MoEは、モデル内部に複数のネットワーク、すなわちExpertを持ち、入力トークンごとにその一部のExpertだけを選択して計算する仕組みです。これにより、モデル全体のパラメータ数を大きく保ちながら、推論時に実際に使う計算量を抑えることができます。

本記事では、Qwen3.5のMoEモデルを題材に、MoEがソースコード上でどのような構造として実装されているのか、また入力に応じてどのようにExpertが選択されるのかを解説します。

想定読者

本記事は、MoEの概念はある程度知っているものの、実際のモデル実装ではどのように動いているのかがまだイメージしづらい方を主な読者として想定しています。

そのため、MoEの基本概念そのものについては詳しく扱いません。MoEについて初めて学ぶ方は、まず以下の記事などを読んでおくと、本記事の内容を追いやすくなると思います。

また、ソースコードを読みながら解説するため、PythonおよびPyTorchの基本的な読み書きに慣れている方を前提としています。

なお、Qwen3.5 MoEには、Gated DeltaNetによるlinear attentionや、full attentionとのハイブリッド構成など、MoE以外にも多くの工夫が導入されています。ただし、本記事の主題はあくまでMoE部分の実装理解であるため、それらの仕組みについては必要な範囲で触れるにとどめ、詳細な解説は行いません。

前提

まず、Qwen3.5には、MoEモデルである

Qwen3.5-35B-A3B
/
Qwen3.5-122B-A10B
/
Qwen3.5-397B-A17B
と、denseモデルである
Qwen3.5-9B
/
Qwen3.5-27B
が基盤モデルとして公開されています。

MoEモデル名に含まれる

○○B-AxxB
は、
○○B
がモデル全体のパラメータ数、
AxxB
が推論時にアクティブになるパラメータ数の目安を表しています。たとえば
Qwen3.5-35B-A3B
であれば、モデル全体では35B規模のパラメータを持ちますが、各トークンの計算で主に使用されるのは約3B分のパラメータです。

本記事では、このうち

Qwen3.5-35B-A3B
を前提に解説します。また、ソースコードは以下のHugging Face Transformersの実装をもとにします。

この実装には、テキスト生成を扱う

Qwen3_5MoeForCausalLM
と、画像・動画入力も扱えるマルチモーダル用の
Qwen3_5MoeForConditionalGeneration
があります。本記事では、MoEの構造を追いやすくするため、テキスト生成用の
Qwen3_5MoeForCausalLM
に絞って解説します。

Qwen3.5 MoEの全体像

上図は、

Qwen3.5-35B-A3B
の全体構造を簡略化したものです。

モデルの中心となるのは

DecoderLayer
であり、
Qwen3.5-35B-A3B
では合計40層のDecoderLayerが積み重ねられています。ただし、すべての層が同じ構造になっているわけではありません。内部では、
linear_attention layer
が3層続いた後に
full_attention layer
が1層配置される構成を1グループとしており、このグループが10回繰り返されます。

つまり、全体としては次のような構造になります。

1 group = 3 × linear_attention layer + 1 × full_attention layer
10 groups × 4 layers = 40 layers

各DecoderLayerは、大きく見ると「Attention系のToken Mixer」と「Sparse MoE Block」から構成されます。

linear_attention layer
ではGated DeltaNetが使われ、
full_attention layer
では通常のSelf-Attentionに近い処理が使われます。一方で、後段にはどちらの層でもSparse MoE Blockが配置されます。

Configから構造を確認する

次に、

Qwen3_5MoeTextConfig
からモデル構造に関わる主要な設定を確認します。
class Qwen3_5MoeTextConfig(PreTrainedConfig):

    # 一部のパラメータは省略

    vocab_size: int = 248320                   # 語彙数。lm_headの出力次元で、次トークン候補の総数
    hidden_size: int = 2048                    # 隠れ状態ベクトルの次元数。各トークンを2048次元で表現する
    num_hidden_layers: int = 40                # DecoderLayerの層数。ここでは40層
    num_attention_heads: int = 16              # Full Attentionで使うQueryヘッド数
    num_key_value_heads: int = 2               # Full Attentionで使うKey/Valueヘッド数。GQA用の設定
    hidden_act: str = "silu"                   # MLPやMoE内部で使う活性化関数
    head_dim: int = 256                        # Full Attentionの各ヘッドの次元数
    linear_conv_kernel_dim: int = 4            # Gated DeltaNet内の短い畳み込みに使うカーネルサイズ
    linear_key_head_dim: int = 128             # Linear Attention側のKeyヘッド1つあたりの次元数
    linear_value_head_dim: int = 128           # Linear Attention側のValueヘッド1つあたりの次元数
    linear_num_key_heads: int = 16             # Linear Attention側のKeyヘッド数
    linear_num_value_heads: int = 32           # Linear Attention側のValueヘッド数
    moe_intermediate_size: int = 512           # 各MoE expert内部の中間層次元数
    shared_expert_intermediate_size: int = 512 # 全トークンで共有されるshared expertの中間層次元数
    num_experts_per_tok: int = 8               # 1トークンあたり選択されるexpert数。Top-k routingのk
    num_experts: int = 256                     # MoEに用意されているexpertの総数

このConfigを見ると、モデル全体の大まかな構造を把握できます。

まず、

vocab_size
248320
に設定されています。これは、モデルが出力候補として扱うトークンIDの総数です。最終的な
lm_head
は、各トークン位置の隠れ状態をこの語彙数ぶんのlogitに変換します。

また、

hidden_size
2048
です。これは、各トークンがモデル内部で2048次元のベクトルとして表現されることを意味します。DecoderLayerの数は
num_hidden_layers = 40
であり、先ほどの図で示したように、40層のDecoderLayerが積み重ねられます。

Attentionに関する設定としては、full attention用の

num_attention_heads
num_key_value_heads
head_dim
に加えて、linear attention用の
linear_key_head_dim
linear_value_head_dim
linear_num_key_heads
linear_num_value_heads
などが定義されています。ここから、Qwen3.5 MoEではfull attentionだけでなく、Gated DeltaNetによるlinear attentionも併用されていることが分かります。

さらに、本記事で主に扱うMoEに関する設定として、

num_experts = 256
num_experts_per_tok = 8
が定義されています。これは、モデル内に256個のExpertが用意されており、各トークンに対してそのうち8個のExpertが選択されることを意味します。

つまり、このConfigからは、

Qwen3.5-35B-A3B
が「40層のDecoderLayer」「full attentionとlinear attentionのハイブリッド構成」「256個のExpertからトークンごとに8個を選択するSparse MoE構造」を持つモデルであることが確認できます。

実装を追う

ここからは、実際のソースコードを上位のクラスから順に見ていきます。

流れとしては、まずユーザーが直接呼び出す

Qwen3_5MoeForCausalLM
を確認し、その内部で使われる
Qwen3_5MoeTextModel
、さらにその中に積み重ねられている
Qwen3_5MoeDecoderLayer
へと進みます。
Qwen3_5MoeForCausalLM
  ↓
Qwen3_5MoeTextModel
  ↓
Qwen3_5MoeDecoderLayer
  ↓
Qwen3_5MoeSparseMoeBlock

最終的には、本記事の主題である

Qwen3_5MoeSparseMoeBlock
の内部を詳しく見ていきます。

テキスト生成モデルの入口:Qwen3_5MoeForCausalLM

まずは、ユーザーが直接呼び出すモデルである

Qwen3_5MoeForCausalLM
から見ていきます。
class Qwen3_5MoeForCausalLM(Qwen3_5MoePreTrainedModel, GenerationMixin):
    def __init__(self, config):
        super().__init__(config)

        self.model = Qwen3_5MoeTextModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        self.router_aux_loss_coef = config.router_aux_loss_coef
        self.num_experts = config.num_experts
        self.num_experts_per_tok = config.num_experts_per_tok

        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        past_key_values=None,
        inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_router_logits=None,
        logits_to_keep=0,
        **kwargs,
    ) -> MoeCausalLMOutputWithPast:

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_router_logits=output_router_logits,
            **kwargs,
        )

        hidden_states = outputs.last_hidden_state

        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)

        aux_loss = None
        if output_router_logits:
            aux_loss = load_balancing_loss_func(
                outputs.router_logits,
                self.num_experts,
                self.num_experts_per_tok,
                attention_mask,
            )
            if labels is not None:
                loss += self.router_aux_loss_coef * aux_loss.to(loss.device)

        return MoeCausalLMOutputWithPast(
            loss=loss,
            aux_loss=aux_loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            router_logits=outputs.router_logits,
        )

Qwen3_5MoeForCausalLM
は、テキスト生成用モデルの入口にあたるクラスです。

大きな流れは次のとおりです。

input_ids
  ↓
Qwen3_5MoeTextModel
  ↓
hidden_states
  ↓
lm_head
  ↓
logits

まず、ユーザーから渡された

input_ids
Qwen3_5MoeTextModel
に入力されます。
Qwen3_5MoeTextModel
は、Embedding、DecoderLayer、RMSNormなどを通して、各トークンの特徴量である
hidden_states
を出力します。

その後、

lm_head
によって
hidden_states
が語彙数ぶんのスコアに変換されます。
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

今回のConfigでは、

hidden_size = 2048
vocab_size = 248320
です。そのため、各トークン位置の2048次元ベクトルは、248,320個の語彙トークンそれぞれに対するスコアへ変換されます。

このスコアが

logits
です。
hidden_states: [batch_size, seq_len, hidden_size]
logits:        [batch_size, seq_len, vocab_size]

また、

labels
が与えられている場合は言語モデリング用の
loss
が計算されます。さらに、
output_router_logits=True
の場合は、MoEのExpert選択が偏りすぎないようにするための補助損失
aux_loss
も計算されます。

最後に、

loss
aux_loss
logits
router_logits
などが
MoeCausalLMOutputWithPast
にまとめられて返されます。

MoeCausalLMOutputWithPast
は、通常のテンソルそのものではなく、モデル出力を名前付きでまとめるためのデータクラス系オブジェクトです。

言語モデル本体:Qwen3_5MoeTextModel

次に、

Qwen3_5MoeTextModel
を見ていきます。ここが、EmbeddingやDecoderLayerを含む言語モデル本体です。
class Qwen3_5MoeTextModel(Qwen3_5MoePreTrainedModel):
    def __init__(self, config: Qwen3_5MoeTextConfig):
        super().__init__(config)

        self.embed_tokens = nn.Embedding(
            config.vocab_size,
            config.hidden_size,
            config.pad_token_id,
        )

        self.layers = nn.ModuleList(
            [
                Qwen3_5MoeDecoderLayer(config, layer_idx)
                for layer_idx in range(config.num_hidden_layers)
            ]
        )

        self.norm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = Qwen3_5MoeTextRotaryEmbedding(config=config)

        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        past_key_values=None,
        inputs_embeds=None,
        use_cache=None,
        **kwargs,
    ) -> BaseModelOutputWithPast:

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        causal_mask = create_causal_mask(
            config=self.config,
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            position_ids=position_ids,
        )

        linear_attn_mask = self._update_linear_attn_mask(
            attention_mask,
            past_key_values,
        )

        hidden_states = inputs_embeds
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
            layer_mask = (
                linear_attn_mask
                if self.config.layer_types[i] == "linear_attention"
                else causal_mask
            )

            hidden_states = decoder_layer(
                hidden_states,
                position_embeddings=position_embeddings,
                attention_mask=layer_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                use_cache=use_cache,
                **kwargs,
            )

        hidden_states = self.norm(hidden_states)

        return Qwen3_5MoeModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values,
        )

Qwen3_5MoeTextModel
では、まず
input_ids
がEmbedding層に入力されます。
inputs_embeds = self.embed_tokens(input_ids)

これにより、各トークンIDは2048次元のベクトルに変換されます。

input_ids:     [batch_size, seq_len]
inputs_embeds: [batch_size, seq_len, hidden_size]

ここでの

hidden_size
はConfigで指定されている
2048
です。つまり、各トークンはモデル内部では2048次元の特徴ベクトルとして扱われます。

続いて、Attentionに使うマスクが作成されます。

causal_mask = create_causal_mask(...)
linear_attn_mask = self._update_linear_attn_mask(...)

Qwen3.5 MoEでは、層によって

full_attention
linear_attention
が切り替わるため、それぞれに応じたマスクが使われます。

また、位置情報としてRotary Position Embedding、いわゆるRoPEが計算されます。

position_embeddings = self.rotary_emb(hidden_states, position_ids)

RoPEは、トークンの順序や相対的な位置関係をAttention計算に反映するための位置埋め込みです。
その後、

DecoderLayer
を順番に通していきます。
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
    layer_mask = (
        linear_attn_mask
        if self.config.layer_types[i] == "linear_attention"
        else causal_mask
    )

    hidden_states = decoder_layer(...)

Decoderでは、

self.config.layer_types[i]
によって、各層が
linear_attention
なのか
full_attention
なのかを判定しています。
linear_attention
の層では
linear_attn_mask
が使われ、
full_attention
の層では
causal_mask
が使われます。

すべてのDecoderLayerを通った後、最後にRMSNormが適用されます。

hidden_states = self.norm(hidden_states)

そして、最終的な

hidden_states
last_hidden_state
として返されます。

DecoderLayerの構造

続いて、Qwen3.5 MoEのメインブロックである

DecoderLayer
を見ていきます。

ここまでで、Qwen3.5 MoEでは

linear_attention
full_attention
が層ごとに使い分けられていることを確認しました。

DecoderLayer
では、前半でトークン間の情報を混ぜる
Token Mixer
が使われ、後半で
Sparse MoE Block
によるFFN処理が行われます。Qwen3.5 MoEでは、この
Token Mixer
として、層によって
linear_attention
または
full_attention
が使い分けられています。
class Qwen3_5MoeDecoderLayer(GradientCheckpointingLayer):
    def __init__(self, config: Qwen3_5MoeTextConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.layer_type = config.layer_types[layer_idx]

        if self.layer_type == "linear_attention":
            self.linear_attn = Qwen3_5MoeGatedDeltaNet(config, layer_idx)
        elif self.layer_type == "full_attention":
            self.self_attn = Qwen3_5MoeAttention(config, layer_idx)

        self.mlp = Qwen3_5MoeSparseMoeBlock(config)
        self.input_layernorm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = Qwen3_5MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: torch.Tensor | None = None,
        position_ids: torch.LongTensor | None = None,
        past_key_values: Cache | None = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> torch.FloatTensor:
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Token Mixer
        if self.layer_type == "linear_attention":
            hidden_states = self.linear_attn(
                hidden_states=hidden_states,
                cache_params=past_key_values,
                attention_mask=attention_mask,
                **kwargs,
            )
        elif self.layer_type == "full_attention":
            hidden_states, _ = self.self_attn(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                position_embeddings=position_embeddings,
                **kwargs,
            )

        hidden_states = residual + hidden_states

        # Sparse MoE Block
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)

        if isinstance(hidden_states, tuple):
            hidden_states, _ = hidden_states

        hidden_states = residual + hidden_states

        return hidden_states

このコードを見ると、

Qwen3_5MoeDecoderLayer
は大きく次の2つの部分に分けられます。
1. Token Mixer
   - linear_attention: Qwen3_5MoeGatedDeltaNet
   - full_attention: Qwen3_5MoeAttention

2. Sparse MoE Block
   - Qwen3_5MoeSparseMoeBlock

まず、入力された

hidden_states
input_layernorm
によって正規化されます。
hidden_states = self.input_layernorm(hidden_states)

その後、

self.layer_type
に応じて、
linear_attention
または
full_attention
のどちらかに入力されます。
if self.layer_type == "linear_attention":
    hidden_states = self.linear_attn(...)
elif self.layer_type == "full_attention":
    hidden_states, _ = self.self_attn(...)

各層でどちらのAttentionを使うかは、Config内の

layer_types
によって決まります。
def __post_init__(self, **kwargs):
    kwargs.setdefault("partial_rotary_factor", 0.25)

    if self.layer_types is None:
        interval_pattern = kwargs.pop("full_attention_interval", 4)
        self.layer_types = [
            "linear_attention" if bool((i + 1) % interval_pattern) else "full_attention"
            for i in range(self.num_hidden_layers)
        ]

    super().__post_init__(**kwargs)

このコードでは、

full_attention_interval
のデフォルト値が
4
になっています。そのため、4層に1回だけ
full_attention
が使われ、それ以外の層では
linear_attention
が使われます。
linear_attention
linear_attention
linear_attention
full_attention

つまり、

Qwen3.5-35B-A3B
では、この4層のパターンが10回繰り返され、合計40層のDecoderLayerを構成しています。

full_attention
linear_attention
の違いについて

本記事はMoEの解説を主題としているため、それぞれのAttention機構については深く扱いません。ただし、Qwen3.5 MoEでは

full_attention
linear_attention
が併用されているため、ここでは両者の違いを簡単に整理します。

full_attention
:
通常のSelf-Attentionに近く、各トークンが他のトークンとのAttentionスコアを計算します。そのため、文脈中のトークン間の関係を直接的に捉えやすい一方で、系列長が長くなるほど計算量やメモリ使用量が大きくなります。

Qwen3.5 MoEの

full_attention
では、通常の
q
,
k
,
v
に加えてAttention出力を調整するgateが導入されています。また、Key/Valueヘッド数をQueryヘッド数より少なくするGrouped Query Attention(GQA)も使われており、KV cacheの増大を抑えたり、推論時のK/V読み出しコストを下げたりする工夫が含まれています。

linear_attention
:
full_attention
では全トークン間のAttentionスコアを計算するため、系列長を
N
とすると計算量が概ね
O(N^2)
に増えます。これに対して、
linear_attention
は全トークン間のAttention行列を明示的に作らず、系列情報をより効率的に扱うための仕組みです。

Qwen3.5 MoEでは、

linear_attention
として
Gated DeltaNet
が使われています。Gated DeltaNetは、通常のAttentionとは異なり、系列情報を状態として扱いながらトークン間の情報を混ぜるLinear Attention系のToken Mixerです。これにより、長い系列に対する計算量やメモリ使用量を抑えやすくなります。

このように、Qwen3.5 MoEでは、効率性を重視した

linear_attention
と、通常のSelf-Attentionに近い
full_attention
を組み合わせることで、計算量を抑えつつ文脈情報を扱う構成になっています。

どちらのToken Mixerを通った場合でも、その後段には共通して

Qwen3_5MoeSparseMoeBlock
が配置されています。
self.mlp = Qwen3_5MoeSparseMoeBlock(config)

そのため、DecoderLayerの流れは次のように整理できます。

ここまでで、入力が

Qwen3_5MoeForCausalLM
から
Qwen3_5MoeTextModel
、さらに
Qwen3_5MoeDecoderLayer
を通り、その中で
Qwen3_5MoeSparseMoeBlock
に到達するまでの流れを確認しました。

以降では、本記事の主題である

Qwen3_5MoeSparseMoeBlock
の内部、特にRouterによるExpert選択、Expert計算、出力の集約について詳しく見ていきます。

Sparse MoE Blockの内部

Sparse MoE Block全体の流れ

ここからは、本記事のメインであるMoEブロックの内部を見ていきます。

Attentionブロックを通った

hidden_states
は、
Qwen3_5MoeSparseMoeBlock
に入力されます。このブロックでは、入力トークンごとに使用するExpertを選択し、その出力を集約することで、最終的なMoEブロックの出力を作ります。

まずは、全体の流れを確認します。

class Qwen3_5MoeSparseMoeBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate = Qwen3_5MoeTopKRouter(config)
        self.experts = Qwen3_5MoeExperts(config)
        self.shared_expert = Qwen3_5MoeMLP(
            config,
            intermediate_size=config.shared_expert_intermediate_size,
        )
        self.shared_expert_gate = torch.nn.Linear(
            config.hidden_size,
            1,
            bias=False,
        )

    def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        batch_size, sequence_length, hidden_dim = hidden_states.shape

        hidden_states_reshaped = hidden_states.view(-1, hidden_dim)

        shared_expert_output = self.shared_expert(hidden_states_reshaped)

        _, routing_weights, selected_experts = self.gate(hidden_states_reshaped)

        expert_output = self.experts(
            hidden_states_reshaped,
            selected_experts,
            routing_weights,
        )

        shared_expert_output = (
            F.sigmoid(self.shared_expert_gate(hidden_states_reshaped))
            * shared_expert_output
        )

        expert_output = expert_output + shared_expert_output
        expert_output = expert_output.reshape(
            batch_size,
            sequence_length,
            hidden_dim,
        )

        return expert_output

このコードの流れを大きく見ると、次のようになります。

ここで

B
batch_size
S
sequence_length
H
hidden_dim
を表します。

最初に、入力された

hidden_states
[B, S, H]
から
[B*S, H]
に変形されます。これは、MoEでは各トークンごとにExpertを選択するため、バッチ方向と系列長方向をまとめて「全トークンを1行ずつ並べたテンソル」として扱うためです。

その後、処理は大きく2つの経路に分かれます。1つ目は、Routerによって選択されたExpertを通る経路です。2つ目は、すべてのトークンに共通して適用されるshared expertの経路です。

最後に、この2つの経路の出力を足し合わせ、形を

[B, S, H]
に戻すことで、MoEブロックの出力が得られます。

以降では、まずRouterによるExpert選択、次に選択されたExpertによる計算、最後にshared expertを含む出力の集約という順番で詳しく見ていきます。

選択されたExpertによる計算:Qwen3_5MoeExperts

Routerによって各トークンに対して使用するExpertが決まると、次に

Qwen3_5MoeExperts
で実際のExpert計算が行われます。
class Qwen3_5MoeExperts(nn.Module):
    """Collection of expert weights stored as 3D tensors."""

    def __init__(self, config):
        super().__init__()
        self.num_experts = config.num_experts
        self.hidden_dim = config.hidden_size
        self.intermediate_dim = config.moe_intermediate_size

        self.gate_up_proj = nn.Parameter(
            torch.empty(
                self.num_experts,
                2 * self.intermediate_dim,
                self.hidden_dim,
            )
        )
        self.down_proj = nn.Parameter(
            torch.empty(
                self.num_experts,
                self.hidden_dim,
                self.intermediate_dim,
            )
        )
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(
        self,
        hidden_states: torch.Tensor,
        top_k_index: torch.Tensor,
        top_k_weights: torch.Tensor,
    ) -> torch.Tensor:
        final_hidden_states = torch.zeros_like(hidden_states)

        with torch.no_grad():
            expert_mask = torch.nn.functional.one_hot(
                top_k_index,
                num_classes=self.num_experts,
            )
            expert_mask = expert_mask.permute(2, 1, 0)
            expert_hit = torch.greater(
                expert_mask.sum(dim=(-1, -2)),
                0,
            ).nonzero()

        for expert_idx in expert_hit:
            expert_idx = expert_idx[0]

            if expert_idx == self.num_experts:
                continue

            top_k_pos, token_idx = torch.where(expert_mask[expert_idx])

            current_state = hidden_states[token_idx]

            gate, up = nn.functional.linear(
                current_state,
                self.gate_up_proj[expert_idx],
            ).chunk(2, dim=-1)

            current_hidden_states = self.act_fn(gate) * up

            current_hidden_states = nn.functional.linear(
                current_hidden_states,
                self.down_proj[expert_idx],
            )

            current_hidden_states = (
                current_hidden_states
                * top_k_weights[token_idx, top_k_pos, None]
            )

            final_hidden_states.index_add_(
                0,
                token_idx,
                current_hidden_states.to(final_hidden_states.dtype),
            )

        return final_hidden_states

ここでは、まず

top_k_index
をもとに、今回の入力で実際に使われるExpertだけを抽出しています。
with torch.no_grad():
    expert_mask = torch.nn.functional.one_hot(
        top_k_index,
        num_classes=self.num_experts,
    )
    expert_mask = expert_mask.permute(2, 1, 0)
    expert_hit = torch.greater(
        expert_mask.sum(dim=(-1, -2)),
        0,
    ).nonzero()

top_k_index
は、各トークンに対して選ばれたExpert IDを表します。これをone-hot化し、Expert方向を先頭に並べ替えることで、各Expertがどのトークンに選ばれたのかを扱いやすくしています。

その後、

expert_hit
によって、今回の入力で1回以上選ばれたExpertだけを取り出します。つまり、256個すべてのExpertを毎回計算するのではなく、実際に使われるExpertだけを処理対象にしています。

次に、Expertごとにループし、そのExpertに割り当てられたトークンを取り出します。

top_k_pos, token_idx = torch.where(expert_mask[expert_idx])

current_state = hidden_states[token_idx]

token_idx
は、現在のExpertに割り当てられたトークンの位置を表します。
そのため、
hidden_states[token_idx]
によって、現在のExpertが処理すべきトークンの特徴量だけを取り出しています。

Expert内部のMLP計算では、まず

gate_up_proj
によって線形変換を行います。
gate, up = nn.functional.linear(
    current_state,
    self.gate_up_proj[expert_idx],
).chunk(2, dim=-1)

gate_up_proj
はExpertごとに異なる重みを持っています。
Qwen3.5-35B-A3Bでは、
num_experts = 256
hidden_dim = 2048
intermediate_dim = 512
なので、
gate_up_proj
の形は次のようになります。
[num_experts, 2 * intermediate_dim, hidden_dim]
= [256, 1024, 2048]

実際には、

self.gate_up_proj[expert_idx]
によって、現在のExpertに対応する重みだけを取り出して使います。
self.gate_up_proj[expert_idx]: [1024, 2048]

current_state
の形を
[n_tokens_for_expert, hidden_dim]
とすると、線形変換後の出力は次の形になります。
current_state:
  [n_tokens_for_expert, 2048]

linear output:
  [n_tokens_for_expert, 1024]

この出力を

.chunk(2, dim=-1)
で2分割し、
gate
up
に分けています。
gate: [n_tokens_for_expert, 512]
up:   [n_tokens_for_expert, 512]

この分岐と再合流の流れを図にすると、次のようになります。

ここで行われている中心的な処理は、次の部分です。

current_hidden_states = self.act_fn(gate) * up

これは、通常の単純なMLPではなく、SwiGLUに近いゲート付きMLP構造です。

gate
側に活性化関数をかけたものが、
up
側の特徴量をどの程度通すかを制御します。
gate:
  どの成分をどの程度通すかを制御する

up:
  変換後の特徴量本体

act_fn(gate) * up:
  gateによって制御された中間表現

その後、

down_proj
によって中間表現を再び
hidden_dim
に戻します。
current_hidden_states = nn.functional.linear(
    current_hidden_states,
    self.down_proj[expert_idx],
)

down_proj
もExpertごとに異なる重みを持っています。
形は次のようになります。
[num_experts, hidden_dim, intermediate_dim]

Qwen3.5-35B-A3Bでは、具体的には次の形です。

[256, 2048, 512]

現在のExpertに対応する重みだけを見ると、次の形になります。

self.down_proj[expert_idx]: [2048, 512]

current_hidden_states
[n_tokens_for_expert, 512]
なので、
down_proj
を通すことで次のように戻ります。
[n_tokens_for_expert, 512]
  ↓
[n_tokens_for_expert, 2048]

ここまでで、現在のExpertによる出力が得られます。

次に、この出力にRouterで計算された重みを掛けます。

current_hidden_states = (
    current_hidden_states
    * top_k_weights[token_idx, top_k_pos, None]
)

top_k_weights
は、選ばれたExpertの出力をどの程度反映するかを表す重みです。各トークンはTop-k個のExpertに通されるため、それぞれのExpert出力に対応する重みを掛けてから集約します。

最後に、

index_add_
によって、対応するトークン位置へ出力を加算します。
final_hidden_states.index_add_(
    0,
    token_idx,
    current_hidden_states.to(final_hidden_states.dtype),
)

同じトークンは複数のExpertに割り当てられるため、各Expertから得られた出力を同じトークン位置に足し合わせる必要があります。

index_add_
は、そのための加算処理です。

つまり、

Qwen3_5MoeExperts
では、Routerで選ばれたExpertごとに対象トークンを集め、そのExpert専用のゲート付きMLPで変換し、Routerの重みを掛けたうえで、最終的な出力テンソルに加算しています。

Shared Expertによる共通経路と出力の集約

Qwen3.5 MoEでは、256個のExpertの中から各トークンごとにTop-k個のExpertを選択します。一方で、それとは別に、すべてのトークンが必ず通る共通経路として

shared_expert
も用意されています。

該当する処理は、

Qwen3_5MoeSparseMoeBlock
forward
内の次の部分です。
hidden_states_reshaped = hidden_states.view(-1, hidden_dim)

shared_expert_output = self.shared_expert(hidden_states_reshaped)

shared_expert_output = (
    F.sigmoid(self.shared_expert_gate(hidden_states_reshaped))
    * shared_expert_output
)

expert_output = expert_output + shared_expert_output

まず、Attentionブロックを通った

hidden_states
は、MoEブロック内で
[batch_size × sequence_length, hidden_dim]
の形に変形されます。
hidden_states_reshaped = hidden_states.view(-1, hidden_dim)

その後、この

hidden_states_reshaped
shared_expert
に入力されます。
shared_expert_output = self.shared_expert(hidden_states_reshaped)

shared_expert
は、Routerによって選択される通常のExpertとは異なり、すべてのトークンに共通して適用されるMLPです。つまり、Top-k RouterでどのExpertが選ばれたかに関係なく、各トークンは必ずこの共通経路を通ります。

さらに、shared expertの出力には

shared_expert_gate
によるゲートが掛けられます。
shared_expert_output = (
    F.sigmoid(self.shared_expert_gate(hidden_states_reshaped))
    * shared_expert_output
)

ここでは、

shared_expert_gate
によって各トークンごとにスカラー値を計算し、
sigmoid
を通してから
shared_expert_output
に掛けています。これにより、shared expertの出力をどの程度反映するかをトークンごとに調整しています。

最後に、Routerで選ばれたExpertによる出力

expert_output
と、共通経路である
shared_expert_output
を足し合わせます。
expert_output = expert_output + shared_expert_output

つまり、

Qwen3_5MoeSparseMoeBlock
の最終出力は、次の2つを合成したものです。
1. Routerで選ばれたTop-k Expertの出力
2. すべてのトークンが通るshared expertの出力

このように、Qwen3.5 MoEでは、トークンごとに選択されるSparseなExpert経路に加えて、全トークン共通のshared expert経路も組み合わせています。これにより、Expertごとの専門的な変換だけでなく、全トークンに共通する変換も同時に利用できる構造になっています。

最終的に得られたMoEブロックの出力は、DecoderLayer内で残差接続された後、次のDecoderLayer、または最終層であれば

Final RMSNorm
LM Head
に渡されます。そして、
LM Head
によって語彙数ぶんのlogitsに変換され、次トークン予測に使われます。

【補足】補助損失 aux_loss

MoEでは、各トークンごとに使用するExpertをRouterが選択します。しかし、学習中にRouterの出力が偏ると、一部のExpertばかりが使われ、他のExpertがほとんど使われない状態になる可能性があります。

このような偏りを抑えるために、Qwen3.5 MoEでは補助損失として

aux_loss
が計算されます。これは、通常の言語モデリングlossとは別に、Expertの使用が特定のExpertに集中しすぎないようにするためのlossです。

該当する処理は、

Qwen3_5MoeForCausalLM
forward
内にあります。
aux_loss = None
if output_router_logits:
    aux_loss = load_balancing_loss_func(
        outputs.router_logits,
        self.num_experts,
        self.num_experts_per_tok,
        attention_mask,
    )
    if labels is not None:
        loss += self.router_aux_loss_coef * aux_loss.to(loss.device)

output_router_logits=True
の場合、各MoE層から出力された
router_logits
を使って
load_balancing_loss_func
が呼び出されます。
そして、
labels
が与えられている学習時には、通常の言語モデリングlossに対して、係数
router_aux_loss_coef
を掛けた
aux_loss
が加えられます。
最終的なloss
= 言語モデリングloss + router_aux_loss_coef × aux_loss

load_balancing_loss_func
の大まかな処理は次のとおりです。
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)

_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)

expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)

まず、各層の

router_logits
を結合し、softmaxによってExpert方向の確率分布
routing_weights
を計算します。
その後、Top-kによって実際に選ばれるExpertを取得し、one-hot化します。

ここで重要なのは、次の2つの値です。

tokens_per_expert = torch.mean(expert_mask.float(), dim=0)

router_prob_per_expert = torch.mean(routing_weights, dim=0)

tokens_per_expert
は、Top-kによって実際に各Expertがどれくらい選ばれたかを表します。
つまり、Routerの確率分布から実際に採用されたExpertの使用率です。

一方、

router_prob_per_expert
は、Top-kでExpertを選ぶ前のRouter確率を平均したものです。
これは、Routerが各Expertに平均的にどれくらいの確率を割り当てていたかを表します。

整理すると、次のようになります。

tokens_per_expert
  = Top-k後の実際のExpert使用率

router_prob_per_expert
  = Top-k前のRouter確率の平均

最後に、この2つを掛け合わせて合計します。

overall_loss = torch.sum(
    tokens_per_expert * router_prob_per_expert.unsqueeze(0)
)

return overall_loss * num_experts

この式では、実際によく選ばれているExpertであり、かつRouterが高い確率を割り当てているExpertがあると、そのExpertに対応する積が大きくなります。
つまり、特定のExpertにルーティングが集中しているほど

aux_loss
が大きくなります。

直感的には、次のようなExpertの偏りを罰しています。

Expert 0ばかりが選ばれる
RouterもExpert 0に高い確率を出し続ける
→ aux_lossが大きくなる

逆に、Expertが比較的均等に使われていれば、特定のExpertだけに大きな値が集中しにくくなるため、

aux_loss
は小さくなります。

なお、

attention_mask
が与えられている場合は、padding tokenを除外して
tokens_per_expert
router_prob_per_expert
を計算します。padding部分のRouter結果までExpert使用率に含めてしまうと、実際の有効トークンに対するExpert使用状況を正しく評価できないためです。

このように、

aux_loss
はMoEモデルにおいて、Expertが偏って使われることを防ぐための補助的な正則化項として機能します。

まとめ

本記事では、Qwen3.5 MoEの実装を、

Qwen3_5MoeForCausalLM
から
Qwen3_5MoeSparseMoeBlock
まで順に追いながら、MoEがソースコード上でどのように動いているのかを確認しました。

Qwen3.5 MoEでは、DecoderLayerの後半に

Qwen3_5MoeSparseMoeBlock
が配置されており、通常のDenseなFFNの代わりに、Routerで選択されたExpertによる変換が行われます。

Sparse MoE Blockの中心となる処理は、次の3つです。

Qwen3_5MoeTopKRouter
  → 各トークンに対して使用するExpertを選択する

Qwen3_5MoeExperts
  → 選択されたExpertで実際に特徴量を変換する

shared_expert
  → すべてのトークンが通る共通経路を提供する

つまり、Qwen3.5 MoEのSparse MoE Blockは、単に複数のMLPを並べた構造ではなく、RouterによるExpert選択、Expertごとの計算、shared expertによる共通経路を組み合わせた構造になっています。

また、学習時にはExpertの使用が偏りすぎないように、補助的に

aux_loss
も使われます。

この流れを押さえることで、「入力トークンごとにExpertを選び、必要なExpertだけを使って計算する」というMoEの仕組みを、Qwen3.5の実装と対応づけて理解しやすくなります。