{
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "!date\n",
        "!python --version"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4rCtQjelA2xv",
        "outputId": "bc344276-c318-44c5-dd72-3893d112f8cc"
      },
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mon Apr 21 06:58:06 AM UTC 2025\n",
            "Python 3.11.12\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MZ0AXI9nohgr"
      },
      "source": [
        "# 極性分類システムの構築例（Bag-of-Wordsベース）\n",
        "- 参考\n",
        "    - [CMU CS11-711 Advanced NLP](http://phontron.com/class/anlp2024/)\n",
        "\n",
        "## 本演習の目標\n",
        "- ルールベース（rule_based.ipynb）と機械学習との違いを理解する。\n",
        "\n",
        "## 実装方針\n",
        "特徴量を Bag-of-Words によるバイナリコーディング f(x) とする。学習器は、全ての単語に対する重み W による荷重和スコアを求め、スコアに基づき識別する。\n",
        "- 特徴量関数 $h = f(x)$: bag-of-words\n",
        "- $score = Wh = W * f(x)$"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_-v2CCY6ohgt"
      },
      "source": [
        "## A. データセット用意（コピペ）"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!curl -O https://ie.u-ryukyu.ac.jp/~tnal/2025/dm/static/r_assesment_sentiment.xlsx"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "kl00ZwZVA5HY",
        "outputId": "961de9a4-e2d8-47fc-b458-b81807cc9e8a"
      },
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
            "                                 Dload  Upload   Total   Spent    Left  Speed\n",
            "100 53514  100 53514    0     0  17962      0  0:00:02  0:00:02 --:--:-- 17957\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "uFAVexH9ohgt",
        "outputId": "55cf81be-1ef6-40a6-f983-a8b681f2dfad",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 289
        }
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "   title  grade  required     q_id                       comment  \\\n",
              "0  工業数学Ⅰ      1      True  Q21 (1)                          特になし   \n",
              "1  工業数学Ⅰ      1      True  Q21 (2)            正直わかりずらい。むだに間があるし。   \n",
              "2  工業数学Ⅰ      1      True  Q21 (2)          例題を取り入れて理解しやすくしてほしい。   \n",
              "3  工業数学Ⅰ      1      True  Q21 (2)                          特になし   \n",
              "4  工業数学Ⅰ      1      True  Q21 (2)  スライドに書く文字をもう少しわかりやすくして欲しいです。   \n",
              "\n",
              "                                    wakati1  \\\n",
              "0                                     特に なし   \n",
              "1             正直 わかり ず らい 。 むだ に 間 が ある し 。   \n",
              "2            例題 を 取り入れ て 理解 し やすく し て ほしい 。   \n",
              "3                                     特に なし   \n",
              "4  スライド に 書く 文字 を もう少し わかり やすく し て 欲しい です 。   \n",
              "\n",
              "                                     wakati2  sentiment  \n",
              "0                                      特に ない          0  \n",
              "1              正直 わかる ぬ らい 。 むだ に 間 が ある し 。         -1  \n",
              "2          例題 を 取り入れる て 理解 する やすい する て ほしい 。         -1  \n",
              "3                                      特に ない          0  \n",
              "4  スライド に 書く 文字 を もう少し わかる やすい する て 欲しい です 。         -1  "
            ],
            "text/html": [
              "\n",
              "  <div id=\"df-2ba0bcd6-1798-480c-8c11-4fed4a81a23a\" class=\"colab-df-container\">\n",
              "    <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>title</th>\n",
              "      <th>grade</th>\n",
              "      <th>required</th>\n",
              "      <th>q_id</th>\n",
              "      <th>comment</th>\n",
              "      <th>wakati1</th>\n",
              "      <th>wakati2</th>\n",
              "      <th>sentiment</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>工業数学Ⅰ</td>\n",
              "      <td>1</td>\n",
              "      <td>True</td>\n",
              "      <td>Q21 (1)</td>\n",
              "      <td>特になし</td>\n",
              "      <td>特に なし</td>\n",
              "      <td>特に ない</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>工業数学Ⅰ</td>\n",
              "      <td>1</td>\n",
              "      <td>True</td>\n",
              "      <td>Q21 (2)</td>\n",
              "      <td>正直わかりずらい。むだに間があるし。</td>\n",
              "      <td>正直 わかり ず らい 。 むだ に 間 が ある し 。</td>\n",
              "      <td>正直 わかる ぬ らい 。 むだ に 間 が ある し 。</td>\n",
              "      <td>-1</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>工業数学Ⅰ</td>\n",
              "      <td>1</td>\n",
              "      <td>True</td>\n",
              "      <td>Q21 (2)</td>\n",
              "      <td>例題を取り入れて理解しやすくしてほしい。</td>\n",
              "      <td>例題 を 取り入れ て 理解 し やすく し て ほしい 。</td>\n",
              "      <td>例題 を 取り入れる て 理解 する やすい する て ほしい 。</td>\n",
              "      <td>-1</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>工業数学Ⅰ</td>\n",
              "      <td>1</td>\n",
              "      <td>True</td>\n",
              "      <td>Q21 (2)</td>\n",
              "      <td>特になし</td>\n",
              "      <td>特に なし</td>\n",
              "      <td>特に ない</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>工業数学Ⅰ</td>\n",
              "      <td>1</td>\n",
              "      <td>True</td>\n",
              "      <td>Q21 (2)</td>\n",
              "      <td>スライドに書く文字をもう少しわかりやすくして欲しいです。</td>\n",
              "      <td>スライド に 書く 文字 を もう少し わかり やすく し て 欲しい です 。</td>\n",
              "      <td>スライド に 書く 文字 を もう少し わかる やすい する て 欲しい です 。</td>\n",
              "      <td>-1</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "    <div class=\"colab-df-buttons\">\n",
              "\n",
              "  <div class=\"colab-df-container\">\n",
              "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-2ba0bcd6-1798-480c-8c11-4fed4a81a23a')\"\n",
              "            title=\"Convert this dataframe to an interactive table.\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
              "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
              "  </svg>\n",
              "    </button>\n",
              "\n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    .colab-df-buttons div {\n",
              "      margin-bottom: 4px;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "    <script>\n",
              "      const buttonEl =\n",
              "        document.querySelector('#df-2ba0bcd6-1798-480c-8c11-4fed4a81a23a button.colab-df-convert');\n",
              "      buttonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "      async function convertToInteractive(key) {\n",
              "        const element = document.querySelector('#df-2ba0bcd6-1798-480c-8c11-4fed4a81a23a');\n",
              "        const dataTable =\n",
              "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                    [key], {});\n",
              "        if (!dataTable) return;\n",
              "\n",
              "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "          + ' to learn more about interactive tables.';\n",
              "        element.innerHTML = '';\n",
              "        dataTable['output_type'] = 'display_data';\n",
              "        await google.colab.output.renderOutput(dataTable, element);\n",
              "        const docLink = document.createElement('div');\n",
              "        docLink.innerHTML = docLinkHtml;\n",
              "        element.appendChild(docLink);\n",
              "      }\n",
              "    </script>\n",
              "  </div>\n",
              "\n",
              "\n",
              "<div id=\"df-8264945c-eaa0-400e-a8a4-8b207f707291\">\n",
              "  <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-8264945c-eaa0-400e-a8a4-8b207f707291')\"\n",
              "            title=\"Suggest charts\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "     width=\"24px\">\n",
              "    <g>\n",
              "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
              "    </g>\n",
              "</svg>\n",
              "  </button>\n",
              "\n",
              "<style>\n",
              "  .colab-df-quickchart {\n",
              "      --bg-color: #E8F0FE;\n",
              "      --fill-color: #1967D2;\n",
              "      --hover-bg-color: #E2EBFA;\n",
              "      --hover-fill-color: #174EA6;\n",
              "      --disabled-fill-color: #AAA;\n",
              "      --disabled-bg-color: #DDD;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart {\n",
              "      --bg-color: #3B4455;\n",
              "      --fill-color: #D2E3FC;\n",
              "      --hover-bg-color: #434B5C;\n",
              "      --hover-fill-color: #FFFFFF;\n",
              "      --disabled-bg-color: #3B4455;\n",
              "      --disabled-fill-color: #666;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart {\n",
              "    background-color: var(--bg-color);\n",
              "    border: none;\n",
              "    border-radius: 50%;\n",
              "    cursor: pointer;\n",
              "    display: none;\n",
              "    fill: var(--fill-color);\n",
              "    height: 32px;\n",
              "    padding: 0;\n",
              "    width: 32px;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart:hover {\n",
              "    background-color: var(--hover-bg-color);\n",
              "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "    fill: var(--button-hover-fill-color);\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart-complete:disabled,\n",
              "  .colab-df-quickchart-complete:disabled:hover {\n",
              "    background-color: var(--disabled-bg-color);\n",
              "    fill: var(--disabled-fill-color);\n",
              "    box-shadow: none;\n",
              "  }\n",
              "\n",
              "  .colab-df-spinner {\n",
              "    border: 2px solid var(--fill-color);\n",
              "    border-color: transparent;\n",
              "    border-bottom-color: var(--fill-color);\n",
              "    animation:\n",
              "      spin 1s steps(1) infinite;\n",
              "  }\n",
              "\n",
              "  @keyframes spin {\n",
              "    0% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "      border-left-color: var(--fill-color);\n",
              "    }\n",
              "    20% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    30% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    40% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    60% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    80% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "    90% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "  }\n",
              "</style>\n",
              "\n",
              "  <script>\n",
              "    async function quickchart(key) {\n",
              "      const quickchartButtonEl =\n",
              "        document.querySelector('#' + key + ' button');\n",
              "      quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
              "      quickchartButtonEl.classList.add('colab-df-spinner');\n",
              "      try {\n",
              "        const charts = await google.colab.kernel.invokeFunction(\n",
              "            'suggestCharts', [key], {});\n",
              "      } catch (error) {\n",
              "        console.error('Error during call to suggestCharts:', error);\n",
              "      }\n",
              "      quickchartButtonEl.classList.remove('colab-df-spinner');\n",
              "      quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
              "    }\n",
              "    (() => {\n",
              "      let quickchartButtonEl =\n",
              "        document.querySelector('#df-8264945c-eaa0-400e-a8a4-8b207f707291 button');\n",
              "      quickchartButtonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "    })();\n",
              "  </script>\n",
              "</div>\n",
              "\n",
              "    </div>\n",
              "  </div>\n"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "dataframe",
              "variable_name": "assesment_df",
              "summary": "{\n  \"name\": \"assesment_df\",\n  \"rows\": 170,\n  \"fields\": [\n    {\n      \"column\": \"title\",\n      \"properties\": {\n        \"dtype\": \"category\",\n        \"num_unique_values\": 16,\n        \"samples\": [\n          \"\\u5de5\\u696d\\u6570\\u5b66\\u2160\",\n          \"\\u6280\\u8853\\u8005\\u306e\\u502b\\u7406\",\n          \"\\u30a2\\u30eb\\u30b4\\u30ea\\u30ba\\u30e0\\u3068\\u30c7\\u30fc\\u30bf\\u69cb\\u9020\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"grade\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0,\n        \"min\": 1,\n        \"max\": 3,\n        \"num_unique_values\": 3,\n        \"samples\": [\n          1,\n          2,\n          3\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"required\",\n      \"properties\": {\n        \"dtype\": \"boolean\",\n        \"num_unique_values\": 2,\n        \"samples\": [\n          false,\n          true\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"q_id\",\n      \"properties\": {\n        \"dtype\": \"category\",\n        \"num_unique_values\": 5,\n        \"samples\": [\n          \"Q21 (2)\",\n          \"Q22\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"comment\",\n      \"properties\": {\n        \"dtype\": \"string\",\n        \"num_unique_values\": 153,\n        \"samples\": [\n          \"\\u30fb\\u6559\\u79d1\\u66f8\\u304c\\u5fc5\\u8981\\u306a\\u306e\\u304b\\u5fc5\\u8981\\u3067\\u306a\\u3044\\u306e\\u304b\\u304c\\u66d6\\u6627\\u306a\\u307e\\u307e\\u6388\\u696d\\u304c\\u59cb\\u307e\\u308a\\u3001\\u975e\\u5e38\\u306b\\u4e0d\\u5b89\\u3060\\u3063\\u305f\\u305f\\u3081\\u3001\\u6559\\u79d1\\u66f8\\u304c\\u5fc5\\u9808\\u304b\\u305d\\u3046\\u3067\\u306a\\u3044\\u306e\\u304b\\u306f\\u6700\\u521d\\u306b\\u306f\\u3063\\u304d\\u308a\\u3057\\u3066\\u6b32\\u3057\\u3044\\u3002\\n\\u30fb\\u8ab2\\u984c\\u3092\\u51fa\\u3059\\u3060\\u3051\\u51fa\\u3055\\u305b\\u3066\\u304a\\u3044\\u3066\\u3001\\u63a1\\u70b9\\u3082\\u305b\\u305a\\u3001\\u3069\\u3046\\u3044\\u3063\\u305f\\u89e3\\u7b54\\u304c\\u6b63\\u3057\\u3044\\u306e\\u304b\\u3068\\u3044\\u3063\\u305f\\u6307\\u91dd\\u3082\\u51fa\\u3059\\u306e\\u304c\\u3068\\u3066\\u3082\\u9045\\u3044\\u3002\\u8ab2\\u984c\\u306f\\u89e3\\u304f\\u3060\\u3051\\u3067\\u306f\\u77e5\\u8b58\\u306e\\u5b9a\\u7740\\u306b\\u3064\\u306a\\u304c\\u3089\\u306a\\u3044\\u3068\\u601d\\u3044\\u307e\\u3059\\u304c\\u3001\\u305d\\u3053\\u3089\\u3078\\u3093\\u306f\\u3069\\u3046\\u306a\\u3093\\u3067\\u3057\\u3087\\u3046\\u304b\\u3002\\n\\u30fb\\u914d\\u5e03\\u8cc7\\u6599\\u3068\\u3057\\u3066\\u3001\\u904e\\u53bb\\u554f\\u3082\\u914d\\u5e03\\u3057\\u3066\\u304f\\u308c\\u308b\\u3068\\u3068\\u3066\\u3082\\u52a9\\u304b\\u308b\\u306a\\u3001\\u3068\\u601d\\u3044\\u307e\\u3059\\u3002\\u3054\\u691c\\u8a0e\\u304a\\u9858\\u3044\\u3057\\u307e\\u3059\\u3002\",\n          \"\\u30fb\\u4e2d\\u9593\\u30c6\\u30b9\\u30c8\\u3092\\u5ef6\\u671f\\u3057\\u7d9a\\u3051\\u3001\\u6700\\u7d42\\u7684\\u306b\\u4e2d\\u9593\\u30fb\\u671f\\u672b\\u8a66\\u9a13\\u3092\\uff12\\u9031\\u7d9a\\u3051\\u3066\\u3084\\u308b\\u3053\\u3068\\u3068\\u306a\\u308a\\u3001\\u8a08\\u753b\\u6027\\u304c\\u6b20\\u3051\\u3066\\u3044\\u308b\\u3002\\n\\u30fb\\u914d\\u5e03\\u8cc7\\u6599\\u306e\\u8aa4\\u5b57\\u8131\\u5b57\\u304c\\u591a\\u3059\\u304e\\u308b\\u3002\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"wakati1\",\n      \"properties\": {\n        \"dtype\": \"string\",\n        \"num_unique_values\": 153,\n        \"samples\": [\n          \"\\u30fb \\u6559\\u79d1\\u66f8 \\u304c \\u5fc5\\u8981 \\u306a \\u306e \\u304b \\u5fc5\\u8981 \\u3067 \\u306a\\u3044 \\u306e \\u304b \\u304c \\u66d6\\u6627 \\u306a \\u307e\\u307e \\u6388\\u696d \\u304c \\u59cb\\u307e\\u308a \\u3001 \\u975e\\u5e38 \\u306b \\u4e0d\\u5b89 \\u3060\\u3063 \\u305f \\u305f\\u3081 \\u3001 \\u6559\\u79d1\\u66f8 \\u304c \\u5fc5\\u9808 \\u304b \\u305d\\u3046 \\u3067 \\u306a\\u3044 \\u306e \\u304b \\u306f \\u6700\\u521d \\u306b \\u306f\\u3063\\u304d\\u308a \\u3057 \\u3066 \\u6b32\\u3057\\u3044 \\u3002 \\n \\u30fb \\u8ab2\\u984c \\u3092 \\u51fa\\u3059 \\u3060\\u3051 \\u51fa\\u3055 \\u305b \\u3066 \\u304a\\u3044 \\u3066 \\u3001 \\u63a1\\u70b9 \\u3082 \\u305b \\u305a \\u3001 \\u3069\\u3046 \\u3044\\u3063 \\u305f \\u89e3\\u7b54 \\u304c \\u6b63\\u3057\\u3044 \\u306e \\u304b \\u3068\\u3044\\u3063\\u305f \\u6307\\u91dd \\u3082 \\u51fa\\u3059 \\u306e \\u304c \\u3068\\u3066\\u3082 \\u9045\\u3044 \\u3002 \\u8ab2\\u984c \\u306f \\u89e3\\u304f \\u3060\\u3051 \\u3067 \\u306f \\u77e5\\u8b58 \\u306e \\u5b9a\\u7740 \\u306b \\u3064\\u306a\\u304c\\u3089 \\u306a\\u3044 \\u3068 \\u601d\\u3044 \\u307e\\u3059 \\u304c \\u3001 \\u305d\\u3053\\u3089 \\u3078\\u3093 \\u306f \\u3069\\u3046 \\u306a \\u3093 \\u3067\\u3057\\u3087 \\u3046 \\u304b \\u3002 \\n \\u30fb \\u914d\\u5e03 \\u8cc7\\u6599 \\u3068\\u3057\\u3066 \\u3001 \\u904e\\u53bb\\u554f \\u3082 \\u914d\\u5e03 \\u3057 \\u3066 \\u304f\\u308c\\u308b \\u3068 \\u3068\\u3066\\u3082 \\u52a9\\u304b\\u308b \\u306a \\u3001 \\u3068 \\u601d\\u3044 \\u307e\\u3059 \\u3002 \\u3054 \\u691c\\u8a0e \\u304a\\u9858\\u3044 \\u3057 \\u307e\\u3059 \\u3002\",\n          \"\\u30fb \\u4e2d\\u9593\\u30c6\\u30b9\\u30c8 \\u3092 \\u5ef6\\u671f \\u3057 \\u7d9a\\u3051 \\u3001 \\u6700\\u7d42\\u7684 \\u306b \\u4e2d\\u9593 \\u30fb \\u671f\\u672b\\u8a66\\u9a13 \\u3092 \\uff12 \\u9031 \\u7d9a\\u3051 \\u3066 \\u3084\\u308b \\u3053\\u3068 \\u3068 \\u306a\\u308a \\u3001 \\u8a08\\u753b \\u6027 \\u304c \\u6b20\\u3051 \\u3066 \\u3044\\u308b \\u3002 \\n \\u30fb \\u914d\\u5e03 \\u8cc7\\u6599 \\u306e \\u8aa4\\u5b57\\u8131\\u5b57 \\u304c \\u591a \\u3059\\u304e\\u308b \\u3002\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"wakati2\",\n      \"properties\": {\n        \"dtype\": \"string\",\n        \"num_unique_values\": 153,\n        \"samples\": [\n          \"\\u30fb \\u6559\\u79d1\\u66f8 \\u304c \\u5fc5\\u8981 \\u3060 \\u306e \\u304b \\u5fc5\\u8981 \\u3060 \\u306a\\u3044 \\u306e \\u304b \\u304c \\u66d6\\u6627 \\u3060 \\u307e\\u307e \\u6388\\u696d \\u304c \\u59cb\\u307e\\u308b \\u3001 \\u975e\\u5e38 \\u306b \\u4e0d\\u5b89 \\u3060 \\u305f \\u305f\\u3081 \\u3001 \\u6559\\u79d1\\u66f8 \\u304c \\u5fc5\\u9808 \\u304b \\u305d\\u3046 \\u3060 \\u306a\\u3044 \\u306e \\u304b \\u306f \\u6700\\u521d \\u306b \\u306f\\u3063\\u304d\\u308a \\u3059\\u308b \\u3066 \\u6b32\\u3057\\u3044 \\u3002 * \\u30fb \\u8ab2\\u984c \\u3092 \\u51fa\\u3059 \\u3060\\u3051 \\u51fa\\u3059 \\u305b\\u308b \\u3066 \\u304a\\u304f \\u3066 \\u3001 \\u63a1\\u70b9 \\u3082 \\u3059\\u308b \\u306c \\u3001 \\u3069\\u3046 \\u3044\\u3046 \\u305f \\u89e3\\u7b54 \\u304c \\u6b63\\u3057\\u3044 \\u306e \\u304b \\u3068\\u3044\\u3063\\u305f \\u6307\\u91dd \\u3082 \\u51fa\\u3059 \\u306e \\u304c \\u3068\\u3066\\u3082 \\u9045\\u3044 \\u3002 \\u8ab2\\u984c \\u306f \\u89e3\\u304f \\u3060\\u3051 \\u3060 \\u306f \\u77e5\\u8b58 \\u306e \\u5b9a\\u7740 \\u306b \\u3064\\u306a\\u304c\\u308b \\u306a\\u3044 \\u3068 \\u601d\\u3046 \\u307e\\u3059 \\u304c \\u3001 \\u305d\\u3053\\u3089 \\u3078\\u3093 \\u306f \\u3069\\u3046 \\u3060 \\u3093 \\u3067\\u3059 \\u3046 \\u304b \\u3002 * \\u30fb \\u914d\\u5e03 \\u8cc7\\u6599 \\u3068\\u3057\\u3066 \\u3001 \\u904e\\u53bb\\u554f \\u3082 \\u914d\\u5e03 \\u3059\\u308b \\u3066 \\u304f\\u308c\\u308b \\u3068 \\u3068\\u3066\\u3082 \\u52a9\\u304b\\u308b \\u306a \\u3001 \\u3068 \\u601d\\u3046 \\u307e\\u3059 \\u3002 \\u3054 \\u691c\\u8a0e \\u304a\\u9858\\u3044 \\u3059\\u308b \\u307e\\u3059 \\u3002\",\n          \"\\u30fb \\u4e2d\\u9593\\u30c6\\u30b9\\u30c8 \\u3092 \\u5ef6\\u671f \\u3059\\u308b \\u7d9a\\u3051\\u308b \\u3001 \\u6700\\u7d42\\u7684 \\u306b \\u4e2d\\u9593 \\u30fb \\u671f\\u672b\\u8a66\\u9a13 \\u3092 \\uff12 \\u9031 \\u7d9a\\u3051\\u308b \\u3066 \\u3084\\u308b \\u3053\\u3068 \\u3068 \\u306a\\u308b \\u3001 \\u8a08\\u753b \\u6027 \\u304c \\u6b20\\u3051\\u308b \\u3066 \\u3044\\u308b \\u3002 * \\u30fb \\u914d\\u5e03 \\u8cc7\\u6599 \\u306e \\u8aa4\\u5b57\\u8131\\u5b57 \\u304c \\u591a\\u3044 \\u3059\\u304e\\u308b \\u3002\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"sentiment\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0,\n        \"min\": -1,\n        \"max\": 1,\n        \"num_unique_values\": 3,\n        \"samples\": [\n          0,\n          -1\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    }\n  ]\n}"
            }
          },
          "metadata": {},
          "execution_count": 3
        }
      ],
      "source": [
        "import pandas as pd\n",
        "\n",
        "filename = \"r_assesment_sentiment.xlsx\"\n",
        "assesment_df = pd.read_excel(filename)\n",
        "assesment_df.head()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "MpuBm4Hiohgv",
        "outputId": "9bd2776a-21f8-49b7-f953-29c13ccf6571",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "x_data[0]='特に なし', type(x_data[0])=<class 'str'>\n",
            "y_data[0]=0, type(y_data[0])=<class 'int'>\n"
          ]
        }
      ],
      "source": [
        "x_data = list(assesment_df['wakati1'])\n",
        "y_data = list(assesment_df['sentiment'])\n",
        "\n",
        "print(f\"{x_data[0]=}, {type(x_data[0])=}\")\n",
        "print(f\"{y_data[0]=}, {type(y_data[0])=}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "aGeqd5tMohgv",
        "outputId": "2406f8f9-6aa8-4c34-835d-959a406286fb",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "len(X_train)=136, len(y_train)=136\n",
            "len(X_test)=34, len(y_test)=34\n",
            "X_train[0]='python の 内容 は 予想 を 上回る ほど の 量 だっ た ので 、 まだ 理解度 が 完璧 と は 言え ない 状況 です 。 夏休み は 復習 を し て 、 ２ 学期 から また 新しい 言語 を 学ん で いき たい と 思い ます 。', y_train[0]=0\n"
          ]
        }
      ],
      "source": [
        "# 学習用データ、テスト用データに分割\n",
        "from sklearn.model_selection import train_test_split\n",
        "\n",
        "# train_size = 学習用データの割合。\n",
        "# random_state = 疑似乱数生成するためのシード値。\n",
        "#   シード値を固定しておくと「シャッフルするけど毎回同じシャッフル結果」を利用できる。\n",
        "#   結果を再現できるため、動作確認や失敗分析をし易い。\n",
        "# shuffle = シャッフするなら True。\n",
        "X_train, X_test, y_train, y_test = train_test_split(x_data, y_data, train_size=0.8, random_state=1, shuffle=True)\n",
        "print(f\"{len(X_train)=}, {len(y_train)=}\")\n",
        "print(f\"{len(X_test)=}, {len(y_test)=}\")\n",
        "print(f\"{X_train[0]=}, {y_train[0]=}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fWLoqB6Fohgv"
      },
      "source": [
        "## モジュール読み込み\n",
        "今回はどちらもカットして問題ないが、多くの実装で利用されるため使っています。\n",
        "- random: データセットをシャッフルするために利用。\n",
        "- tqdm: プログレス・バー（進捗状況）を表示するために利用。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "oKbDiGSJohgw"
      },
      "outputs": [],
      "source": [
        "import random\n",
        "import tqdm"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vbCbaQFwohgw"
      },
      "source": [
        "## B. 特徴抽出（変更あり）\n",
        "[Bag-of-Words](https://en.wikipedia.org/wiki/Bag-of-words_model)により特徴表現。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "MtJKFVvTohgw",
        "outputId": "272ca07e-e6a7-4c7a-de9d-f2c49bd43415",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "text='社会人 に 向け て の これから を 考える いい 機会 に なっ た ので よかっ た です 。'\n",
            "features={'社会人': 1.0, 'に': 2.0, '向け': 1.0, 'て': 1.0, 'の': 1.0, 'これから': 1.0, 'を': 1.0, '考える': 1.0, 'いい': 1.0, '機会': 1.0, 'なっ': 1.0, 'た': 2.0, 'ので': 1.0, 'よかっ': 1.0, 'です': 1.0, '。': 1.0}\n"
          ]
        }
      ],
      "source": [
        "def extract_features(x: str) -> dict[str, float]:\n",
        "    features = {}\n",
        "    x_split = x.split(' ')\n",
        "    for x in x_split:\n",
        "        features[x] = features.get(x, 0) + 1.0\n",
        "    return features\n",
        "\n",
        "# 実行例\n",
        "text = X_train[8]\n",
        "features = extract_features(text)\n",
        "print(f\"{text=}\")\n",
        "print(f\"{features=}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IU8K9im6ohgw"
      },
      "source": [
        "## C. 分類器構築（少し修正）\n",
        "ルールベースでは、(1) good_words, bad_words, bias の3種類の重みを用意し、(2) それぞれの出現数に掛け合わせた総和によりスコアを求め、(3) しきい値処理により推定するという流れで処理した。\n",
        "\n",
        "BoWベースでは、(1) 全ての単語に対して異なる重み（初期値0）を用意し、(2) それぞれの出現数に掛け合わせた総和によりスコアを求め、(3) しきい値処理により推定する。変更点は(1)のみ。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "Y69GZWPlohgw",
        "outputId": "fab44eb5-3da8-4d0a-d280-41408e9d3e33",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "X_train[i]='python の 内容 は 予想 を 上回る ほど の 量 だっ た ので 、 まだ 理解度 が 完璧 と は 言え ない 状況 です 。 夏休み は 復習 を し て 、 ２ 学期 から また 新しい 言語 を 学ん で いき たい と 思い ます 。'\n",
            "features={'python': 1.0, 'の': 2.0, '内容': 1.0, 'は': 3.0, '予想': 1.0, 'を': 3.0, '上回る': 1.0, 'ほど': 1.0, '量': 1.0, 'だっ': 1.0, 'た': 1.0, 'ので': 1.0, '、': 2.0, 'まだ': 1.0, '理解度': 1.0, 'が': 1.0, '完璧': 1.0, 'と': 2.0, '言え': 1.0, 'ない': 1.0, '状況': 1.0, 'です': 1.0, '。': 2.0, '夏休み': 1.0, '復習': 1.0, 'し': 1.0, 'て': 1.0, '２': 1.0, '学期': 1.0, 'から': 1.0, 'また': 1.0, '新しい': 1.0, '言語': 1.0, '学ん': 1.0, 'で': 1.0, 'いき': 1.0, 'たい': 1.0, '思い': 1.0, 'ます': 1.0}\n",
            "score=0.0, estimated_label=0, true_label=0\n",
            "---\n",
            "X_train[i]='特に なし'\n",
            "features={'特に': 1.0, 'なし': 1.0}\n",
            "score=0.0, estimated_label=0, true_label=0\n",
            "---\n",
            "X_train[i]='配布 資料 が 教科書 の 内容 に 沿っ て おり 、 わかり やすかっ た 。'\n",
            "features={'配布': 1.0, '資料': 1.0, 'が': 1.0, '教科書': 1.0, 'の': 1.0, '内容': 1.0, 'に': 1.0, '沿っ': 1.0, 'て': 1.0, 'おり': 1.0, '、': 1.0, 'わかり': 1.0, 'やすかっ': 1.0, 'た': 1.0, '。': 1.0}\n",
            "score=0.0, estimated_label=0, true_label=1\n",
            "---\n",
            "X_train[i]='Zoom の 音声 、 資料 画像 の 画質 など 特に 問題 なく 授業 を 受け られ た 。'\n",
            "features={'Zoom': 1.0, 'の': 2.0, '音声': 1.0, '、': 1.0, '資料': 1.0, '画像': 1.0, '画質': 1.0, 'など': 1.0, '特に': 1.0, '問題': 1.0, 'なく': 1.0, '授業': 1.0, 'を': 1.0, '受け': 1.0, 'られ': 1.0, 'た': 1.0, '。': 1.0}\n",
            "score=0.0, estimated_label=0, true_label=1\n",
            "---\n",
            "X_train[i]='たまに 説明 が ない コード が あっ たり し た ので 少し 戸惑っ た 。 いずれ はやっ て いく もの で は ある が 、 、 、'\n",
            "features={'たまに': 1.0, '説明': 1.0, 'が': 3.0, 'ない': 1.0, 'コード': 1.0, 'あっ': 1.0, 'たり': 1.0, 'し': 1.0, 'た': 2.0, 'ので': 1.0, '少し': 1.0, '戸惑っ': 1.0, '。': 1.0, 'いずれ': 1.0, 'はやっ': 1.0, 'て': 1.0, 'いく': 1.0, 'もの': 1.0, 'で': 1.0, 'は': 1.0, 'ある': 1.0, '、': 3.0}\n",
            "score=0.0, estimated_label=0, true_label=-1\n",
            "---\n"
          ]
        }
      ],
      "source": [
        "# 全ての重みを0に初期化\n",
        "feature_weights = {}\n",
        "\n",
        "def run_classifier(features: dict[str, float]) -> int:\n",
        "    '''入力された特徴辞書の極性を推定する。\n",
        "\n",
        "    入力 (features)：特徴辞書。\n",
        "    出力1 (int): 推定ラベル: 良い評価(1)、悪い評価(-1)、どちらでもない(0)。\n",
        "    出力2 (score): 算出スコア。\n",
        "    '''\n",
        "    score = 0\n",
        "    for feat_name, feat_value in features.items():\n",
        "        score = score + feat_value * feature_weights.get(feat_name, 0)\n",
        "    if score > 0:\n",
        "        return 1, score\n",
        "    elif score < 0:\n",
        "        return -1, score\n",
        "    else:\n",
        "        return 0, score\n",
        "\n",
        "for i in range(5):\n",
        "    print(f\"{X_train[i]=}\")\n",
        "    features = extract_features(X_train[i])\n",
        "    print(f\"{features=}\")\n",
        "    estimated_label, score = run_classifier(features)\n",
        "    true_label = y_train[i]\n",
        "    print(f\"{score=}, {estimated_label=}, {true_label=}\")\n",
        "    print(\"---\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gtcxuz0Aohgx"
      },
      "source": [
        "## New: 学習\n",
        "単語に対する重みを学習により求める。学習アルゴリズムは以下の通り。\n",
        "- もし推測結果が正しいなら、何もしない（重みを更新しない）。\n",
        "- もし推測結果が誤りなら、全ての特徴量を ``新しい重み = 現在の重み + y * 特徴量`` で更新する。\n",
        "    - case 1: 正解が1で、-1と誤った場合。\n",
        "        - ``新しい重み = 現在の重み + 特徴量``\n",
        "        - 特徴量が加算される。これにより重みがより正の方向に修正され、スコアが0より大きな値になりやすくなる。\n",
        "    - case 2: 正解が-1で、1と誤った場合。\n",
        "        - ``新しい重み = 現在の重み - 特徴量``\n",
        "        - 特徴量が減算される。これにより重みがより負の方向に修正され、スコアが0より大きな値になりやすくなる。\n",
        "\n",
        "補足\n",
        "- どちらでもない(0)に対する学習は行っていない（省略）。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "jmc0n5Gcohgx"
      },
      "outputs": [],
      "source": [
        "# E. 性能評価関数（コピペ）\n",
        "def calculate_accuracy(x_data: list[str], y_data: list[int]) -> float:\n",
        "    total_number = 0\n",
        "    correct_number = 0\n",
        "    for x, y in zip(x_data, y_data):\n",
        "        y_pred, score = run_classifier(extract_features(x))\n",
        "        total_number += 1\n",
        "        if y == y_pred:\n",
        "            correct_number += 1\n",
        "    return correct_number / float(total_number)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "_jOSeHMsohgx",
        "outputId": "5d9b9717-39a9-48a2-9956-d70b9b490d09",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "before: train_accuracy=0.16912, test_accuracy=0.20588\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Epoch 1: 100%|██████████| 136/136 [00:00<00:00, 31922.62it/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "epoch=1: train_accuracy=0.78676, test_accuracy=0.52941\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Epoch 2: 100%|██████████| 136/136 [00:00<00:00, 43168.26it/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "epoch=2: train_accuracy=0.89706, test_accuracy=0.64706\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Epoch 3: 100%|██████████| 136/136 [00:00<00:00, 7456.54it/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "epoch=3: train_accuracy=0.80147, test_accuracy=0.67647\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Epoch 4: 100%|██████████| 136/136 [00:00<00:00, 11287.06it/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "epoch=4: train_accuracy=0.92647, test_accuracy=0.70588\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Epoch 5: 100%|██████████| 136/136 [00:00<00:00, 15123.02it/s]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "epoch=5: train_accuracy=0.94118, test_accuracy=0.67647"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Epoch 6: 100%|██████████| 136/136 [00:00<00:00, 42011.00it/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "epoch=6: train_accuracy=0.94118, test_accuracy=0.67647\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Epoch 7: 100%|██████████| 136/136 [00:00<00:00, 14828.95it/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "epoch=7: train_accuracy=0.88971, test_accuracy=0.73529\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Epoch 8: 100%|██████████| 136/136 [00:00<00:00, 34317.49it/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "epoch=8: train_accuracy=0.76471, test_accuracy=0.64706\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Epoch 9: 100%|██████████| 136/136 [00:00<00:00, 41737.42it/s]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "epoch=9: train_accuracy=0.81618, test_accuracy=0.67647"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Epoch 10: 100%|██████████| 136/136 [00:00<00:00, 44749.77it/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "epoch=10: train_accuracy=0.83088, test_accuracy=0.67647\n"
          ]
        }
      ],
      "source": [
        "# 全ての重みを0に初期化\n",
        "feature_weights = {}\n",
        "\n",
        "# 学習前のスコア\n",
        "train_accuracy = calculate_accuracy(X_train, y_train)\n",
        "test_accuracy = calculate_accuracy(X_test, y_test)\n",
        "print(f\"before: {train_accuracy=:.5f}, {test_accuracy=:.5f}\")\n",
        "\n",
        "# 学習\n",
        "NUM_EPOCHS = 10\n",
        "for epoch in range(1, NUM_EPOCHS+1):\n",
        "    # データセットをシャッフル\n",
        "    data_ids = list(range(len(X_train)))\n",
        "    random.shuffle(data_ids)\n",
        "\n",
        "    # サンプルごとの処理\n",
        "    for data_id in tqdm.tqdm(data_ids, desc=f'Epoch {epoch}'):\n",
        "        x = X_train[data_id]\n",
        "        y = y_train[data_id]\n",
        "\n",
        "        if y == 0: # 「どちらでもない(0)」ケースはスキップ。\n",
        "            continue\n",
        "\n",
        "        # 予測\n",
        "        features = extract_features(x)\n",
        "        predicted_y, score = run_classifier(features)\n",
        "\n",
        "        # 予測結果が誤り時の重み更新処理\n",
        "        if predicted_y != y:\n",
        "            for feature in features:\n",
        "                feature_weights[feature] = feature_weights.get(feature, 0) + y * features[feature]\n",
        "                #print(f\"{feature_weights=}\")\n",
        "\n",
        "    train_accuracy = calculate_accuracy(X_train, y_train)\n",
        "    test_accuracy = calculate_accuracy(X_test, y_test)\n",
        "    print(f\"{epoch=}: {train_accuracy=:.5f}, {test_accuracy=:.5f}\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SpcHGrlZohgx"
      },
      "source": [
        "## D. 性能評価（コピペ）\n",
        "関数定義は既に定義済みなので、ここでは実行コードのみコピペ。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "5Kc6YnU2ohgx",
        "outputId": "d6848d1d-ed4b-45ac-cc21-bb334412c912",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "{-1: 9, 1: 18, 0: 7}\n",
            "Train accuracy: 0.8308823529411765\n",
            "Dev/test accuracy: 0.6764705882352942\n"
          ]
        }
      ],
      "source": [
        "label_count = {}\n",
        "for y in y_test:\n",
        "    if y not in label_count:\n",
        "        label_count[y] = 0\n",
        "    label_count[y] += 1\n",
        "print(label_count)\n",
        "\n",
        "train_accuracy = calculate_accuracy(X_train, y_train)\n",
        "test_accuracy = calculate_accuracy(X_test, y_test)\n",
        "print(f'Train accuracy: {train_accuracy}')\n",
        "print(f'Dev/test accuracy: {test_accuracy}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SzQwnPhsohgx"
      },
      "source": [
        "## E. 失敗分析（コピペ）"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "TkXOsbd5ohgy",
        "outputId": "a702b1e9-f18a-4ba7-e91c-0b74f610090b",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "特に なし\n",
            "true label: 0\n",
            "predicted label: 1\n",
            "\n",
            "特に なし\n",
            "true label: 0\n",
            "predicted label: 1\n",
            "\n",
            "特に なし\n",
            "true label: 0\n",
            "predicted label: 1\n",
            "\n",
            "特に なし\n",
            "true label: 0\n",
            "predicted label: 1\n",
            "\n",
            "特に なし\n",
            "true label: 0\n",
            "predicted label: 1\n",
            "\n"
          ]
        }
      ],
      "source": [
        "def find_errors(x_data, y_data):\n",
        "    error_ids = []\n",
        "    y_preds = []\n",
        "    for i, (x, y) in enumerate(zip(x_data, y_data)):\n",
        "        pred, score = run_classifier(extract_features(x))\n",
        "        y_preds.append(pred)\n",
        "        if y != y_preds[-1]:\n",
        "            error_ids.append(i)\n",
        "    for _ in range(5):\n",
        "        my_id = random.choice(error_ids)\n",
        "        x, y, y_pred = x_data[my_id], y_data[my_id], y_preds[my_id]\n",
        "        print(f'{x}\\ntrue label: {y}\\npredicted label: {y_pred}\\n')\n",
        "\n",
        "find_errors(X_train, y_train)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.6"
    },
    "colab": {
      "provenance": [],
      "toc_visible": true
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}