10. シンプルなファインチューニング例

  • やりたいこと

  • 方針

    • spacyにより学習済み英語モデル(en_core_web_sm)を用意する。

    • 学習済みモデルを用いて、20 newsgroups textの記事をベクトル化する。

    • 分類学習にはSVM、NNを用いる。

!date
Wed Jun 16 07:34:31 UTC 2021

10.1. 事前学習

10.1.1. 環境構築

  • spacyと学習済みモデルをインストール。

!pip install -U ginza
!python -m spacy download en_core_web_sm
#!python -m spacy download en_core_web_lg
Requirement already up-to-date: ginza in /usr/local/lib/python3.7/dist-packages (4.0.6)
Requirement already satisfied, skipping upgrade: SudachiDict-core>=20200330; python_version >= "3.5" in /usr/local/lib/python3.7/dist-packages (from ginza) (20210608)
Requirement already satisfied, skipping upgrade: spacy<3.0.0,>=2.3.2 in /usr/local/lib/python3.7/dist-packages (from ginza) (2.3.7)
Requirement already satisfied, skipping upgrade: ja-ginza<4.1.0,>=4.0.0 in /usr/local/lib/python3.7/dist-packages (from ginza) (4.0.0)
Requirement already satisfied, skipping upgrade: SudachiPy>=0.4.9; python_version >= "3.5" in /usr/local/lib/python3.7/dist-packages (from ginza) (0.5.2)
Requirement already satisfied, skipping upgrade: srsly<1.1.0,>=1.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (1.0.5)
Requirement already satisfied, skipping upgrade: catalogue<1.1.0,>=0.0.7 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (1.0.0)
Requirement already satisfied, skipping upgrade: blis<0.8.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (0.4.1)
Requirement already satisfied, skipping upgrade: plac<1.2.0,>=0.9.6 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (1.1.3)
Requirement already satisfied, skipping upgrade: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (2.23.0)
Requirement already satisfied, skipping upgrade: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (4.41.1)
Requirement already satisfied, skipping upgrade: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (3.0.5)
Requirement already satisfied, skipping upgrade: setuptools in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (57.0.0)
Requirement already satisfied, skipping upgrade: numpy>=1.15.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (1.19.5)
Requirement already satisfied, skipping upgrade: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (1.0.5)
Requirement already satisfied, skipping upgrade: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (2.0.5)
Requirement already satisfied, skipping upgrade: thinc<7.5.0,>=7.4.1 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (7.4.5)
Requirement already satisfied, skipping upgrade: wasabi<1.1.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (0.8.2)
Requirement already satisfied, skipping upgrade: sortedcontainers~=2.1.0 in /usr/local/lib/python3.7/dist-packages (from SudachiPy>=0.4.9; python_version >= "3.5"->ginza) (2.1.0)
Requirement already satisfied, skipping upgrade: dartsclone~=0.9.0 in /usr/local/lib/python3.7/dist-packages (from SudachiPy>=0.4.9; python_version >= "3.5"->ginza) (0.9.0)
Requirement already satisfied, skipping upgrade: importlib-metadata>=0.20; python_version < "3.8" in /usr/local/lib/python3.7/dist-packages (from catalogue<1.1.0,>=0.0.7->spacy<3.0.0,>=2.3.2->ginza) (4.5.0)
Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.0.0,>=2.3.2->ginza) (2021.5.30)
Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.0.0,>=2.3.2->ginza) (1.24.3)
Requirement already satisfied, skipping upgrade: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.0.0,>=2.3.2->ginza) (3.0.4)
Requirement already satisfied, skipping upgrade: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.0.0,>=2.3.2->ginza) (2.10)
Requirement already satisfied, skipping upgrade: Cython in /usr/local/lib/python3.7/dist-packages (from dartsclone~=0.9.0->SudachiPy>=0.4.9; python_version >= "3.5"->ginza) (0.29.23)
Requirement already satisfied, skipping upgrade: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=0.20; python_version < "3.8"->catalogue<1.1.0,>=0.0.7->spacy<3.0.0,>=2.3.2->ginza) (3.4.1)
Requirement already satisfied, skipping upgrade: typing-extensions>=3.6.4; python_version < "3.8" in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=0.20; python_version < "3.8"->catalogue<1.1.0,>=0.0.7->spacy<3.0.0,>=2.3.2->ginza) (3.7.4.3)
Requirement already satisfied: en_core_web_sm==2.3.1 from https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz#egg=en_core_web_sm==2.3.1 in /usr/local/lib/python3.7/dist-packages (2.3.1)
Requirement already satisfied: spacy<2.4.0,>=2.3.0 in /usr/local/lib/python3.7/dist-packages (from en_core_web_sm==2.3.1) (2.3.7)
Requirement already satisfied: catalogue<1.1.0,>=0.0.7 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.0.0)
Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.19.5)
Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (57.0.0)
Requirement already satisfied: plac<1.2.0,>=0.9.6 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.1.3)
Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (3.0.5)
Requirement already satisfied: thinc<7.5.0,>=7.4.1 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (7.4.5)
Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (2.23.0)
Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (4.41.1)
Requirement already satisfied: wasabi<1.1.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (0.8.2)
Requirement already satisfied: srsly<1.1.0,>=1.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.0.5)
Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.0.5)
Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (2.0.5)
Requirement already satisfied: blis<0.8.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (0.4.1)
Requirement already satisfied: importlib-metadata>=0.20; python_version < "3.8" in /usr/local/lib/python3.7/dist-packages (from catalogue<1.1.0,>=0.0.7->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (4.5.0)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (2021.5.30)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (3.0.4)
Requirement already satisfied: typing-extensions>=3.6.4; python_version < "3.8" in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=0.20; python_version < "3.8"->catalogue<1.1.0,>=0.0.7->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (3.7.4.3)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=0.20; python_version < "3.8"->catalogue<1.1.0,>=0.0.7->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (3.4.1)
✔ Download and installation successful
You can now load the model via spacy.load('en_core_web_sm')
import spacy
nlp = spacy.load("en_core_web_sm")

10.1.2. 事前学習により得られたモデルの確認

  • 単語でも文章でもベクトル化できる。

  • ベクトル化できたため、類似単語も確認可能。

# 動作確認1
nlp = spacy.load("en_core_web_sm")
token = nlp('artificial')
print(token.vector.shape)
print(token.vector[:5])
print(token.vector_norm)
(96,)
[-0.64146984 -1.346977   -1.4614831  -2.7170322   4.683545  ]
16.45753783733204
# 動作確認2
words = ['apple', 'banana', 'car']

tokens = []
for word in words:
  tokens.append(nlp(word))

for token1 in tokens:
  for token2 in tokens:
    if token1 == token2:
      continue
    else:
      sim = token1.similarity(token2)
      print("similarity({}, {}) = {}".format(token1.text, token2.text, sim))
similarity(apple, banana) = 0.7039913704898891
similarity(apple, car) = 0.5905734774861556
similarity(banana, apple) = 0.7039913704898891
similarity(banana, car) = 0.5619708709107428
similarity(car, apple) = 0.5905734774861556
similarity(car, banana) = 0.5619708709107428
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:13: UserWarning: [W007] The model you're using has no word vectors loaded, so the result of the Doc.similarity method will be based on the tagger, parser and NER, which may not give useful similarity judgements. This may happen if you're using one of the small models, e.g. `en_core_web_sm`, which don't ship with word vectors and only use context-sensitive tensors. You can always add your own word vectors, or use one of the larger models instead if available.
  del sys.path[0]

10.2. ファインチューニング

事前学習済みモデルを用意できた。これを用いて本当にやりたい 20 news 分類学習に移る。

10.2.1. データセットを用意

  • 20 newsのデータセットを用意。

# fine-tuneing stage.
# デーセットの用意
# こちらも時間かかるので、変換したデータセットを指定した場所に保存。
# 既に保存済みデータセットの利用にも対応。

from sklearn.datasets import fetch_20newsgroups
#categories = ['alt.atheism', 'sci.space']
categories = ['comp.os.ms-windows.misc',  'comp.sys.mac.hardware',  'misc.forsale']
newsgroups_train = fetch_20newsgroups(subset='train', categories=categories)
train_text = newsgroups_train.data
train_label = newsgroups_train.target
newsgroups_test = fetch_20newsgroups(subset='test', categories=categories)
test_text = newsgroups_test.data
test_label = newsgroups_test.target

10.2.2. 事前学習モデルによるベクトル化

  • en_core_web_sm によるベクトルを train_vectors, test_vectors に保存。

!date
# 事前学習済みモデルにより、文章をベクトルに変換
def sentence2vector(sentences, model):
    vectors = []
    for sent in sentences:
        vectors.append(nlp(sent).vector)
    return vectors

train_vectors = sentence2vector(train_text, nlp)
test_vectors = sentence2vector(test_text, nlp)

!date
Wed Jun 16 07:34:43 UTC 2021
Wed Jun 16 07:37:35 UTC 2021

10.2.3. 分類学習

!date

#from sklearn.naive_bayes import MultinomialNB
from sklearn import svm
from sklearn.neural_network import MLPClassifier

#clf1 = MultinomialNB()
clf2 = svm.SVC(gamma='scale')
clf3 = MLPClassifier(max_iter=1000)
clfs = {"SVM":clf2, "NN":clf3}

scores = []
for name, clf in clfs.items():
  clf.fit(train_vectors, train_label)
  score = clf.score(test_vectors, test_label)
  scores.append(score)
  print("score = {} by {}".format(score,name))

!date
Wed Jun 16 07:37:35 UTC 2021
score = 0.6706586826347305 by SVM
score = 0.6595380667236954 by NN
Wed Jun 16 07:37:49 UTC 2021
!date
Wed Jun 16 07:37:49 UTC 2021