← 기사 목록
日本語https://zenn.dev/topics/llm/feed

Key-Value Means (KVM):RNNとTransformerの境界線を消す異端の設計

추출된 키워드

49
KVM·5Key-Value Means·5JIT 正規化·4winner-take-all·4オンライン K-means·4delta rule·4RWKV-7·4Linear RNN·4Transformer·4RNN·4scaled_dot_product_attention·3scatter·3KV キャッシュ·3learnable temperature·3attention sink·3Partial RoPE·3LayerNorm·3CUDA kernel·3OVQ·3Recursal AI·3Eleuther AI·3GDN·3Kimi Delta Attention·3分離性 (separability·3Online Vector Quantization·3Full Attention·3温度ゼロ極限のアテンション·3Mamba·2StreamingLM·2SDPA·2FlexAttention·2Gated DeltaNet·2TokenFormer·2Compressive Transformer·2TransformerFAM·2Block-Recurrent Transformer·2LaCT·2Titans·2Kimi Delta·2remove_rope·2MI325X·2MI300·2AMD GPU·2Daniel·2Eugene·2W_merge_gate·2PyTorch operation·2DPLR·2IPLR·2

원문

12,527
Key-Value Means (KVM):RNNとTransformerの境界線を消す異端の設計

Key-Value Means (KVM):RNNとTransformerの境界線を消す異端の設計

はじめに

こんにちはOpenMOSEです。お元気ですか?

2026年5月、Recursal AI / Eleuther AI の Daniel氏 と Eugene氏から Key-Value Means (KVM) が公開されました。

このアーキテクチャの面白さは、ひとことで言うと 「Transformer か RNN か」の二者択一を、連続的なスペクトラムに溶かしてしまった ところにあります。

Linear RNN(RWKV-7, GDN, Kimi Delta Attention)が固定サイズの状態で KVM はその間に「目盛り」を引き直す試み です。

本記事では、過去に擬似コードベースで掘り下げた内容も交えながら、この手法の 優位性異端性 を整理します。

1. KVM の立ち位置:RNN と Transformer の間

論文 Table 1 を要約すると、計算量プロファイルは以下のようになります。

性質Linear RNNKVM (fixed)KVM ( Full Attention
State サイズ
Prefill 時間
Decode 時間/token
Recall 性能限定的限定的強い完全
Prefill 並列化チャンク単位チャンク単位チャンク単位完全並列

注目すべきは KVM ( という中間点が存在することです。固定状態サイズの KVM (fixed) は実質 chunked RNN として振る舞いますが、State を「もっとも新規性の高いオーバーフロートークン」で

拡張させるモードを使うと、

つまりユーザは推論時のコンテキスト要求に応じて、

  • 「短文中心、メモリ最小」→ KVM (fixed)
  • 「長文 recall を妥協なく」→ KVM ()\sqrt{N}
  • 「とにかく完璧」→ Full Attention

同じアーキテクチャの中で 選べる。これがハイブリッドモデル設計の自由度を一気に広げます。

2. 核となる発想:勝者総取りのクラスタリング

KVM の State 更新ルールは、本質的には オンライン K-means です。擬似コードで核心部分だけ抜き出すと:

# Phase 1: BSWA からあふれたトークンに data-dependent gating を適用
g = 1 + elu(x @ self.W_merge_gate)[:, :, bswa_begin - chunk_len:bswa_begin]
o_k = self.layernorm_s_k(remove_rope(o_k)) * g
o_v = o_v * g

# Phase 2: state 内で最も似た 1 スロットを見つけて加算
s_k_norm = self.layernorm_s_k(s_k)
sim = o_k @ s_k_norm.mT
sim[..., 0:sink_len] = float('-inf')      # sink トークンは保護
best_s_idx = sim.max(dim=-1, keepdim=True).indices
sim_max = scatter(zeros_like(sim), -1, best_s_idx, ones_like(sim))  # one-hot

s_k = s_k + (sim_max.mT @ o_k)   # winner-take-all 加算
s_v = s_v + (sim_max.mT @ o_v)

この

winner-take-all
の構造は、K-means の中身そのものです。
  • State slot = cluster centroid
  • Overflow token = 新しいデータ点
  • 最近傍 centroid を見つけて、その centroid に加算(= centroid 更新)

論文中では、ソフトマックスの代わりに「最大値だけ 1.0 で残し、他を 0 にする」という 温度ゼロ極限のアテンション を使う設計に至った理由が、State key の 分離性 (separability) を保つためだと議論されています。「似ているものを 1 つに集約することで、State 空間全体としてのキーの分散を維持する」という発想で、ここに OVQ (Online Vector Quantization) との並行進化が見えます。

ちなみに OVQ との差分は論文中で明確に切り分けられていて、要点だけ書くと:

  • KVM は 単一の softmax passで State + BSWA を統合 attention(OVQ は別レイヤー)
  • JIT 正規化によって centroid のカウントを別途追跡しなくてよい
  • RoPE 部分のゼロ化で位置エンコーディングと State の整合性を取る
  • State サイズが 無制限に拡張可能
  • Sink トークン保護+ value のノルム保存
  • State 領域と BSWA 領域に 別々の learnable temperatureを割り当てる

これらの「細かい工夫の積み重ね」が、論文の本当の貢献だと感じます。

3. 4つの技術的な工夫

3.1 JIT (Just-In-Time) 正規化

複数のベクトルを単純に足し合わせると、

  • 直交するベクトル同士は norm が縮む
  • 反対方向のベクトルは破壊的干渉でさらに縮む

という現象が起きます。これは累積マージで State key のノルムが時間とともに減衰していく原因です。

KVM はこれを「State key を 保存時には正規化しない」「Attention に投入する直前に毎回 LayerNorm をかける」という JIT 設計で回避しています。

s_k_attn = self.layernorm_s_k(s_k) * self.state_temperature

State 内部はあくまで「累積された加算和」として保持され、attention 時にだけ正規化空間に投影される。これにより:

  • 累積数を別途追跡する必要がない(OVQ との大きな違い)
  • Query/Key normalization の理論的根拠(test-time regression 文脈)とも整合する

Value の方は別の問題があります。Sink トークンは他のトークンと norm が桁違いに違うことが知られている (Guo et al., 2024) ので、単純な正規化はかえって有害です。KVM は「初期 state value の norm を記録しておき、その値を その slot 専用の JIT norm 半径 として一生使い回す」という形で、slot ごとに固有のスケールを保ちます。

s_v_attn = (normalize(s_v.float(), dim=-1) * s_vlen).to(s_v.dtype)

s_vlen
がその slot 固有の長さ。これも学習可能パラメータです。

3.2 Partial RoPE のゼロ化

State 内のキーは、いろんな絶対位置から来た overflow token がマージされた結果なので、もはや単一の RoPE 角度を割り当てることができません。

KVM の解決策はシンプルで、RoPE が適用される次元の最初の

o_k = remove_rope(o_k)  # rotary 部分を 0 に

これによって:

  • State key 側は purely semantic な空間で動く
  • BSWA window 側は通常通り RoPE で位置を扱う
  • 同じ Q がこの両方に attend するとき、State 側は RoPE 部分が無視され、BSWA 側は通常の位置依存 attention が動く

論文中でも「partial RoPE をゼロにすると一部の表現力が落ちる」と認めていて、別解(State 側だけ unrotated query で別途 attention して logsumexp で merge する)も提案されていますが、現状の実装はシンプルさを優先して partial RoPE ゼロ化を採用しています。

3.3 State 拡張:「最も意外な」トークンを残す

ここが KVM の中で 一番異端な部分 だと思います。

固定サイズ State の限界は、結局のところ「容量に対してリコール対象が増えると、情報が圧縮されきって失われる」点にあります。KVM は state サイズを 時間とともに拡張する ことでこれを回避します。

ルールは明快で:

  • Overflow ブロックが到来したら、その中で 現在の State との類似度が最も低いトークン(= 最も surprising / novel なトークン)を一定数選んで append
  • 残りの overflow トークンは、通常通り winner-take-all で既存スロットにマージ

「driver の交差点を埋めずに、本当に新しい landmark だけを地図に追加していく」イメージです。スケジュールは固定(power-law / saturating など複数バリアントが Figure 2 で議論)ですが、「最近傍類似度の閾値」を学習可能にする方向は future work として明示されています。

3.4 Sink トークン保護と learnable temperature

StreamingLM (Xiao et al., 2024) で重要性が知られた attention sink は、State 先頭の

sink_len
個のスロットとして恒久的に保持され、winner-take-all のマージ対象から除外されます。
sim[..., 0:sink_len] = float('-inf')

さらに State 領域と BSWA 領域には別々の温度パラメータが学習で割り当てられます。

s_k_attn = self.layernorm_s_k(s_k) * self.state_temperature
bswa_k   = k[:, :, bswa_begin:bswa_end] * self.bswa_temperature

これは「State key は累積加算でスケールが変動する」「BSWA は通常のキー」という、性質の異なる2種類のキーを同じ softmax に同居させる ための鍵です。

4. RWKV-7 との対比で見る「異端性」

ここからは個人的な視点も入りますが、自分が普段 RWKV-7 のステート更新を扱っている身として、KVM を見るとアーキテクチャ哲学の違いが鮮明です。

4.1 状態更新の対比

RWKV-7 (delta rule):

  • 全スロットに対して 連続的に減衰 + 加算
  • 「忘れながら全体に書き込む」
  • IPLR / DPLR の matrix-valued state で表現力が高い
  • ただし custom CUDA kernel がほぼ必須

KVM (winner-take-all):

  • 離散的に 1 スロットだけ選んで加算
  • 「最も似たスロットにだけ情報を積む」
  • 表現力は限定的(最近傍 1 つへの加算)
  • でも 標準的な PyTorch operationだけで実装可能

RWKV-7 の delta rule は「全体に対する連続更新」、KVM は「離散的に選んで局所更新」。これは Linear RNN 族の中でも極めて異質な選択です。

4.2 追加パラメータの少なさ

これも KVM の大きな利点で、Q/K/V/O 以外の追加学習パラメータは:

  • W_merge_gate
    (マージゲートの射影、1 行列のみ)
  • layernorm_s_k
    の γ, β
  • state_temperature
    ,
    bswa_temperature
    (スカラー)
  • s_vlen
    (slot 固有のスケール)
  • 初期 State
    s_k
    ,
    s_v

これだけ。RWKV-7 が time_decay, time_first, receptance, gate, w / a / k / v の LoRA 群などで層ごとに大量のパラメータを追加するのと比べると、KVM は既存の Transformer に最小限の差分で載る 設計です。

論文がアブストで「insignificant number of new parameters」と強調しているのはこの点で、Transformer-to-X 変換系の蒸留パイプラインに組み込む観点でも極めて魅力的です。

4.3 並列化

両者ともチャンク単位での並列化を実現していますが、

  • RWKV-7 は chunked parallel kernel が 必須(実装難度が高い)
  • KVM は チャンク内は通常の SDPA、チャンク境界で離散更新という単純な構造なので、FlexAttention や標準実装にそのまま乗る

「カスタムカーネルなしで Linear RNN 相当の効率」というのは、production 実装の観点では非常に大きい。

5. KVM の優位性まとめ

ここまで見てきた特徴を整理すると、KVM の優位性は4点に集約できます。

5.1 連続的なメモリ-リコールのトレードオフ

ユーザが「fixed」「任意の中間点 を選べる。同じアーキテクチャで Prefill 計算量を

5.2 カスタムカーネル不要

scaled_dot_product_attention
scatter
だけで実装できる。これは AMD GPU (MI300 / MI325X) や consumer GPU でも、CUDA 専用最適化なしで素直に動くことを意味します。

5.3 全レイヤー適用可能

Linear RNN は通常 hybrid(一部レイヤーは softmax attention に残す)構成が必要ですが、KVM は 全レイヤーに置き換え可能。それでいて KV キャッシュメモリは大幅に削減される。

5.4 LRNN との併用も可能

論文では「LRNN レイヤーと組み合わせて、LRNN 側にサブリニアな長文記憶を補完する hybrid 構成」も提案されています。これは RWKV / Gated DeltaNet を主体に置きつつ、KVM レイヤーで long-context recall を補強する方向の研究を呼びそうです。

6. KVM の「異端性」

最後に、なぜこの設計が異端と感じられるかを言語化しておきます。

異端性①:「Linear RNN を作る」のではなく「Transformer の attention に recurrence を仕込む」

近年の Linear RNN 系の流れ(Mamba, RWKV-7, GDN, Kimi Delta, Titans, LaCT)は、基本的に「attention を別の機構で置き換える」アプローチでした。KVM は逆で、Softmax attention をそのまま使い続ける。ただしその attention の対象に、recurrence で更新される圧縮 State を含める。

「Linear RNN 化」ではなく「Block-Recurrent化された Attention」という位置取り。これは Block-Recurrent Transformer や TransformerFAM の系譜ですが、State 更新を 学習可能な K-means にしてしまった点で完全に新しい。

異端性②:温度ゼロ極限のアテンションを「使い切る」

Softmax attention の表現力を信じる流派と、「Linear attention で十分」と考える流派の対立がある中で、KVM は 温度ゼロ極限(= argmax)まで先鋭化することで離散的に分離性を保つ という、両者の中間に独自の位置を切る選択をしました。

「ソフトマックスの良さと、離散割当の良さを、State 構築側でだけ argmax にすることで両取りする」。これは TTT 系(gradient descent を runtime で回す)とも、純粋な linear attention とも違うルートです。

異端性③:明示的に State を拡張する

ほとんどの Linear RNN は「固定サイズ State こそが利点」と考えます。KVM は「サブリニアに拡張する State」という、attention と RNN の中間概念を提示しました。

これは Compressive Transformer や TokenFormer の系譜ですが、KVM の場合は**「最も意外なトークンを残す」というシンプルな規則だけで** sublinear growth を自然に実現します。

7. おわりに

KVM は「RNN か Transformer か」の二択を、勾配ベース optimizer も custom kernel も介さずに、単純な K-means 加算と JIT 正規化だけで スペクトラム化した手法です。

個人的に重要だと思うポイントは:

  • Transformer-to-RWKV 蒸留のような変換パイプラインの観点で、KVM の追加パラメータの少なさは破格。蒸留の adapter としての扱いやすさは未踏領域がある
  • RWKV-7 や Kimi Delta Attention の delta rule との ハイブリッドが future work として論文中で明示されている
  • 「KV キャッシュ問題」に対して Linear RNN を作らない方向の真面目な解として、Production 実装に最も近いポジションにいる

長文 LLM の Prefill / Decode を実用化する道は、これまで「linear attention 系の蒸留」と「sparse attention + 圧縮」の二択でしたが、KVM はそのどちらにも収まらない第三の道として面白いと思います。

詳細は論文と公開コードをぜひ。