概要
Flairは、Pytorchで書かれた自然言語処理用のフレームワークです。固有表現抽出や品詞タグ付け、文章分類などの機能を提供しているほか、文章埋め込み (Sentence Embedding) も簡単に計算することができます。以前に本ブログで紹介したSWEMも扱うことができたので、ここで使い方を紹介したいと思います。
記事:SWEM: 単語埋め込みのみを使うシンプルな文章埋め込み - Out-of-the-box
方法
単語ベクトルの読み込み
まずFlairで学習済みの単語埋め込みベクトルを読み込みます。あらかじめ学習済み単語ベクトルのファイルを用意しておく必要はなく、以下のコードを初めて動かす際に自動でウェブからダウンロードされます。日本語の場合は、fastTextが提供しているja-wiki-fasttext-300d-1M
が選択されます。
from flair.embeddings import WordEmbeddings, DocumentPoolEmbeddings, Sentence
from flair.data import build_japanese_tokenizer
ja_embedding = WordEmbeddings("ja")
ここでダウンロードしたファイルは$HOME/.flair/embeddings/
に保存されます。
文章埋め込みの選択
次に、文章埋め込みの手法を選択します。SWEMは各単語ベクトルに対して各種Poolingの操作を行うことで文章埋め込みを計算するため、DocumentPoolEmbeddings()
を利用します。この引数には、さきほど読み込んだWordEmbeddings
のインスタンスを選択します。
document_embeddings = DocumentPoolEmbeddings([ja_embedding])
文章埋め込みを計算する
最後に文章埋め込みを計算します。まず対象となる文章のSentence
オブジェクトを作成し、それを上記で作成したdocument_embeddings.embed()
で埋め込みます。
sentence = Sentence("吾輩は猫である。名前はまだ無い。",
use_tokenizer=build_japanese_tokenizer("MeCab"))
document_embeddings.embed(sentence)
そして、文章埋め込みのベクトルはsentnece.get_embedding()
から取得することができます。
In []: sentence.get_embedding()
Out[]:
tensor([ 2.1660e+00, -1.9450e+00, -1.9782e+00, -1.0372e+01, -7.4274e-01,
-1.6262e+00, 2.3832e+00, 1.3668e+00, 4.2834e+00, -3.4007e+00,
[...]
3.6956e+00, -4.1554e+00, 4.7224e+00, 4.1686e+00, -4.3685e+00],
grad_fn=<CatBackward>)
カスタマイズ
指定した学習済みの単語ベクトルを使う
Flairがデフォルトで指定している学習済みの単語ベクトルではなく、ローカルにあるファイルを指定することもできます。ファイル形式はgensim
のバイナリフォーマットで用意する必要があります。
own_embedding = WordEmbeddings("path/to/own_vector.bin")
flair/CLASSIC_WORD_EMBEDDINGS.md at master · flairNLP/flair
Poolingの方法を変更する
Poolingの方法には、デフォルトのaverage pooling(pooling="mean"
)の他に、max pooling(pooling="max"
)とmin pooling(pooling="min"
)も用意されています。
document_embeddings = DocumentPoolEmbeddings([ja_embedding], pooling="max")
Poolingの方法を組み合わせる
SWEMには、average-poolingとmax poolingを組み合わせたSWEM-concat
という手法があります。flairではStackedEmbeddings()
を使うことで、複数のEmbeddingを組み合わせることができます。
from flair.embeddings import StackedEmbeddings
average_embedding = DocumentPoolEmbeddings([ja_embedding], pooling="mean")
max_embedding = DocumentPoolEmbeddings([ja_embedding], pooling="max")
document_embeddings = StackedEmbeddings([average_embedding,
max_embedding])
document_embeddings.embed(sentence)
In []: sentence.get_embedding()
Out[]:
tensor([ 2.1660e+00, -1.9450e+00, -1.9782e+00, -1.0372e+01, -7.4274e-01,
-1.6262e+00, 2.3832e+00, 1.3668e+00, 4.2834e+00, -3.4007e+00,
[...]
4.5165e+00, -2.9375e+00, 5.7923e+00, 5.0611e+00, -3.1531e+00],
grad_fn=<CatBackward>)
In []: sentence.get_embedding().shape
Out[]: torch.Size([600])