← 기사 목록
日本語https://qiita.com/tags/llm/feed

【コード解説編】論理ゲートで Transformer を超える実装 (PPL 4.73)

추출된 키워드

33
Transformer·5PPL·5論理ゲート·5DLGN·4HBA·4知識蒸留·4微分可能な論理ゲート層·4Boolean Router·4量子化誤差·3hard collapse·3softmax·3エッジ推論·3Speculative decoding·3boolean-attention·3warm_hold 温度スケジュール·3Hard threshold calibration·3Early stopping·3spectral norm·3BooleanAttentionLayer·3float Value·3LoopedDLGN·3Distilling the Knowledge in a Neural Network·2Deep Differentiable Logic Gate Networks·2KL·2CE·2ChatHBA·2Universal Transformer·2RTX 4060 8GB·2PyTorch 2.1+·2Python 3.10+·2電力制約環境·2born-again networks·2リプシッツ性·2

원문

10,734
【コード解説編】論理ゲートで Transformer を超える実装 (PPL 4.73)

【コード解説編】論理ゲートで Transformer を超える実装 (PPL 4.73)

論理ゲートだけで言語モデルを作って Transformer (PPL 4.86) を 0.13 上回った実装の解説です。
DLGN, HBA, 知識蒸留の 実コード を中心に、再現に必要な要点をまとめます。

物語 / 失敗譚は 物語編 を参照してください。

環境

  • Python 3.10+, PyTorch 2.1+
  • RTX 4060 8GB(CPU でも動作可、学習時間は伸びます)
git clone https://github.com/karumaru-kakikukekodoumei/boolean-attention.git
cd boolean-attention
pip install -r requirements.txt

Step 1. 微分可能な論理ゲート層 (DLGN)

論理ゲートは離散関数で勾配が流れません。2 入力ブール関数は $2^4 = 16$ 種類しかないという事実を使い、16 ゲートを softmax で混合 することで勾配を流します。

import torch
import torch.nn as nn
import torch.nn.functional as F

def all_gates(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """16 種類の論理ゲートを (..., 16) 次元で返す。a, b は [0,1] 連続値想定。"""
    return torch.stack([
        torch.zeros_like(a), a * b, a * (1 - b), a,
        (1 - a) * b, b, a + b - 2*a*b, a + b - a*b,
        1 - (a + b - a*b), 1 - (a + b - 2*a*b), 1 - b, a + (1-b) - a*(1-b),
        1 - a, (1-a) + b - (1-a)*b, 1 - a*b, torch.ones_like(a),
    ], dim=-1)


class DLGNLayer(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, tau: float = 1.0):
        super().__init__()
        self.in_dim, self.out_dim = in_dim, out_dim
        self.tau = tau
        self.pair_a = nn.Parameter(torch.randn(out_dim, in_dim) * 0.5)
        self.pair_b = nn.Parameter(torch.randn(out_dim, in_dim) * 0.5)
        self.gate_logits = nn.Parameter(torch.randn(out_dim, 16) * 0.1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        wa = F.softmax(self.pair_a / self.tau, dim=-1)
        wb = F.softmax(self.pair_b / self.tau, dim=-1)
        a = x @ wa.T
        b = x @ wb.T
        gates = all_gates(a, b)
        alpha = F.softmax(self.gate_logits / self.tau, dim=-1)
        return (gates * alpha).sum(dim=-1)

学習が終わったら

argmax(self.gate_logits)
で 1 個に確定すれば、純粋なブーリアン回路に戻ります(hard collapse)。
python src/dlgn_charlm.py
結果Soft PPLHard PPL
DLGN flat (4 層)11.8315.16
Transformer (比較)4.86

論理回路で言語学習はできた、ただし TF には届かず。次の設計へ。

Step 2. (失敗例) LoopedDLGN

DLGN を T 回繰り返す Universal Transformer 風設計。撃沈例として参考までに残します。

python src/looped_dlgn_charlm.py --max-iters=8
Soft PPLHard PPL
v1 (PE なし)11.05754.31

ハードコラプス時に PPL が 754 まで暴騰。反復ごとの量子化誤差が

ε_total ≈ Σ_t ‖f_hard(x⁽ᵗ⁾) - f_soft(x⁽ᵗ⁾)‖

として深さ方向に蓄積するためです。反復系は Boolean と相性が悪い という構造的な学び。

Step 3. HBA — Boolean Router + float Value

Attention のルーターだけを Boolean 化、値集約は float のまま。

import torch.nn.utils as nn_utils

class BooleanAttentionLayer(nn.Module):
    def __init__(self, d: int, tau: float = 0.1):
        super().__init__()
        self.q = nn.Linear(d, d)
        self.k = nn.Linear(d, d)
        self.v = nn.Linear(d, d)
        # bilinear router (Lipschitz 制約に spectral norm)
        self.w_router = nn_utils.spectral_norm(nn.Linear(d, d, bias=False))
        self.tau = tau

    def forward(self, x: torch.Tensor, causal_mask: torch.Tensor) -> torch.Tensor:
        Q, K, V = self.q(x), self.k(x), self.v(x)  # [B, T, d]
        # Q · W · K^T
        logits = Q @ self.w_router.weight @ K.transpose(-1, -2)  # [B, T, T]
        logits = logits.masked_fill(causal_mask, float("-inf"))

        if self.training:
            router = torch.tanh(logits / self.tau)  # 連続近似
        else:
            router = torch.sign(logits)  # 推論時離散

        attn = F.softmax(router / self.tau, dim=-1)
        return attn @ V  # V は float のまま

ポイント:

  • ルーターは離散値 (-1, +1) に確定
  • 値の集約は float なので 量子化誤差が深さ方向に伝播しない
  • spectral norm で router 重みのリプシッツ性を担保(発散防止)
python src/hba_charlm.py --epochs=60
HBA v1Best PPLFinal PPL
Ep12 / Ep605.40 9.75

TF (4.86) まで 0.54 差 まで肉薄。ただし過学習が課題。

Step 4. HBA v2 — 安定化 4 点セット

# 1. Best checkpoint
if val_ppl < best_ppl:
    best_ppl = val_ppl
    best_state = {k: v.clone() for k, v in model.state_dict().items()}
    best_epoch = ep
    bad_count = 0
else:
    bad_count += 1

# 2. Early stopping
if bad_count >= patience:
    print(f"early stop at ep {ep}")
    break

# 3. Hard threshold calibration
def calibrate_hard_threshold(model, val_loader, taus=(0.05, 0.08, 0.1, 0.15, 0.2)):
    best = (None, float("inf"))
    for tau in taus:
        model.set_inference_tau(tau)
        ppl = evaluate(model, val_loader)
        if ppl < best[1]:
            best = (tau, ppl)
    return best

# 4. warm_hold 温度スケジュール
def temperature_schedule(epoch: int) -> float:
    if epoch < 5:  return 1.0           # warm: 柔らかく
    if epoch < 15: return 0.5           # hold: 中間
    return max(0.1, 0.5 * 0.95**(epoch - 15))  # decay
python src/hba_charlm.py --epochs=40 --early-stop --calibrate
HBA v2Soft PPLHard PPLTrain time
結果5.32 6.54 4.7 min

LoopedDLGN の Hard PPL 754 と比べて 115 倍の改善

Step 5. 知識蒸留で TF 越え

教師 (TF) → 生徒 (HBA v2 構造) に蒸留。ハイブリッド損失で CE と KL を併用。

def distill_loss(
    student_logits: torch.Tensor,
    teacher_logits: torch.Tensor,
    targets: torch.Tensor,
    alpha: float = 0.3,
    T: float = 8.0,
) -> torch.Tensor:
    ce = F.cross_entropy(student_logits, targets)
    kl = F.kl_div(
        F.log_softmax(student_logits / T, dim=-1),
        F.softmax(teacher_logits / T, dim=-1),
        reduction="batchmean",
    )
    return alpha * ce + (1 - alpha) * (T * T) * kl
python src/hba_distill_charlm.py --epochs=30 --teacher-ckpt=teacher_tf.pt
Soft PPL
Teacher (TF)4.86
Student (HBA distilled) 4.73
逆転幅-0.13

論理回路ベースのモデルが Transformer を逆転。born-again networks (Furlanello et al. 2018) として知られる現象です。

ハマりどころ: 温度整合性のバグ

初期実装で訓練 eval と最終比較で温度 $\tau$ が違っていて、PPL が 4.71 vs 8.72 と乖離するバグに数日とられました。

Bad

# 訓練 eval は固定 tau=1.0、最終比較は final_tau=0.1 と別物
def evaluate(model, loader):
    model.set_inference_tau(1.0)
    ...

# 最終比較
model.set_inference_tau(0.1)  # ← 急に厳しい τ にする
final_ppl = evaluate(model, test_loader)

Good

# 訓練 eval は「現在のスケジューラ τ」で評価
def evaluate(model, loader, tau: float):
    model.set_inference_tau(tau)
    ...

# 最終比較は best epoch 時点の実 τ を逆引き
best_tau = temperature_schedule(best_epoch)
model.load_state_dict(best_state)
model.set_inference_tau(best_tau)
final_ppl = evaluate(model, test_loader, best_tau)

これで再現性のある PPL 4.73 が出るようになりました。

再現手順まとめ

# 1. ベースライン
python src/dlgn_charlm.py        # PPL 11.83

# 2. 失敗パス (任意)
python src/looped_dlgn_charlm.py # PPL 754 で爆死を体感

# 3. HBA v2
python src/hba_charlm.py --early-stop --calibrate  # PPL 5.32

# 4. 蒸留 (要 teacher checkpoint)
python src/train_teacher.py
python src/hba_distill_charlm.py  # PPL 4.73 → TF 越え

学習ログは

results/
、学習済み ChatHBA は
checkpoints/
にあります。

応用先

HBA は 特化用途で実用性あり という結論:

  • Speculative decoding のドラフトモデル— 大きい教師モデルとの並用で軽量ルーティング
  • エッジ推論— CPU/MCU で動く軽量 LM
  • 電力制約環境— GPU を持たないシステム

まとめ

Stepやったこと結果
1DLGN 層16 ゲート softmax 混合で勾配 OK
2DLGN flatPPL 11.83 (TF 4.86 未達)
3LoopedDLGNPPL 754 で構造的に詰む
4HBA v1PPL 5.40 (TF まで 0.54 差)
5HBA v2PPL 5.32 / Hard 6.54
6知識蒸留PPL 4.73 (TF 4.86 → 0.13 上回る)

リンク