← 기사 목록
日本語https://zenn.dev/topics/llm/feed

LLM解説シリーズ:Self-Attentionを数式と実装から理解する

추출된 키워드

37
LLM·5Scaled Dot-Product Attention·5Transformer·5Self-Attention·5Attention Is All You Need·4Multi-Head Attention·4Query·4Key·4Value·4attention weight·3KV Cache·3causal mask·3Positional Encoding·3RNN·3tensor·3mask·3softmax·3BLEU·2WMT 2014 English-to-French·2WMT 2014 English-to-German·2torch·2自己回帰モデル·2RoPE·2GPT系·2残差接続·2Feed Forward·2Encoder-Decoder構造·2GQA/MQA·2FlashAttention·2Ashish Vaswani·1Noam Shazeer·1Niki Parmar·1Jakob Uszkoreit·1Llion Jones·1Aidan N. Gomez·1Lukasz Kaiser·1Illia Polosukhin·1

원문

10,388
LLM解説シリーズ:Self-Attentionを数式と実装から理解する

LLM解説シリーズ:Self-Attentionを数式と実装から理解する

はじめに

この記事は、以下の論文を読んだ技術メモです。

  • 論文タイトル: Attention Is All You Need
  • 論文リンク: https://arxiv.org/abs/1706.03762
  • 著者: Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin
  • 初版公開日: 2017年6月12日

詳細な背景説明、Q/K/Vの直感、計算量の整理は個人ブログ側にまとめています。

👉 完全版はこちら:Self-Attentionとは?Scaled Dot-Product Attentionを数式と実装でやさしく解説

3行まとめ

Self-Attentionが文中のtoken同士の関係を重みとして計算する流れ

  • Self-Attention(同じ系列内のtoken同士が互いを参照する仕組み)は、TransformerやLLMの中核にある計算です。
  • Scaled Dot-Product Attentionは、、softmax、Valueの重み付き和として実装できます。QK^T / \sqrt{d_k}
  • 実装では、tensor形状、softmax前のmask、の計算量をセットで理解するのが重要です。O(n^2)

何の論文か

Attention Is All You Needは、Transformerを提案した論文です。

前回のTransformer編では、Encoder-Decoder構造、Feed Forward、Positional Encoding、残差接続まで含めた全体像を扱いました。

今回は、その中でもSelf-AttentionとMulti-Head Attentionに絞ります。

元論文のAttentionは次の式で定義されます。

記号役割
Query今のtokenが何を参照したいか
Key各tokenが照合用に持つ手がかり
Valueattention weightで実際に混ぜる情報
Keyの次元数

何が新しいのか

RNNの逐次処理とSelf-Attentionの直接参照の違い

Self-Attentionのポイントは、同じ系列内のtoken同士を直接比較できることです。

RNN(前の状態を次へ渡しながら系列を処理するニューラルネットワーク)では、遠いtokenの情報は複数ステップを通って届きます。

Self-Attentionでは、1層の中で任意のtokenペアのスコアを計算できます。

観点RNNSelf-Attention
処理逐次的層内で並列化しやすい
遠いtoken経路が長くなりやすい直接スコアを計算できる
位置情報処理順に含まれるPositional Encodingなどが必要
弱点長さ方向の逐次性

技術的に面白い点

softmax前にscaleとmaskが入る

Scaled Dot-Product AttentionでQK^T、スケール、softmax、Valueの重み付き和を計算する流れ

スコアは

ただし、次元数が大きいと内積値も大きくなりやすいため、

その後、見てはいけない位置をmaskし、softmaxでattention weightへ変換します。

scores = Q @ K^T
scores = scores / sqrt(d_k)
scores = apply_mask(scores)
weights = softmax(scores)
output = weights @ V

maskをsoftmax前に入れる理由は、不可視位置へ確率質量を流さないためです。

softmax後に単純に0を掛けると、残った重みの合計が1でなくなり、attention weightとしての正規化が崩れます。

softmax前に不可視位置のscoreを非常に小さい値へ落としておけば、見える位置だけで再正規化できます。

GPT系のような自己回帰モデルでは、未来tokenを見ないcausal maskが必要になります。

tensor形状は
(B, H, T, D)
で見る

Self-Attentionで入力XからQ/K/V、scores、outputへ変換されるtensor形状

実装では、batch sizeを

tensor形状意味
Query
(B, H, T, D)
各位置が探す手がかり
Key
(B, H, T, D)
各位置の照合用表現
Value
(B, H, T, D)
実際に混ぜる表現
scores
(B, H, T, T)
全tokenペアの相性
output
(B, H, T, D)
attention後の表現

scoresが

(T, T)
を含むため、系列長が伸びると計算量とメモリが二乗で増えます。

Multi-Head Attentionは複数の見方を並列に持つ

Multi-Head Attentionが複数のheadで異なる関係を並列に見る流れ

Multi-Head Attentionは、複数のAttention headを並列に計算して結合します。

headごとに別の投影行列を持つため、同じtoken列でも異なる表現空間で関係を見られます。

たとえば、あるheadは近いtoken、別のheadは主語と述語、さらに別のheadは指示語の参照を見やすくなる可能性があります。

実装者視点で気になった点

最小実装

import logging
from typing import Optional

import torch
from torch import Tensor

logger = logging.getLogger(__name__)


def scaled_dot_product_attention(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    attention_mask: Optional[Tensor] = None,
) -> Tensor:
    """Compute scaled dot-product attention.

    Args:
        query: Query tensor with shape `(batch, heads, query_length, head_dim)`.
        key: Key tensor with shape `(batch, heads, key_length, head_dim)`.
        value: Value tensor with shape `(batch, heads, key_length, head_dim)`.
        attention_mask: Optional boolean mask broadcastable to
            `(batch, heads, query_length, key_length)`. `True` means visible.

    Returns:
        Tensor with shape `(batch, heads, query_length, head_dim)`.

    Raises:
        ValueError: If key/value lengths or query/key head dimensions differ.

    Example:
        >>> q = torch.randn(2, 8, 4, 64)
        >>> k = torch.randn(2, 8, 4, 64)
        >>> v = torch.randn(2, 8, 4, 64)
        >>> scaled_dot_product_attention(q, k, v).shape
        torch.Size([2, 8, 4, 64])
    """
    if key.size(-2) != value.size(-2):
        logger.error("key_length and value_length must match")
        raise ValueError("key and value must have the same sequence length")

    if query.size(-1) != key.size(-1):
        logger.error("query and key head dimensions must match")
        raise ValueError("query and key must have the same head dimension")

    head_dim: int = query.size(-1)
    scores: Tensor = query @ key.transpose(-2, -1)
    scores = scores / (head_dim**0.5)

    if attention_mask is not None:
        logger.debug("applying attention mask before softmax")
        # softmax前に消すことで、不可視位置へ確率質量が流れないようにするためです。
        scores = scores.masked_fill(~attention_mask, torch.finfo(scores.dtype).min)

    weights: Tensor = torch.softmax(scores, dim=-1)
    logger.debug("computed attention weights", extra={"shape": tuple(weights.shape)})
    return weights @ value

causal mask

def build_causal_mask(sequence_length: int, device: torch.device) -> Tensor:
    """Build a causal attention mask.

    Args:
        sequence_length: Number of tokens in the sequence.
        device: Device where the mask should be allocated.

    Returns:
        Boolean tensor with shape `(1, 1, sequence_length, sequence_length)`.
        `True` means the key position is visible.

    Raises:
        ValueError: If `sequence_length` is less than 1.

    Example:
        >>> build_causal_mask(3, torch.device("cpu"))[0, 0].int()
        tensor([[1, 0, 0],
                [1, 1, 0],
                [1, 1, 1]], dtype=torch.int32)
    """
    if sequence_length < 1:
        logger.error("sequence_length must be positive")
        raise ValueError("sequence_length must be positive")

    token_positions: Tensor = torch.arange(sequence_length, device=device)
    visible: Tensor = token_positions[:, None] >= token_positions[None, :]
    logger.debug("built causal mask", extra={"sequence_length": sequence_length})
    return visible.unsqueeze(0).unsqueeze(0)

実験結果をどう読むか

論文では、Transformer bigがWMT 2014 English-to-Germanで28.4 BLEU、English-to-Frenchで41.8 BLEUを報告しています。

ただし、これはSelf-Attention単体のablationではなく、Transformer全体の結果です。

Self-Attention目線では、次のように読むのがよさそうです。

観点読み方
並列化RNNの逐次処理を避け、学習を並列化しやすくした
長距離依存任意のtokenペアを直接比較できる
表現力複数headで異なる関係を見られる
限界長文では

現代LLMの対話性能をBLEUだけで説明することはできませんが、Self-Attention中心の構造が後のLLMにつながる重要な基礎になったと読めます。

よくある誤解

誤解正確な見方
Self-Attentionだけで語順が分かるPositional EncodingやRoPEなどの位置情報が別途必要
attention weightは完全な説明になる参照傾向は見えるが、因果的説明としては慎重に読む
maskはsoftmax後でよいsoftmax前に不可視位置を極小値にするのが基本
Multi-Headは同じ計算の単純な繰り返しheadごとに別の投影行列を持つ
KV Cacheで学習時の KV Cacheは主に自己回帰推論でKey/Valueを再利用する仕組み

個人的な所感

Self-Attentionは、LLMを「文章を読んでいる」ように見せる重要な部品ですが、実体はかなり素直な行列計算です。

QueryとKeyで参照先を決め、Valueを重み付きで混ぜる。

この見方に慣れると、KV Cache、FlashAttention、GQA/MQAのような後続技術もかなり読みやすくなります。

特に、scoresの形状が

(B, H, T, T)
になることを体で覚えると、長文LLMの難しさが一気に現実味を持ちます。

詳細版

より詳しい背景、Q/K/Vの具体例、計算量、関連技術とのつながりは個人ブログにまとめています。

👉 完全版はこちら:Self-Attentionとは?Scaled Dot-Product Attentionを数式と実装でやさしく解説