rohaniのブログ

ゆるっと自然言語処理奴。ときどき工作系バイト。

サブワード分割手法 BPE(Sennrich, 2016) をPythonで実装してみた

輪講でも度々登場するBPE(Sennrich, 2016)を勉強のために書いてみた。

論文に乗っている Algorithm 1 Learn BPE operations そのまま。→

import re, collections

def get_stats(vocab):
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
    return pairs

def merge_vocab(pair, v_in):
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!<\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

vocab = {'l o w </w>':5, 'l o w e s t </w>':2, 
         'n e w e r </w>':6, 'w i d e r </w>':3}

num_merges = 10
for i in range(num_merges):
    pairs = get_stats(vocab)
    print('pairs',pairs)
    best = max(pairs, key=pairs.get)
    vocab = merge_vocab(best, vocab)
    print('best',best)

print('vocab',vocab

文に対してBPEをかけるように変更。

import re, collections

def get_stats(vocab):
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
    return pairs

def merge_vocab(pair, v_in):
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!<\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

def txt2voc(text):
    wl = text.split()
    vocab = {}
    for w in wl:
        w = ' '.join(w)+' </w>'
        vocab[w] = vocab.get(w,0) + 1
    return vocab

def bpe(num_merges, text):
    vocab = txt2voc(text)
    pairs_lst = []
    for i in range(num_merges):
        pairs = get_stats(vocab)
        pairs_lst.append(pairs)
        best = max(pairs, key=pairs.get)
        vocab = merge_vocab(best, vocab)
    return (pairs_lst, vocab)

if __name__=='__main__':
    num_merges = 10
    text='low low low low low lowest lowest newer newer newer newer newer newer wider wider wider'
    (pairs_lst, vocab) = bpe(num_merges, text)
    for pairs in pairs_lst:
        print(pairs)
    print('vocab',vocab)

BPE

とても簡単、けれど強力なサブワード分割手法。

Sennrich Rico, Haddow Barry, Birch Alexandra, 2016, "Neural Machine Translation of Rare Words with Subword Units", Association for Computational Linguistics, pages 1715--1725"

感想

だんだん頻出する文字列がまとめられていく様子を、もっとかっこよく可視化したいな。

英語は単語がスペース区切りで独立しているので良いけれど、日本語だと前処理として形態素解析が必要。 ...Ngramからのスタートでやったらどうなるだろう?