10. シンプルなファインチューニング例¶
やりたいこと
20 newsgroups text datasetを分類タスクとして学習したい。
方針
spacyにより学習済み英語モデル(en_core_web_sm)を用意する。
学習済みモデルを用いて、20 newsgroups textの記事をベクトル化する。
分類学習にはSVM、NNを用いる。
10.1. 事前学習¶
10.1.1. 環境構築¶
spacyと学習済みモデルをインストール。
Requirement already up-to-date: ginza in /usr/local/lib/python3.7/dist-packages (4.0.6)
Requirement already satisfied, skipping upgrade: SudachiDict-core>=20200330; python_version >= "3.5" in /usr/local/lib/python3.7/dist-packages (from ginza) (20210608)
Requirement already satisfied, skipping upgrade: spacy<3.0.0,>=2.3.2 in /usr/local/lib/python3.7/dist-packages (from ginza) (2.3.7)
Requirement already satisfied, skipping upgrade: ja-ginza<4.1.0,>=4.0.0 in /usr/local/lib/python3.7/dist-packages (from ginza) (4.0.0)
Requirement already satisfied, skipping upgrade: SudachiPy>=0.4.9; python_version >= "3.5" in /usr/local/lib/python3.7/dist-packages (from ginza) (0.5.2)
Requirement already satisfied, skipping upgrade: srsly<1.1.0,>=1.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (1.0.5)
Requirement already satisfied, skipping upgrade: catalogue<1.1.0,>=0.0.7 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (1.0.0)
Requirement already satisfied, skipping upgrade: blis<0.8.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (0.4.1)
Requirement already satisfied, skipping upgrade: plac<1.2.0,>=0.9.6 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (1.1.3)
Requirement already satisfied, skipping upgrade: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (2.23.0)
Requirement already satisfied, skipping upgrade: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (4.41.1)
Requirement already satisfied, skipping upgrade: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (3.0.5)
Requirement already satisfied, skipping upgrade: setuptools in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (57.0.0)
Requirement already satisfied, skipping upgrade: numpy>=1.15.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (1.19.5)
Requirement already satisfied, skipping upgrade: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (1.0.5)
Requirement already satisfied, skipping upgrade: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (2.0.5)
Requirement already satisfied, skipping upgrade: thinc<7.5.0,>=7.4.1 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (7.4.5)
Requirement already satisfied, skipping upgrade: wasabi<1.1.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0,>=2.3.2->ginza) (0.8.2)
Requirement already satisfied, skipping upgrade: sortedcontainers~=2.1.0 in /usr/local/lib/python3.7/dist-packages (from SudachiPy>=0.4.9; python_version >= "3.5"->ginza) (2.1.0)
Requirement already satisfied, skipping upgrade: dartsclone~=0.9.0 in /usr/local/lib/python3.7/dist-packages (from SudachiPy>=0.4.9; python_version >= "3.5"->ginza) (0.9.0)
Requirement already satisfied, skipping upgrade: importlib-metadata>=0.20; python_version < "3.8" in /usr/local/lib/python3.7/dist-packages (from catalogue<1.1.0,>=0.0.7->spacy<3.0.0,>=2.3.2->ginza) (4.5.0)
Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.0.0,>=2.3.2->ginza) (2021.5.30)
Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.0.0,>=2.3.2->ginza) (1.24.3)
Requirement already satisfied, skipping upgrade: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.0.0,>=2.3.2->ginza) (3.0.4)
Requirement already satisfied, skipping upgrade: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.0.0,>=2.3.2->ginza) (2.10)
Requirement already satisfied, skipping upgrade: Cython in /usr/local/lib/python3.7/dist-packages (from dartsclone~=0.9.0->SudachiPy>=0.4.9; python_version >= "3.5"->ginza) (0.29.23)
Requirement already satisfied, skipping upgrade: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=0.20; python_version < "3.8"->catalogue<1.1.0,>=0.0.7->spacy<3.0.0,>=2.3.2->ginza) (3.4.1)
Requirement already satisfied, skipping upgrade: typing-extensions>=3.6.4; python_version < "3.8" in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=0.20; python_version < "3.8"->catalogue<1.1.0,>=0.0.7->spacy<3.0.0,>=2.3.2->ginza) (3.7.4.3)
Requirement already satisfied: en_core_web_sm==2.3.1 from https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz#egg=en_core_web_sm==2.3.1 in /usr/local/lib/python3.7/dist-packages (2.3.1)
Requirement already satisfied: spacy<2.4.0,>=2.3.0 in /usr/local/lib/python3.7/dist-packages (from en_core_web_sm==2.3.1) (2.3.7)
Requirement already satisfied: catalogue<1.1.0,>=0.0.7 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.0.0)
Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.19.5)
Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (57.0.0)
Requirement already satisfied: plac<1.2.0,>=0.9.6 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.1.3)
Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (3.0.5)
Requirement already satisfied: thinc<7.5.0,>=7.4.1 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (7.4.5)
Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (2.23.0)
Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (4.41.1)
Requirement already satisfied: wasabi<1.1.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (0.8.2)
Requirement already satisfied: srsly<1.1.0,>=1.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.0.5)
Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.0.5)
Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (2.0.5)
Requirement already satisfied: blis<0.8.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (0.4.1)
Requirement already satisfied: importlib-metadata>=0.20; python_version < "3.8" in /usr/local/lib/python3.7/dist-packages (from catalogue<1.1.0,>=0.0.7->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (4.5.0)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (2021.5.30)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (3.0.4)
Requirement already satisfied: typing-extensions>=3.6.4; python_version < "3.8" in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=0.20; python_version < "3.8"->catalogue<1.1.0,>=0.0.7->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (3.7.4.3)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=0.20; python_version < "3.8"->catalogue<1.1.0,>=0.0.7->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (3.4.1)
✔ Download and installation successful
You can now load the model via spacy.load('en_core_web_sm')
10.1.2. 事前学習により得られたモデルの確認¶
単語でも文章でもベクトル化できる。
ベクトル化できたため、類似単語も確認可能。
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:13: UserWarning: [W007] The model you're using has no word vectors loaded, so the result of the Doc.similarity method will be based on the tagger, parser and NER, which may not give useful similarity judgements. This may happen if you're using one of the small models, e.g. `en_core_web_sm`, which don't ship with word vectors and only use context-sensitive tensors. You can always add your own word vectors, or use one of the larger models instead if available.
del sys.path[0]
10.2. ファインチューニング¶
事前学習済みモデルを用意できた。これを用いて本当にやりたい 20 news 分類学習に移る。
10.2.1. データセットを用意¶
20 newsのデータセットを用意。
# fine-tuneing stage.
# デーセットの用意
# こちらも時間かかるので、変換したデータセットを指定した場所に保存。
# 既に保存済みデータセットの利用にも対応。
from sklearn.datasets import fetch_20newsgroups
#categories = ['alt.atheism', 'sci.space']
categories = ['comp.os.ms-windows.misc', 'comp.sys.mac.hardware', 'misc.forsale']
newsgroups_train = fetch_20newsgroups(subset='train', categories=categories)
train_text = newsgroups_train.data
train_label = newsgroups_train.target
newsgroups_test = fetch_20newsgroups(subset='test', categories=categories)
test_text = newsgroups_test.data
test_label = newsgroups_test.target
10.2.3. 分類学習¶
!date
#from sklearn.naive_bayes import MultinomialNB
from sklearn import svm
from sklearn.neural_network import MLPClassifier
#clf1 = MultinomialNB()
clf2 = svm.SVC(gamma='scale')
clf3 = MLPClassifier(max_iter=1000)
clfs = {"SVM":clf2, "NN":clf3}
scores = []
for name, clf in clfs.items():
clf.fit(train_vectors, train_label)
score = clf.score(test_vectors, test_label)
scores.append(score)
print("score = {} by {}".format(score,name))
!date