メインコンテンツまでスキップ

HuggingFace Transformers形式のチェックポイントの行列を比較する

本説明書では、HuggingFace Transformers形式チェックポイントの行列比較について解説します。

前提条件

本手順は以下の構成を前提としています。

比較スクリプトの使用方法

比較スクリプトscripts/compare_weights.pyによって、二つのモデルのstate_dictを比較することができます。 使い方は以下です。

python3 compare_weights.py <Reference Checkpoint> <Test Checkpoint>
# 参照元となるチェックポイントを第一引数に、テスト対象となるチェックポイントを第二引数に指定

なお、HuggingFace Transformersライブラリでは、wrapperによって自動でkeyの不一致の補完や、定数テンソルの再計算が行われる場合があるため、state_dictは完全一致している必要はありません。

使用方法の具体的な例

具体的な例として、チェックポイントの形式変換(Megatron-LM, HuggingFace Transformers)で得られたチェックポイントが変換前後で変わっていないことを確認する手順を紹介します。

Llama2の形式変換確認

行列(state_dict)が変換前後で変わっていないことを確認します。 なお、比較できるのは

  • HuggingFaceからダウンロードしたLlama-2-7b-hf
  • HuggingFaceからダウンロードしたLlama-2-7b-hfをMegatron-LM形式に変換したあと、再度HuggingFace Transformers形式に戻したチェックポイント

です。 HuggingFace Transformers形式のチェックポイントとMegatron-LM形式のチェックポイントの比較ではない点にご注意ください。

$ . ../.venv/bin/activate
$ python3 ../scripts/compare_weights.py \
Llama-2-7b-hf \
Llama-2-7b-mega-hf
Load Llama-2-7b-hf: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:06<00:00, 3.01s/it]
Load Llama-2-7b-mega-hf: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00, 3.09s/it]
Compare: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 291/291 [00:37<00:00, 7.75it/s]
Tensors only in the base model
- model.layers.0.self_attn.rotary_emb.inv_freq
- model.layers.1.self_attn.rotary_emb.inv_freq

~~(途中省略)~~

- model.layers.8.self_attn.rotary_emb.inv_freq
- model.layers.9.self_attn.rotary_emb.inv_freq
Tensors only in the target model
Nothing
Shape mismatched tensors
Nothing
Value mismatched tensors
Nothing

Total tensors: 323
Tensors only in the base model: 32
Tensors only in the target model: 0
Shape mismatched tensors: 0
Value mismatched tensors: 0

上記のように表示されれば、変換は正常に行われています。

self_attn.rotary_emb.inv_freqについて補足

Tensors only in the base modelとしてmodel.layers.XXXX.self_attn.rotary_emb.inv_freqが出力されますが、こちらは問題ありません。

model.layers.XXXX.self_attn.rotary_emb.inv_freqは、Position Embeddingの層であり、再算出可能なものなのでチェックポイントに含める必要はありません。

Llama 3の形式変換の確認

Llama 3を比較します。

$ . ../.venv/bin/activate
$ python3 ../scripts/compare_weights.py \
Meta-Llama-3-8B \
Meta-Llama-3-8B-mega-hf
Load Meta-Llama-3-8B: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00, 1.01it/s]
Load Meta-Llama-3-8B-mega-hf: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00, 3.46s/it]
Compare: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 291/291 [00:40<00:00, 7.20it/s]
Tensors only in the base model
Nothing
Tensors only in the target model
Nothing
Shape mismatched tensors
Nothing
Value mismatched tensors
Nothing

Total tensors: 291
Tensors only in the base model: 0
Tensors only in the target model: 0
Shape mismatched tensors: 0
Value mismatched tensors: 0

上記のように表示されれば、変換は正常に行われています。

GPT-2の形式変換の確認

GPT-2を比較します。

$ . ../.venv/bin/activate
$ python3 ../scripts/compare_weights.py \
--base-prefix transformer. \
gpt2-medium \
gpt2-medium-mega
Load gpt2-medium: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 183.31it/s]
Load gpt2-medium-mega: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 9.42it/s]
Compare: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 316/316 [00:04<00:00, 75.20it/s]
Tensors only in the base model
Nothing
Tensors only in the target model
- lm_head.weight
- transformer.h.0.attn.masked_bias
- transformer.h.1.attn.masked_bias

~~(途中省略)~~

- transformer.h.8.attn.masked_bias
- transformer.h.9.attn.masked_bias
Shape mismatched tensors
- transformer.wte.weight
Value mismatched tensors
Nothing

Total tensors: 341
Tensors only in the base model: 0
Tensors only in the target model: 25
Shape mismatched tensors: 1
Value mismatched tensors: 0

上記のように表示されれば、変換は正常に行われています。

transformer.wte.weightについて補足

Shape mismatched tensorsとしてtransformer.wte.weightが出力されますが、こちらは問題ありません。

HuggingFace Transformers形式からMegatron-LM形式へ変換する際に、Vocab sizeが変わったことが原因です。(いくつかのdummy tokenが追加されました)

lm_head.weightについて補足

Tensors only in the target modelとしてlm_head.weightが出力されますが、こちらは問題ありません。

これはMegatron-LM形式からHuggingFace Transformers形式へ変換する際に追加されるTensorで、元のモデルに含まれている必要はありません。

attn.masked_biasについて補足

Tensors only in the target modelとしてtransformer.h.XXXX.attn.masked_biasが出力されますが、こちらは問題ありません。

これらはMegatron-LM形式からHuggingFace Transformers形式へ変換する際に追加されるダミーTensorで、元のモデルに含まれている必要はありません。

state_dictkeyの差分について補足

/openai-community/gpt2-mediumモデルのstate_dictkeyは以下のようになっています。

[
"h.0.ln_1.weight",
"h.0.ln_1.bias",
"h.0.attn.bias",
"h.0.attn.c_attn.weight",
"h.0.attn.c_attn.bias",
"h.0.attn.c_proj.weight",
"h.0.attn.c_proj.bias",
"h.0.ln_2.weight",
"h.0.ln_2.bias",
"h.0.mlp.c_fc.weight",
"h.0.mlp.c_fc.bias",
"h.0.mlp.c_proj.weight",
"h.0.mlp.c_proj.bias",
...
]

一方で、Megatron-LM形式から、HuggingFace Transformers形式への変換後は以下のようなkeyで出力されます。違いは接頭語transformer.の有無です。

[
"transformer.h.0.ln_1.weight",
"transformer.h.0.ln_1.bias",
"transformer.h.0.attn.bias",
"transformer.h.0.attn.masked_bias",
"transformer.h.0.attn.c_attn.weight",
"transformer.h.0.attn.c_attn.bias",
"transformer.h.0.attn.c_proj.weight",
"transformer.h.0.attn.c_proj.bias",
"transformer.h.0.ln_2.weight",
"transformer.h.0.ln_2.bias",
"transformer.h.0.mlp.c_fc.weight",
"transformer.h.0.mlp.c_fc.bias",
"transformer.h.0.mlp.c_proj.weight",
"transformer.h.0.mlp.c_proj.bias",
...
]

前者を「フォーマットA」、後者を「フォーマットB」とします。

transformersにある変換スクリプトでは、フォーマットBにのみ対応しています。 そのためチェックポイントの形式変換(Megatron-LM, HuggingFace Transformers)の「HuggingFace Transformers形式からMegatron-LM形式へ変換する」では、修正スクリプト(scripts/pretrain_megatron_lm/convert_checkpoint/fix_statedict_key_prefix_gpt2.py)を適用し、/openai-community/gpt2-mediumモデル(フォーマットA)をフォーマットBに変換しています。

比較の際は--base-prefix transformer.の指定を追加することで、比較元のkeyをフォーマットB相当にしたうえで比較しています。