{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"fine-turning.ipynb","provenance":[],"collapsed_sections":[],"toc_visible":true,"authorship_tag":"ABX9TyOLlAdQ2Ug0SFBNwGE/ZS0d"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","metadata":{"id":"UedTzbMFg8S7"},"source":["# シンプルなファインチューニング例\n","- やりたいこと\n","  - [20 newsgroups text dataset](https://scikit-learn.org/stable/datasets/real_world.html#newsgroups-dataset)を分類タスクとして学習したい。\n","- 方針\n","  - fastTextにより「Wikipedia(en)コーパスの一部」を用いて事前学習する。（異なるソースを元に言語モデルを構築する）\n","    - なお、ここではfastTextで学習するためにどのようにデータを要ししたら良いのかを確認しやすくするためにWikipediaコーパスから事前学習を行っている。しかし自前でWikipedia事前学習するぐらいなら、最初から[FastText](https://fasttext.cc)で公開されている事前学習済みモデルをダウンロードして用いるほうが良い。\n","  - fastText学習済みモデルを用いて、20 newsgroups textの記事をベクトル化する。\n","  - 比較対象としてTF-IDFによるベクトル化も用意する。\n","  - 分類学習にはナイーブベイズ、SVM、NNを用いる。なお、ナイーブベイズは基本的にはカウント情報を想定しているため、fastTextベクトルには適用できないことからTF-IDFのみに適用する。"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Gi4xg_nAJ5Xi","executionInfo":{"status":"ok","timestamp":1622628591595,"user_tz":-540,"elapsed":243,"user":{"displayName":"TOMA Naruaki","photoUrl":"","userId":"11747312442870110137"}},"outputId":"37f61bf2-13a0-4add-b76b-72472569f601"},"source":["!date"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Wed Jun  2 10:09:51 UTC 2021\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"z2fRtSAyxKp-"},"source":["## 事前学習"]},{"cell_type":"markdown","metadata":{"id":"lCUiutDriMiR"},"source":["### 環境構築\n","- fastTextモデルのために[gensim](https://radimrehurek.com/gensim/)を利用。"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"n1N_-9Zi2u97","executionInfo":{"status":"ok","timestamp":1622628600976,"user_tz":-540,"elapsed":8541,"user":{"displayName":"TOMA Naruaki","photoUrl":"","userId":"11747312442870110137"}},"outputId":"7b80153b-fdc8-44e1-82d2-45324f704366"},"source":["!pip install --upgrade gensim"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Collecting gensim\n","\u001b[?25l  Downloading https://files.pythonhosted.org/packages/44/52/f1417772965652d4ca6f901515debcd9d6c5430969e8c02ee7737e6de61c/gensim-4.0.1-cp37-cp37m-manylinux1_x86_64.whl (23.9MB)\n","\u001b[K     |████████████████████████████████| 23.9MB 159kB/s \n","\u001b[?25hRequirement already satisfied, skipping upgrade: scipy>=0.18.1 in /usr/local/lib/python3.7/dist-packages (from gensim) (1.4.1)\n","Requirement already satisfied, skipping upgrade: numpy>=1.11.3 in /usr/local/lib/python3.7/dist-packages (from gensim) (1.19.5)\n","Requirement already satisfied, skipping upgrade: smart-open>=1.8.1 in /usr/local/lib/python3.7/dist-packages (from gensim) (5.0.0)\n","Installing collected packages: gensim\n","  Found existing installation: gensim 3.6.0\n","    Uninstalling gensim-3.6.0:\n","      Successfully uninstalled gensim-3.6.0\n","Successfully installed gensim-4.0.1\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"O6q4nSdxifdA"},"source":["### データセットの準備\n","- [英語版Wikipediaのダンプデータ](https://dumps.wikimedia.org/enwiki/latest/)をダウンロードし、これを事前学習用コーパスとして利用する。なお、全データを用いると圧縮状態で15GBを超えて待ち時間が長いため、ここでは小規模で提供されているものを指定している。\n","- ダンプデータはbzcatで確認しているように、XML形式で書かれている。"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"vPcG1mAI2W6o","executionInfo":{"status":"ok","timestamp":1622628653035,"user_tz":-540,"elapsed":52062,"user":{"displayName":"TOMA Naruaki","photoUrl":"","userId":"11747312442870110137"}},"outputId":"3b84e714-5e5b-4c4c-96b4-4de196b6da4f"},"source":["#!curl -O https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2\n","!curl -O https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles1.xml-p1p41242.bz2"],"execution_count":null,"outputs":[{"output_type":"stream","text":["  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n","                                 Dload  Upload   Total   Spent    Left  Speed\n","100  237M  100  237M    0     0  4668k      0  0:00:52  0:00:52 --:--:-- 4760k\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"E0A8hag-DAup","executionInfo":{"status":"ok","timestamp":1622628653036,"user_tz":-540,"elapsed":17,"user":{"displayName":"TOMA Naruaki","photoUrl":"","userId":"11747312442870110137"}},"outputId":"bbbe3f2e-ac76-4bd7-9f9b-d39fa3ffa43b"},"source":["!bzcat enwiki-latest-pages-articles1.xml-p1p41242.bz2 | head"],"execution_count":null,"outputs":[{"output_type":"stream","text":["<mediawiki xmlns=\"http://www.mediawiki.org/xml/export-0.10/\" xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\" xsi:schemaLocation=\"http://www.mediawiki.org/xml/export-0.10/ http://www.mediawiki.org/xml/export-0.10.xsd\" version=\"0.10\" xml:lang=\"en\">\n","  <siteinfo>\n","    <sitename>Wikipedia</sitename>\n","    <dbname>enwiki</dbname>\n","    <base>https://en.wikipedia.org/wiki/Main_Page</base>\n","    <generator>MediaWiki 1.37.0-wmf.5</generator>\n","    <case>first-letter</case>\n","    <namespaces>\n","      <namespace key=\"-2\" case=\"first-letter\">Media</namespace>\n","      <namespace key=\"-1\" case=\"first-letter\">Special</namespace>\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"oKVSFGGji34A"},"source":["### 環境構築2\n","- [multiprocessing](https://docs.python.org/ja/3/library/multiprocessing.html)は、実行環境におけるCPU数（コア数）を確認するために利用。\n","- [gensim.corpora.wikicorpus](https://radimrehurek.com/gensim/corpora/wikicorpus.html)は、Wikipediaのダンプデータから本文データのみを抽出するために利用。\n","- [gensim.models.fasttext](https://radimrehurek.com/gensim/models/fasttext.html)は、FastTextモデル。"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"NqRIiq8l3JtV","executionInfo":{"status":"ok","timestamp":1622628654097,"user_tz":-540,"elapsed":1066,"user":{"displayName":"TOMA Naruaki","photoUrl":"","userId":"11747312442870110137"}},"outputId":"7db89479-f6e5-4a7a-8f40-e93436a70912"},"source":["import multiprocessing\n","from gensim.corpora.wikicorpus import WikiCorpus\n","from gensim.models.fasttext import FastText as FT_gensim"],"execution_count":null,"outputs":[{"output_type":"stream","text":["/usr/local/lib/python3.7/dist-packages/gensim/similarities/__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package <https://pypi.org/project/python-Levenshtein/> is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.\n","  warnings.warn(msg)\n"],"name":"stderr"}]},{"cell_type":"markdown","metadata":{"id":"ToUOvs4UkAKJ"},"source":["ダンプデータから本文抽出する様子。sentencesにlatestの全本文があり、文書数は15025件。1件目の文書に含まれる単語数は50単語。"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"kANgs9Lan0cY","executionInfo":{"status":"ok","timestamp":1622628756942,"user_tz":-540,"elapsed":102856,"user":{"displayName":"TOMA Naruaki","photoUrl":"","userId":"11747312442870110137"}},"outputId":"84ef1db0-70ed-429d-b7d8-314d57e07687"},"source":["!date\n","\n","!curl -O https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles11.xml-p6899367p7054859.bz2\n","wikipedia_data = \"./enwiki-latest-pages-articles11.xml-p6899367p7054859.bz2\"\n","\n","# expand and extarct\n","print(\"get texts from {}\".format(wikipedia_data))\n","wiki = WikiCorpus(wikipedia_data, dictionary={})\n","sentences = list(wiki.get_texts())\n","\n","# 出力確認\n","print(len(sentences))\n","print(len(sentences[0]))\n","print(sentences[0][0:5])\n","\n","!date"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Wed Jun  2 10:10:53 UTC 2021\n","  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n","                                 Dload  Upload   Total   Spent    Left  Speed\n","100 45.4M  100 45.4M    0     0  4338k      0  0:00:10  0:00:10 --:--:-- 4689k\n","get texts from ./enwiki-latest-pages-articles11.xml-p6899367p7054859.bz2\n","15025\n","50\n","['waterhouse', 'swamp', 'rat', 'scapteromys', 'tumidus']\n","Wed Jun  2 10:12:36 UTC 2021\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"uYpTtjuBv7QW"},"source":["### fastTextによる事前学習\n","- build_vocab() により、まずボキャブラリ（単語一覧）を作成する。\n","- その後、コーパスとそれに対する基本情報、エポック数を指定して学習する。"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"iwBkobZwEcCb","executionInfo":{"status":"ok","timestamp":1622629485746,"user_tz":-540,"elapsed":728809,"user":{"displayName":"TOMA Naruaki","photoUrl":"","userId":"11747312442870110137"}},"outputId":"ce982721-e07d-4c1f-a0b0-894eb0b24268"},"source":["!date\n","\n","# faxtText\n","ft_model = FT_gensim(vector_size=200, window=10, min_count=10, workers=max(1, multiprocessing.cpu_count() - 1))\n","\n","# build the vocabulary\n","print(\"building vocab...\")\n","ft_model.build_vocab(sentences)\n","\n","# train the model\n","print(\"training model...\")\n","ft_model.train(\n","    sentences,\n","    epochs = ft_model.epochs,\n","    total_examples = ft_model.corpus_count,\n","    total_words = ft_model.corpus_total_words\n",")\n","\n","print(\"training done.\")\n","\n","!date"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Wed Jun  2 10:12:36 UTC 2021\n","building vocab...\n","training model...\n","training done.\n","Wed Jun  2 10:24:45 UTC 2021\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"XWzvQVy_waHJ"},"source":["### 事前学習により得られたモデルの確認\n","- 単語でも文章でもベクトル化できる。\n","- \"hoge\" は元々の文書には存在しない（False）が、ベクトル化できている。（サブワードによる未知語対応）\n","- ベクトル化できたため、類似単語も確認可能。"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"q7RQqAcM6K8B","executionInfo":{"status":"ok","timestamp":1622629485746,"user_tz":-540,"elapsed":13,"user":{"displayName":"TOMA Naruaki","photoUrl":"","userId":"11747312442870110137"}},"outputId":"1ce4e773-0015-43bb-a333-79b982ee2cea"},"source":["# 動作確認\n","print(ft_model.wv['artificial'].shape)\n","print(ft_model.wv['artificial'][:5])\n","print(ft_model.wv[\"more like funchuck,Gave this\"][:5])\n","\n","print(\"===========\")\n","print(\"hoge\" in ft_model.wv.key_to_index)\n","print(ft_model.wv[\"hoge\"][:5])\n","\n","print(\"===========\")\n","print(ft_model.wv.most_similar(\"computer\"))\n","print(ft_model.wv.most_similar(\"programming\"))\n","print(ft_model.wv.most_similar(\"apple\"))"],"execution_count":null,"outputs":[{"output_type":"stream","text":["(200,)\n","[-0.6973012  -0.26773152  0.46666357 -0.25891605  1.3251537 ]\n","[-0.17531493 -0.01367316 -0.04481015 -0.02450226  0.121394  ]\n","===========\n","False\n","[ 0.26534227  1.1729028   0.04152503 -2.065968    0.08280751]\n","===========\n","[('compute', 0.9225346446037292), ('computers', 0.874762773513794), ('computing', 0.8344529867172241), ('compulsory', 0.8319492936134338), ('computable', 0.7980512380599976), ('computerized', 0.794980525970459), ('compulsive', 0.7739043235778809), ('computed', 0.7671758532524109), ('computational', 0.762458324432373), ('computability', 0.7551568746566772)]\n","[('programme', 0.9124019145965576), ('programmes', 0.9014513492584229), ('programmer', 0.8962664604187012), ('programmers', 0.8907719254493713), ('programmed', 0.8655120134353638), ('programmable', 0.8621218204498291), ('program', 0.8598142266273499), ('programs', 0.8576140999794006), ('programmatic', 0.8526832461357117), ('prog', 0.8027951121330261)]\n","[('appleby', 0.9099407196044922), ('apples', 0.7911590933799744), ('applause', 0.7831224799156189), ('downloadable', 0.78069007396698), ('applicable', 0.7683167457580566), ('appleseed', 0.7469122409820557), ('app', 0.7379465699195862), ('appleyard', 0.7336268424987793), ('apply', 0.7324667572975159), ('pineapple', 0.7182422876358032)]\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"75ltLfSbw4a7"},"source":["## ファインチューニング\n","fastTextによる事前学習を終えた。これを用いて本当にやりたい 20 news 分類学習に移る。"]},{"cell_type":"markdown","metadata":{"id":"GtruEPCYxUlN"},"source":["### データセットを用意\n","- 20 newsのデータセットを用意。"]},{"cell_type":"code","metadata":{"id":"fI_I4I8--HH-","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1622629497299,"user_tz":-540,"elapsed":11557,"user":{"displayName":"TOMA Naruaki","photoUrl":"","userId":"11747312442870110137"}},"outputId":"dabfeed3-ac1e-4238-e50f-5e862667cf85"},"source":["# fine-tuneing stage.\n","# デーセットの用意\n","# こちらも時間かかるので、変換したデータセットを指定した場所に保存。\n","# 既に保存済みデータセットの利用にも対応。\n","\n","from sklearn.datasets import fetch_20newsgroups\n","#categories = ['alt.atheism', 'sci.space']\n","categories = ['comp.os.ms-windows.misc',  'comp.sys.mac.hardware',  'misc.forsale']\n","newsgroups_train = fetch_20newsgroups(subset='train', categories=categories)\n","train_text = newsgroups_train.data\n","train_label = newsgroups_train.target\n","newsgroups_test = fetch_20newsgroups(subset='test', categories=categories)\n","test_text = newsgroups_test.data\n","test_label = newsgroups_test.target"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Downloading 20news dataset. This may take a few minutes.\n","Downloading dataset from https://ndownloader.figshare.com/files/5975967 (14 MB)\n"],"name":"stderr"}]},{"cell_type":"markdown","metadata":{"id":"xZNv8LuaxjOG"},"source":["### 事前学習モデルによるベクトル化"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"gLnYJcprFi43","executionInfo":{"status":"ok","timestamp":1622629528344,"user_tz":-540,"elapsed":31056,"user":{"displayName":"TOMA Naruaki","photoUrl":"","userId":"11747312442870110137"}},"outputId":"58e4d61f-b82e-4ebc-c81b-096ef1585c56"},"source":["!date\n","# 事前学習したfastTextにより、文章をベクトルに変換\n","def sentence2vector(sentences, model):\n","    vectors = []\n","    for sent in sentences:\n","        vectors.append(model.wv[sent])\n","    return vectors\n","\n","ft_train_vectors = sentence2vector(train_text, ft_model)\n","ft_test_vectors = sentence2vector(test_text, ft_model)\n","\n","!date"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Wed Jun  2 10:24:56 UTC 2021\n","Wed Jun  2 10:25:28 UTC 2021\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"C5kikKNtxqJF"},"source":["### 分類学習モデルによる学習（fastText版）"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"jDxHPF4zPXkM","executionInfo":{"status":"ok","timestamp":1622629544730,"user_tz":-540,"elapsed":16391,"user":{"displayName":"TOMA Naruaki","photoUrl":"","userId":"11747312442870110137"}},"outputId":"8d531db2-b03b-4e5c-e824-49a472d980fd"},"source":["!date\n","\n","#from sklearn.naive_bayes import MultinomialNB\n","from sklearn import svm\n","from sklearn.neural_network import MLPClassifier\n","\n","#clf1 = MultinomialNB()\n","clf2 = svm.SVC(gamma='scale')\n","clf3 = MLPClassifier(max_iter=500)\n","clfs = {\"SVM\":clf2, \"NN\":clf3}\n","\n","ft_scores = []\n","for name, clf in clfs.items():\n","  clf.fit(ft_train_vectors, train_label)\n","  score = clf.score(ft_test_vectors, test_label)\n","  ft_scores.append(score)\n","  print(\"ft_score = {} by {}\".format(score,name))\n","\n","!date"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Wed Jun  2 10:25:28 UTC 2021\n","ft_score = 0.7681779298545766 by SVM\n","ft_score = 0.7895637296834902 by NN\n","Wed Jun  2 10:25:44 UTC 2021\n"],"name":"stdout"},{"output_type":"stream","text":["/usr/local/lib/python3.7/dist-packages/sklearn/neural_network/_multilayer_perceptron.py:571: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (500) reached and the optimization hasn't converged yet.\n","  % self.max_iter, ConvergenceWarning)\n"],"name":"stderr"}]},{"cell_type":"markdown","metadata":{"id":"gO3b-VVDxxdZ"},"source":["### 分類学習（TF-IDF版）"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"brDF3Er5GCoW","executionInfo":{"status":"ok","timestamp":1622629545473,"user_tz":-540,"elapsed":746,"user":{"displayName":"TOMA Naruaki","photoUrl":"","userId":"11747312442870110137"}},"outputId":"17e7a8cd-a2b5-47a9-e2ea-ed7a203b68cc"},"source":["# 比較対象の、事前学習なし実験。\n","# BoW + TFIDFによるベクトル生成\n","\n","from sklearn.feature_extraction.text import TfidfVectorizer\n","vectorizer = TfidfVectorizer()\n","tfidf_train_vectors = vectorizer.fit_transform(newsgroups_train.data)\n","print(\"train_vectors.shape=\", tfidf_train_vectors.shape)\n","print(\"len(train_label)=\",len(train_label))\n","\n","tfidf_test_vectors = vectorizer.transform(newsgroups_test.data)\n","print(\"test_vectors.shape=\", tfidf_test_vectors.shape)\n","print(\"len(test_label)=\",len(test_label))"],"execution_count":null,"outputs":[{"output_type":"stream","text":["train_vectors.shape= (1754, 54317)\n","len(train_label)= 1754\n","test_vectors.shape= (1169, 54317)\n","len(test_label)= 1169\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"AsIvZs7oe1PT","executionInfo":{"status":"ok","timestamp":1622629643962,"user_tz":-540,"elapsed":98500,"user":{"displayName":"TOMA Naruaki","photoUrl":"","userId":"11747312442870110137"}},"outputId":"16a4a6c4-c4f1-4b74-e414-d168e3aa490b"},"source":["!date\n","\n","from sklearn.naive_bayes import MultinomialNB\n","from sklearn import svm\n","from sklearn.neural_network import MLPClassifier\n","\n","clf1 = MultinomialNB()\n","clf2 = svm.SVC(gamma='scale')\n","clf3 = MLPClassifier(max_iter=500)\n","clfs = {\"NB\":clf1, \"SVM\":clf2, \"NN\":clf3}\n","\n","tfidf_scores = []\n","for name, clf in clfs.items():\n","  clf.fit(tfidf_train_vectors, train_label)\n","  score = clf.score(tfidf_test_vectors, test_label)\n","  tfidf_scores.append(score)\n","  print(\"tfidf_scores = {} by {}\".format(score,name))\n","\n","!date"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Wed Jun  2 10:25:45 UTC 2021\n","tfidf_scores = 0.9024807527801539 by NB\n","tfidf_scores = 0.9230111206159111 by SVM\n","tfidf_scores = 0.9084687767322498 by NN\n","Wed Jun  2 10:27:23 UTC 2021\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"rw9KrgsrHvnW","executionInfo":{"status":"ok","timestamp":1622629644422,"user_tz":-540,"elapsed":465,"user":{"displayName":"TOMA Naruaki","photoUrl":"","userId":"11747312442870110137"}},"outputId":"77725488-1339-4798-9ec8-0c1a76e5ea35"},"source":["!date"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Wed Jun  2 10:27:23 UTC 2021\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"j7q9l-kHGHpu"},"source":[""],"execution_count":null,"outputs":[]}]}