← 기사 목록
日本語https://zenn.dev/topics/llm/feed

論文メモ:BERTからEmbeddingを整理する

추출된 키워드

38
Embedding·5BERT·5Transformer Encoder·4文脈化Embedding·4contextualized representation·4token embedding·4segment embedding·4position embedding·4Bidirectional Encoder Representations from Transformers·4Masked Language Modeling·4Next Sentence Prediction·4BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding·3attention mask·3Fine-tuning·3token ID·3tokenizer·3Self-Attention·3MLM·3静的Embedding·3NSP·3MASK·3SEP·3CLS·3WordPiece·3nn.LayerNorm·2nn.Dropout·2nn.Embedding·2torch·2GPT·2Kristina Toutanova·2Kenton Lee·2Ming-Wei Chang·2Jacob Devlin·2padding token·2Sentence-BERT·2RAG·2GloVe·2word2vec·2

원문

11,686
論文メモ:BERTからEmbeddingを整理する

論文メモ:BERTからEmbeddingを整理する

はじめに

この記事は、以下の論文を読んだ技術メモです。

  • 論文タイトル: BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
  • 論文リンク: https://arxiv.org/abs/1810.04805
  • 著者: Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova
  • 初版公開日: 2018年10月11日

詳細な背景説明、BERTの事前学習、文脈化Embeddingの整理は個人ブログ側にまとめています。

👉 完全版はこちら:Embeddingとは?BERT論文から単語・文章をベクトルで表す仕組みを解説

3行まとめ

文章がTokenization、token ID、Embedding、Transformer Encoderを通って文脈化Embeddingになる流れ

  • Embedding(埋め込み)は、token IDをニューラルネットワークが扱える連続ベクトルへ変換する層です。
  • BERTはtoken embedding、segment embedding、position embeddingを足し合わせてTransformer Encoderへ入力します。
  • BERTで重要なのは、入力Embeddingそのものより、左右の文脈を反映したcontextualized representation(文脈化表現)です。

何の論文か

BERT論文は、Transformer Encoderを使って双方向の言語表現を事前学習する手法を提案した論文です。

BERTは、Bidirectional Encoder Representations from Transformersの略です。

名前の通り、DecoderではなくEncoder側のTransformerを使い、文の左側と右側の両方を見ながら表現を作ります。

項目内容
モデルTransformer Encoder
tokenizerWordPiece
代表的な特殊token
[CLS]
,
[SEP]
,
[MASK]
事前学習目的Masked Language Modeling, Next Sentence Prediction
主な用途分類、抽出、質問応答、文ペア判定

LLMシリーズの流れで見ると、前回のTokenization編で得たtoken ID列が、今回のEmbedding層に渡されます。

text -> token -> token ID -> embedding vector -> Transformer

Embeddingを数式で見る

BERTの入力表現がtoken embedding、segment embedding、position embeddingの和で作られることを示す図

語彙サイズを

token ID

これだけなら、Embeddingは「IDをベクトルに引く表」です。

ただし、BERTではこのtoken embeddingに、segment embeddingとposition embeddingを足します。

成分役割
どのtokenかを表す
文A/Bのどちらに属するかを表す
系列内の位置を表す

TransformerのSelf-Attention(系列内のtoken同士の関係を見る仕組み)は、単体では語順を直接持ちません。

そのため、position embeddingを足して、tokenの順序をモデルへ渡します。

BERTの入力表現

BERTの入力は、単文でも文ペアでも同じ形式にそろえられます。

文ペアの場合は次のような並びです。

[CLS] sentence A [SEP] sentence B [SEP]

segment embeddingは、sentence Aとsentence Bを区別するために使われます。

tokens:   [CLS] my dog is cute [SEP] he likes playing [SEP]
segment:    A   A  A   A   A    A    B    B      B     B
position:   0   1  2   3   4    5    6    7      8     9
特殊token使い方
[CLS]
分類用の代表位置として使う
[SEP]
文の終端や文ペアの境界に置く
[MASK]
MLMで予測対象を隠す

実装時には、token ID、segment ID、position IDを同じshapeで用意し、各Embeddingを足し合わせます。

静的Embeddingとの違い

静的EmbeddingとBERTの文脈化Embeddingの違いをbankの例で比較した図

BERTのEmbeddingを理解するとき、入力Embeddingと出力表現を分けると混乱しにくいです。

観点静的EmbeddingBERTの文脈化表現
代表例word2vec, GloVeBERTの各層出力
同じ単語の表現原則同じ文脈で変わる
多義語意味が1ベクトルに混ざりやすい周囲のtokenに応じて分かれやすい
主な用途類似語、初期特徴量分類、抽出、QA、再ランキング

たとえば

bank
は、金融機関の意味でも川岸の意味でも使われます。

BERTでは、入力のtoken embeddingは同じでも、Transformer層で周囲の語と相互作用した後の表現は変わります。

これがcontextualized representation(文脈化表現)です。

Masked Language Modeling

BERTが双方向に文脈を見るためには、通常の左から右の次token予測では不十分です。

全文を見せたまま予測すると、答えのtoken自身が見えてしまうからです。

そこでBERTは、Masked Language Modeling(入力の一部を隠し、元のtokenを当てる事前学習タスク)を使います。

論文では、入力tokenの15%を予測対象にします。

処理割合意図
[MASK]
に置換
80%tokenを明示的に隠す
ランダムtokenに置換10%ノイズに頑健にする
そのまま残す10%事前学習とFine-tuningの入力差を弱める

損失は、mask対象になった位置だけで計算します。

この学習により、BERTは左右の文脈を使って欠けたtokenを推定する表現を獲得します。

Next Sentence Prediction

BERT論文では、Next Sentence Prediction(2文が連続しているかを判定する事前学習タスク)も使われています。

文Aと文Bを入力し、50%は実際に連続する文、50%はランダムな文Bにします。

分類には

[CLS]
の最終表現を使います。
学習目的何を学ぶか
MLMtoken周辺の双方向文脈
NSP文ペアの関係

後続研究ではNSPの必要性が再検討されています。

そのため、ここでは「BERT論文時点の設計」として理解するのがよさそうです。

実装者視点で気になる点

tokenizerとEmbeddingはセットで管理する

Embedding行列の行はtoken IDに対応します。

そのため、tokenizerを差し替えると、同じIDが別のtokenを指す可能性があります。

Fine-tuning済みモデルでは、tokenizer、Embedding、出力層をセットで保存・配布する必要があります。

[CLS]
を万能な文Embeddingと思わない

BERTの

[CLS]
表現は分類タスクには便利です。

ただし、類似度検索やRAG(検索で外部文書を補う生成方式)でそのまま最適とは限りません。

文検索では、Sentence-BERTのように文類似度向けに学習されたモデルを使うことが多いです。

Paddingとattention maskを合わせる

batch処理では、系列長をそろえるためにpadding tokenを入れます。

このとき、padding位置までAttentionしてしまうと不要な情報が混ざります。

実装では、attention maskでpadding位置を無視する必要があります。

実装例:BERT風Embedding層

以下は、token embedding、segment embedding、position embeddingを足し合わせる最小例です。

import logging

import torch
from torch import Tensor, nn

logger = logging.getLogger(__name__)


class BertStyleInputEmbedding(nn.Module):
    """Build BERT-style input embeddings from token, segment, and position IDs.

    Args:
        vocab_size: Number of tokens in the tokenizer vocabulary.
        hidden_size: Embedding dimension used by the Transformer encoder.
        max_position_embeddings: Maximum sequence length supported by the model.
        segment_vocab_size: Number of segment IDs. BERT uses two for sentence A/B.

    Returns:
        A module that maps ID tensors to summed embedding tensors.

    Raises:
        ValueError: If one of the size arguments is not positive.

    Examples:
        >>> module = BertStyleInputEmbedding(30522, 768, 512, 2)
        >>> token_ids = torch.tensor([[101, 2023, 2003, 102]])
        >>> segment_ids = torch.zeros_like(token_ids)
        >>> module(token_ids, segment_ids).shape
        torch.Size([1, 4, 768])
    """

    def __init__(
        self,
        vocab_size: int,
        hidden_size: int,
        max_position_embeddings: int,
        segment_vocab_size: int = 2,
    ) -> None:
        super().__init__()
        if vocab_size <= 0:
            raise ValueError("vocab_size must be positive")
        if hidden_size <= 0:
            raise ValueError("hidden_size must be positive")
        if max_position_embeddings <= 0:
            raise ValueError("max_position_embeddings must be positive")
        if segment_vocab_size <= 0:
            raise ValueError("segment_vocab_size must be positive")

        self.token_embeddings = nn.Embedding(vocab_size, hidden_size)
        self.segment_embeddings = nn.Embedding(segment_vocab_size, hidden_size)
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        self.layer_norm = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(0.1)

    def forward(self, token_ids: Tensor, segment_ids: Tensor | None = None) -> Tensor:
        """Return summed BERT-style embeddings.

        Args:
            token_ids: Tensor shaped `(batch, seq_len)` containing tokenizer IDs.
            segment_ids: Optional tensor shaped like `token_ids`. If omitted, all tokens use segment 0.

        Returns:
            Tensor shaped `(batch, seq_len, hidden_size)`.

        Raises:
            ValueError: If `token_ids` is not 2D or if `segment_ids` has a mismatched shape.

        Examples:
            >>> module = BertStyleInputEmbedding(100, 16, 32)
            >>> ids = torch.tensor([[1, 2, 3]])
            >>> module(ids).shape
            torch.Size([1, 3, 16])
        """
        if token_ids.ndim != 2:
            raise ValueError("token_ids must have shape (batch, seq_len)")

        batch_size, seq_len = token_ids.shape
        if segment_ids is None:
            segment_ids = torch.zeros_like(token_ids)
        if segment_ids.shape != token_ids.shape:
            raise ValueError("segment_ids must have the same shape as token_ids")

        position_ids = torch.arange(seq_len, device=token_ids.device).unsqueeze(0)
        position_ids = position_ids.expand(batch_size, seq_len)

        logger.debug("build embeddings: batch=%d seq_len=%d", batch_size, seq_len)
        # BERTの入力仕様にそろえるため、3種類のID情報を同じhidden_sizeで加算する。
        embeddings = (
            self.token_embeddings(token_ids)
            + self.segment_embeddings(segment_ids)
            + self.position_embeddings(position_ids)
        )
        logger.info("created input embeddings: shape=%s", tuple(embeddings.shape))
        return self.dropout(self.layer_norm(embeddings))

よくある誤解

誤解正確な見方
Embeddingは単語の意味辞書である学習されたベクトルであり、意味は文脈やタスクの中で現れる
BERTの入力Embeddingだけで文脈が分かる文脈化されるのはTransformer層を通過した後
[CLS]
は常に最高の文Embeddingである
分類には便利だが、検索には専用モデルが有利なことが多い
BERTはGPTと同じ生成モデルであるBERTはEncoder型で、理解・分類・抽出に向く

個人的な所感

Embeddingは「単語をベクトルにする処理」と説明されがちですが、BERTを読むと、それだけでは足りないと分かります。

実務で効くのは、tokenizer、Embedding行列、position、attention mask、Fine-tuning後の出力層が一体で動くという見方です。

特にRAGや類似度検索を扱うときは、「BERTの出力を使えば何でも良いEmbeddingになる」とは考えず、どの学習目的で作られた表現なのかを見る必要があります。

詳細版

より詳しい解説、BERTの事前学習目的、Embedding行列のサイズ感、関連技術との比較は個人ブログにまとめています。

👉 完全版はこちら:Embeddingとは?BERT論文から単語・文章をベクトルで表す仕組みを解説