pytorchでBERTの日本語学習済みモデルを利用する - 文章埋め込み編
2019-06-05

https://www.pexels.com/photo/red-and-white-umbrella-during-night-time-39079/

概要

BERT (Bidirectional Encoder Representations from Transformers) は、NAACL2019で論文が発表される前から大きな注目を浴びていた強力な言語モデルです。これまで提案されてきたELMoやOpenAI-GPTと比較して、双方向コンテキストを同時に学習するモデルを提案し、大規模コーパスを用いた事前学習とタスク固有のfine-tuningを組み合わせることで、各種タスクでSOTAを達成しました。

そのように事前学習によって強力な言語モデルを獲得しているBERTですが、今回は日本語の学習済みBERTモデルを利用して、文章埋め込み (Sentence Embedding) を計算してみようと思います。


環境

今回は京都大学の黒橋・河原研究室が公開している「BERT日本語Pretrainedモデル」を利用します。

BERTの実装は、pytorchで書かれたpytorch-pretrained-BERTがベースになります。また形態素解析器は、学習済みモデルに合わせるためJUMAN++を利用します。

方法

今回はBertWithJumanModelという、トークナイズとBERTによる推論を行うクラスを自作しています。ソースコード自体は下記レポジトリにあり、また各ステップでの計算方法を本記事の後半で解説しています。

In []: from bert_juman import BertWithJumanModel

In []: bert = BertWithJumanModel("/path/to/Japanese_L-12_H-768_A-12_E-30_BPE")

In []: bert.get_sentence_embedding("吾輩は猫である。")
Out[]:
array([ 2.22642735e-01, -2.40221739e-01,  1.09303640e-02, -1.02307117e+00,
        1.78834641e+00, -2.73566216e-01, -1.57942638e-01, -7.98571169e-01,
       -2.77438164e-02, -8.05811465e-01,  3.46736580e-01, -7.20409870e-01,
        1.03382647e-01, -5.33944130e-01, -3.25344890e-01, -1.02880754e-01,
        2.26500735e-01, -8.97880018e-01,  2.52314955e-01, -7.09809303e-01,
[...]        

これでBERTによる文章埋め込みのベクトルが得られました。あとは、後続のタスクに利用したり、文章ベクトルとして類似度計算などに利用できます。

また、BERTの隠れ層の位置や、プーリングの計算方法も選択できるようにしています。このあたりの設計はhanxiao/bert-as-service を参考にしています。

In []: bert.get_sentence_embedding("吾輩は猫である。",
   ...:                             pooling_layer=-1,
   ...:                             pooling_strategy="REDUCE_MAX")
   ...:
Out[]:
array([ 1.2089624 ,  0.6267309 ,  0.7243419 , -0.12712255,  1.8050476 ,
        0.43929055,  0.605848  ,  0.5058241 ,  0.8335829 , -0.26000524,
[...]        

解説

上記のBertWithJumanModelクラスの内部を順に解説していきます。そのまま上から実行しても動作するように記載しているので、途中の動作が気になる方は試してみて下さい。

1. 学習済みモデルをpytorch-pretrained-bertで読み込む

まず始めに配布されている学習済みモデルなどをpytorch-pretrained-BERTから読み込みます。

import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel

model = BertModel.from_pretrained("/path/to/Japanese_L-12_H-768_A-12_E-30_BPE/")
bert_tokenizer = BertTokenizer("/path/to/Japanese_L-12_H-768_A-12_E-30_BPE/vocab.txt",
                               do_lower_case=False, do_basic_tokenize=False)

モデルは黒橋・河原研究室の配布サイトからダウンロードし解凍します。BertModelfrom_pretrained()で解凍先のパスを指定することで、モデルをロードすることができます。必要なファイルはpytorch_model.binvocab.txtのみです。

なお、モデル配布ページではpytorch-pretrained-BERT内のtokenization.pyの特定行をコメントアウトするように指示されていますが、BertTokenizer()で引数をdo_basic_tokenize=Falseとすれば対応は不要です。


2. テキストを分かち書きして対応するid列に変換する

次に、与えられたテキストを分かち書きしてトークンに分割したのちに、対応するidに変換します。pytorch-pretrained-BERTは日本語の分かち書きに対応していないため、前者はJuman++によるトークナイザを自作し、後者はBertTokenizer()を利用します。

# Jumanによるトークナイザ
from pyknp import Juman

class JumanTokenizer():
    def __init__(self):
        self.juman = Juman()

    def tokenize(self, text):
        result = self.juman.analysis(text)
        return [mrph.midasi for mrph in result.mrph_list()]

分かち書きしたトークン列は、英語のトークナイズに対応できるようにスペース区切りで結合し、最後にBertTokenizer()でid列に変換します。この際に、テキストの最初と最後に[CLS]および[SEP]トークンを付与します。また、学習済みモデルのmax_seq_lengthが128に設定されているため、トークン列の長さを調整しています。

juman_tokenizer = JumanTokenizer()

tokens = juman_tokenizer.tokenize(text)
bert_tokens = bert_tokenizer.tokenize(" ".join(tokens))
ids = bert_tokenizer.convert_tokens_to_ids(["[CLS]"] + bert_tokens[:126] + ["[SEP]"])
tokens_tensor = torch.tensor(ids).reshape(1, -1)

例えば「我輩は猫である。」という文章は、以下のようにトークン化されid列に変換されます。ちなみに我輩という単語は辞書中に存在せずかつサブワードとしても分割できないことから、未知語としてid:1([UNK])に変換されています。

# text
  吾輩は猫である。
# tokens
  ['[CLS]', '吾輩', 'は', '猫', 'である', '。', '[SEP]']
# tokens_tensor
  tensor([[   2,    1,    9, 4817,   32,    7,    3]])

3. BERTのモデルに入力し、特徴ベクトルを得る

id列に変換したベクトルをBERTのモデルに入力し、それぞれの隠れ層から出力される特徴ベクトルを得ます。

model.eval()
with torch.no_grad():
    all_encoder_layers, _ = model(tokens_tensor)

all_encoder_layersには全12層から出力される特徴ベクトルが格納されています。

4. 特徴ベクトルから文章埋め込みを得る

最後に、BERTの隠れ層から得られた特徴ベクトルから文章埋め込みを計算します。BERTでは各入力のトークンに対応するhidden_size次元 (本モデルでは768次元) のベクトルが得られるため、それを文章埋め込みとしての固定次元のベクトルで表現する必要があります。

ここではhanxiao/bert-as-serviceでの計算方法を参考に、SWEMと同じ方法でベクトルを時間方向にaverage-poolingしています。また利用するBERTの隠れ層は、最終層ではなくその一つ前の層を選択しています。

pooling_layer = -2
embedding = all_encoder_layers[pooling_layer].numpy()[0]
np.mean(embedding, axis=0)

hanxiao/bert-as-serviceでは、この計算方法はいくつか選択できるようになっており、本家に習ってBertWithJumanModelのコード内でもREDUCE_MEANREDUCE_MAXなどの方法を指定できるようにしています。これらの手法は基本的にSWEMの概念と同じですので、詳細は「SWEM: 単語埋め込みのみを使うシンプルな文章埋め込み」の記事を参考ください。

参考

このエントリーをはてなブックマークに追加