{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ダウンロードしたデータセットに機械学習を適用する流れ"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 達成目標\n",
    "- 学習器に対して\n",
    "    - 機械学習のタスク種別がclassification/regression/clusteringのどれに該当するかを判断することができる。\n",
    "    - 判断したタスクに用いることができる学習器（モデル）を選択することができる。\n",
    "        - 参考\n",
    "            - [Choosing the right estimator](https://scikit-learn.org/stable/tutorial/machine_learning_map/index.html)\n",
    "            - [User Guide](https://scikit-learn.org/stable/user_guide.html)\n",
    "            - [Examplex](https://scikit-learn.org/stable/auto_examples/index.html)\n",
    "    - 選択したモデルのAPIを参照し、手動調整が必要なハイパーパラメータを確認することができる。\n",
    "    - 選択したモデルのAPIを参照し、評価方法(score関数)を確認することができる。\n",
    "- データセットに対して\n",
    "    - ダウンロードしたデータセットの中身を目視確認し、保存されている形式を理解することができる。\n",
    "    - 代表的な保存形式で保存されているデータセットに対して、pd.DataFrame形式として読み込むことができる。\n",
    "        - 代表的な保存形式: csv, tsv\n",
    "    - DataFrameから必要な行や列を指定してデータを抜き出すことができる。\n",
    "        - 特徴ベクトルと教師データを分けたり、データセット全体を学習用・テスト用に分けることができる。\n",
    "- 機械学習に対して\n",
    "    - 学習データを学習器に与えて学習させ、テストデータで学習結果の適切さを評価することができる。\n",
    "    - 学習データとテストデータに分割することができる。[交差検定(cross-validation)](https://scikit-learn.org/stable/modules/cross_validation.html)するとベター。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 今回の流れ\n",
    "- Machine Learning Repositoryで公開されている代表的なデータ[Iris Data Set](https://archive.ics.uci.edu/ml/datasets/Iris)に対し、[LinearSVC](https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html)で分類学習してみる。5分割検定で評価する。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### データセットの準備\n",
    "データセットのページには iris.data, iris.names が用意されていることを確認。この2つのファイルをダウンロードして、中身を覗いてみよう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
      "                                 Dload  Upload   Total   Spent    Left  Speed\n",
      "100  4551  100  4551    0     0   7157      0 --:--:-- --:--:-- --:--:--  7212\n",
      "5.1,3.5,1.4,0.2,Iris-setosa\n",
      "4.9,3.0,1.4,0.2,Iris-setosa\n",
      "4.7,3.2,1.3,0.2,Iris-setosa\n",
      "4.6,3.1,1.5,0.2,Iris-setosa\n",
      "5.0,3.6,1.4,0.2,Iris-setosa\n",
      "5.4,3.9,1.7,0.4,Iris-setosa\n",
      "4.6,3.4,1.4,0.3,Iris-setosa\n",
      "5.0,3.4,1.5,0.2,Iris-setosa\n",
      "4.4,2.9,1.4,0.2,Iris-setosa\n",
      "4.9,3.1,1.5,0.1,Iris-setosa\n"
     ]
    }
   ],
   "source": [
    "# iris.data のダウンロード\n",
    "!curl -O https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data\n",
    "\n",
    "# head コマンドにより、冒頭数行を確認する\n",
    "!head iris.data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- どうやら iris.data には「各値がカンマで区切られて列挙されたCSV形式」でデータセットが保存されているらしい。冒頭4つが特徴量で、最後の1つはラベルデータっぽい。CSVなので恐らく [pd.read_csv](https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html) を使えば DataFrame 形式で読み込めるだろう。\n",
    "- 次に iris.names を確認してみよう"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
      "                                 Dload  Upload   Total   Spent    Left  Speed\n",
      "100  2998  100  2998    0     0   4690      0 --:--:-- --:--:-- --:--:--  4721\n",
      "1. Title: Iris Plants Database\n",
      "\tUpdated Sept 21 by C.Blake - Added discrepency information\n",
      "\n",
      "2. Sources:\n",
      "     (a) Creator: R.A. Fisher\n",
      "     (b) Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n",
      "     (c) Date: July, 1988\n",
      "\n",
      "3. Past Usage:\n",
      "   - Publications: too many to mention!!!  Here are a few.\n"
     ]
    }
   ],
   "source": [
    "!curl -O https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.names\n",
    "!head iris.names"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- どうやら iris.names はこのデータセットに関する説明が書かれているらしい。今回は使わない（説明ページで確認済み）ので、無視することにしよう。\n",
    "- iris.data だけで揃うようなので、これを読み込んで処理する準備をしよう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>5.1</th>\n",
       "      <th>3.5</th>\n",
       "      <th>1.4</th>\n",
       "      <th>0.2</th>\n",
       "      <th>Iris-setosa</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>4.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4.7</td>\n",
       "      <td>3.2</td>\n",
       "      <td>1.3</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4.6</td>\n",
       "      <td>3.1</td>\n",
       "      <td>1.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>5.0</td>\n",
       "      <td>3.6</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5.4</td>\n",
       "      <td>3.9</td>\n",
       "      <td>1.7</td>\n",
       "      <td>0.4</td>\n",
       "      <td>Iris-setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>144</th>\n",
       "      <td>6.7</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.3</td>\n",
       "      <td>Iris-virginica</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>145</th>\n",
       "      <td>6.3</td>\n",
       "      <td>2.5</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1.9</td>\n",
       "      <td>Iris-virginica</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>146</th>\n",
       "      <td>6.5</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.0</td>\n",
       "      <td>Iris-virginica</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>147</th>\n",
       "      <td>6.2</td>\n",
       "      <td>3.4</td>\n",
       "      <td>5.4</td>\n",
       "      <td>2.3</td>\n",
       "      <td>Iris-virginica</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>148</th>\n",
       "      <td>5.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.1</td>\n",
       "      <td>1.8</td>\n",
       "      <td>Iris-virginica</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>149 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     5.1  3.5  1.4  0.2     Iris-setosa\n",
       "0    4.9  3.0  1.4  0.2     Iris-setosa\n",
       "1    4.7  3.2  1.3  0.2     Iris-setosa\n",
       "2    4.6  3.1  1.5  0.2     Iris-setosa\n",
       "3    5.0  3.6  1.4  0.2     Iris-setosa\n",
       "4    5.4  3.9  1.7  0.4     Iris-setosa\n",
       "..   ...  ...  ...  ...             ...\n",
       "144  6.7  3.0  5.2  2.3  Iris-virginica\n",
       "145  6.3  2.5  5.0  1.9  Iris-virginica\n",
       "146  6.5  3.0  5.2  2.0  Iris-virginica\n",
       "147  6.2  3.4  5.4  2.3  Iris-virginica\n",
       "148  5.9  3.0  5.1  1.8  Iris-virginica\n",
       "\n",
       "[149 rows x 5 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "df = pd.read_csv('iris.data')\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- pd.read_csvをそのまま使うと、1行目のデータを見出しとして処理してしまっている。これを避けるため、見出しなしとして再読み込みし直すか、もしくは別途見出しを用意して再読込し直す必要がありそうだ。ここでは見出しを用意することにしよう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sepal length</th>\n",
       "      <th>sepal width</th>\n",
       "      <th>petal length</th>\n",
       "      <th>petal width</th>\n",
       "      <th>class</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>5.1</td>\n",
       "      <td>3.5</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4.7</td>\n",
       "      <td>3.2</td>\n",
       "      <td>1.3</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4.6</td>\n",
       "      <td>3.1</td>\n",
       "      <td>1.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5.0</td>\n",
       "      <td>3.6</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>Iris-setosa</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>145</th>\n",
       "      <td>6.7</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.3</td>\n",
       "      <td>Iris-virginica</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>146</th>\n",
       "      <td>6.3</td>\n",
       "      <td>2.5</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1.9</td>\n",
       "      <td>Iris-virginica</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>147</th>\n",
       "      <td>6.5</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.0</td>\n",
       "      <td>Iris-virginica</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>148</th>\n",
       "      <td>6.2</td>\n",
       "      <td>3.4</td>\n",
       "      <td>5.4</td>\n",
       "      <td>2.3</td>\n",
       "      <td>Iris-virginica</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>149</th>\n",
       "      <td>5.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.1</td>\n",
       "      <td>1.8</td>\n",
       "      <td>Iris-virginica</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>150 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     sepal length  sepal width  petal length  petal width           class\n",
       "0             5.1          3.5           1.4          0.2     Iris-setosa\n",
       "1             4.9          3.0           1.4          0.2     Iris-setosa\n",
       "2             4.7          3.2           1.3          0.2     Iris-setosa\n",
       "3             4.6          3.1           1.5          0.2     Iris-setosa\n",
       "4             5.0          3.6           1.4          0.2     Iris-setosa\n",
       "..            ...          ...           ...          ...             ...\n",
       "145           6.7          3.0           5.2          2.3  Iris-virginica\n",
       "146           6.3          2.5           5.0          1.9  Iris-virginica\n",
       "147           6.5          3.0           5.2          2.0  Iris-virginica\n",
       "148           6.2          3.4           5.4          2.3  Iris-virginica\n",
       "149           5.9          3.0           5.1          1.8  Iris-virginica\n",
       "\n",
       "[150 rows x 5 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "feature_names = ['sepal length', 'sepal width', 'petal length', 'petal width', 'class']\n",
    "df = pd.read_csv('iris.data', names=feature_names)\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 学習に向けて特徴ベクトルと教師データを分けておく。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X.shape =  (150, 4)\n",
      "y.shape =  (150,)\n",
      "sepal length    5.1\n",
      "sepal width     3.5\n",
      "petal length    1.4\n",
      "petal width     0.2\n",
      "Name: 0, dtype: float64\n",
      "Iris-setosa\n"
     ]
    }
   ],
   "source": [
    "X = df[['sepal length', 'sepal width', 'petal length', 'petal width']]\n",
    "y = df['class']\n",
    "print('X.shape = ', X.shape)\n",
    "print('y.shape = ', y.shape)\n",
    "print(X.loc[0])\n",
    "print(y[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 学習器の用意\n",
    "- [LinearSVC](https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html)を参考にimportしておく。\n",
    "- ハイパーパラメータのうち C について 0.5, 1.0, 1,5 の3通りを試してみよう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.svm import LinearSVC\n",
    "Cs = [0.5, 1.0, 1.5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 交差検定を利用して学習＆評価\n",
    "今回は[cross_val_score](https://scikit-learn.org/stable/modules/cross_validation.html#computing-cross-validated-metrics)を使ってみよう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C = 0.5: scores=[0.96666667 0.96666667 0.9        0.96666667 1.        ], average=0.960\n",
      "C = 1.0: scores=[0.93333333 0.96666667 0.93333333 0.93333333 1.        ], average=0.953\n",
      "C = 1.5: scores=[0.9        1.         0.93333333 1.         0.9       ], average=0.947\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/tnal/.venv/dm/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  warnings.warn(\n",
      "/Users/tnal/.venv/dm/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  warnings.warn(\n",
      "/Users/tnal/.venv/dm/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  warnings.warn(\n",
      "/Users/tnal/.venv/dm/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  warnings.warn(\n",
      "/Users/tnal/.venv/dm/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  warnings.warn(\n",
      "/Users/tnal/.venv/dm/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  warnings.warn(\n",
      "/Users/tnal/.venv/dm/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  warnings.warn(\n",
      "/Users/tnal/.venv/dm/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  warnings.warn(\n",
      "/Users/tnal/.venv/dm/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  warnings.warn(\n",
      "/Users/tnal/.venv/dm/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  warnings.warn(\n",
      "/Users/tnal/.venv/dm/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  warnings.warn(\n",
      "/Users/tnal/.venv/dm/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  warnings.warn(\n",
      "/Users/tnal/.venv/dm/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  warnings.warn(\n",
      "/Users/tnal/.venv/dm/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  warnings.warn(\n",
      "/Users/tnal/.venv/dm/lib/python3.8/site-packages/sklearn/svm/_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from sklearn.model_selection import cross_val_score\n",
    "from sklearn.model_selection import KFold\n",
    "\n",
    "k_folds = 5 # 5分割検定\n",
    "\n",
    "for c in Cs:\n",
    "    model = LinearSVC(C=c)\n",
    "    scores = cross_val_score(model, X, y, cv=KFold(n_splits=k_folds, shuffle=True))\n",
    "    average = scores.mean()\n",
    "    print(f'C = {c}: scores={scores}, average={average:.3f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "880b2a8c90f9e6beae80b56829e3f671fedd58b6d14887184ddce26124cedfbd"
  },
  "kernelspec": {
   "display_name": "Python 3.8.9 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
