{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rR_Sipz7OuD_"
      },
      "source": [
        "# 事前学習済みモデルmT5を用いたMultilingual NMTの例\n",
        "Transformersモデルの一つ[101言語をカバーするT5モデルmT5](https://huggingface.co/docs/transformers/model_doc/mt5)を用いて、機械翻訳学習する流れを追いかけてみよう。\n",
        "\n",
        "- 全体の流れ\n",
        "    - 注意点\n",
        "    - 環境構築\n",
        "    - 関連モジュールの用意(import)\n",
        "    - Tokenizer, Modelの用意(事前学習済みモデルの用意)\n",
        "    - Transforersを使う際の Tips\n",
        "    - トークナイザの動作確認\n",
        "    - タスクに向けた専用トークンを追加\n",
        "    - データセットの準備\n",
        "    - ファインチューニング方針\n",
        "        - データセット前処理\n",
        "        - 前処理の動作確認\n",
        "    - ファインチューニング部分\n",
        "        - パラメータやモデル評価関数を準備\n",
        "        - ファインチューニング\n",
        "    - 学習中の損失推移\n",
        "    - ファインチューニングしたモデルで翻訳してみる\n",
        "- 参考\n",
        "    - [Hugging Face公式ドキュメント](https://huggingface.co/docs/transformers/index)\n",
        "    - [Transformers Notebooks](https://github.com/nlp-with-transformers/notebooks)\n",
        "    - 今回の原本(original): [Multilingual NMT mt5](https://github.com/ejmejm/multilingual-nmt-mt5)\n",
        "        - 英語ですが、動画解説あり。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tB0U0vedLoqF"
      },
      "source": [
        "## 注意点\n",
        "- 専用の仮想環境構築を推奨する。\n",
        "    - 機械学習ライブラリ[Hugging FaceのTransformers](https://huggingface.co/docs/transformers/installation)は、関連ライブラリのバージョン依存度が高いことから専用の仮想環境を構築することが推奨されている。自身の環境で試す際には venv で別途環境構築することを推奨する。\n",
        "- CUDA環境での実行を推奨する。\n",
        "    - 動作確認することを主眼としているため、系列長やエポック数を小さく設定している。それでもCPU環境では1エポックに1時間程度かかる。時間がない人はCUDA環境（Google Colab）での実行を推奨する。\n",
        "- CPUで実行する際にはコードの一部を編集する必要がある。\n",
        "    - コード中に ``.cuda()`` と付けている箇所は、CUDA環境が必須である。これを消せばCPU環境でも動作するようになる。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XfyEUmaoLoqG"
      },
      "source": [
        "## 環境構築\n",
        "仮想環境に入った状態で、下記のコメントアウトを外して実行しよう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "HC6D5XxBLoqG",
        "outputId": "1779f10a-dd5e-4379-f570-b85c56dc0ed4"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
            "Requirement already satisfied: transformers in /usr/local/lib/python3.7/dist-packages (4.20.1)\n",
            "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.7/dist-packages (0.1.96)\n",
            "Requirement already satisfied: datasets in /usr/local/lib/python3.7/dist-packages (2.3.2)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (6.0)\n",
            "Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.8.1)\n",
            "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.64.0)\n",
            "Requirement already satisfied: tokenizers!=0.11.3,<0.13,>=0.11.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.12.1)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2022.6.2)\n",
            "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.11.4)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.21.6)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.7.1)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (4.1.1)\n",
            "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.9)\n",
            "Requirement already satisfied: dill<0.3.6 in /usr/local/lib/python3.7/dist-packages (from datasets) (0.3.5.1)\n",
            "Requirement already satisfied: aiohttp in /usr/local/lib/python3.7/dist-packages (from datasets) (3.8.1)\n",
            "Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (6.0.1)\n",
            "Requirement already satisfied: fsspec[http]>=2021.05.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (2022.5.0)\n",
            "Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.7/dist-packages (from datasets) (0.18.0)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from datasets) (1.3.5)\n",
            "Requirement already satisfied: multiprocess in /usr/local/lib/python3.7/dist-packages (from datasets) (0.70.13)\n",
            "Requirement already satisfied: xxhash in /usr/local/lib/python3.7/dist-packages (from datasets) (3.0.0)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.25.11)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2022.6.15)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (21.4.0)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (1.3.0)\n",
            "Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (2.0.12)\n",
            "Requirement already satisfied: asynctest==0.13.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (0.13.0)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (6.0.2)\n",
            "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (4.0.2)\n",
            "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (1.2.0)\n",
            "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (1.7.2)\n",
            "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.8.0)\n",
            "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2022.1)\n",
            "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2.8.2)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)\n"
          ]
        }
      ],
      "source": [
        "!pip install transformers sentencepiece datasets"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rMRFG8sfLoqH"
      },
      "source": [
        "## 関連モジュールの用意"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F5m4W5h7cmDn"
      },
      "outputs": [],
      "source": [
        "from datasets import load_dataset\n",
        "#from google.colab import drive\n",
        "#from IPython.display import display\n",
        "#from IPython.html import widgets\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import seaborn as sns\n",
        "import torch\n",
        "from torch import optim\n",
        "from torch.nn import functional as F\n",
        "from transformers import AdamW, AutoModelForSeq2SeqLM, AutoTokenizer\n",
        "from transformers import get_linear_schedule_with_warmup\n",
        "from tqdm import tqdm_notebook\n",
        "\n",
        "sns.set()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1oMSfUL5LoqI"
      },
      "source": [
        "## Tokenizer, Modelの準備\n",
        "Transformersを利用するには大別して、(1)まっさらな状態から事前学習する、(2)事前学習済みモデルを用いてファインチューニングする、(3)ファインチューニング済みモデルを利用する、の3パターンが考えられる。\n",
        "\n",
        "ここでは(2)を体験してみよう。最終的にやりたいことは機械翻訳（日英、英日）だ。このタスクのために今回は事前学習済みモデルとして[mT5](https://huggingface.co/docs/transformers/model_doc/mt5)を準備する。mT5は “Text-to-Text Transfer Transformer” (T5) と呼ばれる Google が開発したモデルをベースとし、多言語（101言語）で事前学習されたモデルだ。\n",
        "\n",
        "- 参考\n",
        "  - 公式で用意されている事前学習済みモデルは[MODELS](https://huggingface.co/docs/transformers/model_doc/albert)を参照しよう。\n",
        "  - 一般開発者を含めたコミュニティで[公開されているモデル](https://huggingface.co/models)はここから検索しよう。\n",
        "\n",
        "```{note}\n",
        "事前学習済みモデルを利用する場合、事前学習時に用いたトークナイザ（≒分かち書き器）を利用する必要がある点に注意しよう。\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TmkmzYNOQ9xC",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "07b29c4b-87e7-45c8-9296-ae06d916cb96"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/transformers/convert_slow_tokenizer.py:435: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n",
            "  \"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option\"\n"
          ]
        }
      ],
      "source": [
        "model_repo = 'google/mt5-small'   # 事前学習済みモデル\n",
        "model_path = 'mt5_translation.pt' # ファインチューニングしたモデルを保存する際のファイル名\n",
        "max_seq_len = 20 # トークン上限数。計算機リソースや学習時間に余裕があるなら増やしてみよう。\n",
        "\n",
        "# トークナイザを準備\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_repo)\n",
        "\n",
        "# 事前学習済みモデルを準備\n",
        "model = AutoModelForSeq2SeqLM.from_pretrained(model_repo)\n",
        "\n",
        "# 環境がCUDA対応してるなら、以下を実行することで高速実行可能。\n",
        "# CPU実行したい場合にはコメントアウトしよう。\n",
        "model = model.cuda()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4zq6AHu4LoqK"
      },
      "source": [
        "## Transformers を使う際の Tips\n",
        "- 学習時に用いたトークナイザを用いる必要がある。\n",
        "    - あるトークナイザXでは ``this`` を ``id = 1`` として割り振っているとしよう。このトークナイザXを使って学習したモデルにとっては ``1 == this``である。しかしファインチューニング時に別のトークナイザを使ってしまうと、idがずれたり、存在しない事もありえる。このような問題を起こさないためには事前学習とファインチューニングで用いるトークナイザを合わせる必要がある。\n",
        "- モデルは、エンコードされた系列データでやり取りする。\n",
        "    - モデルに対する入出力は「エンコードされた系列データ」である。分かち書きされたトークンに対しユニークなtoken_idを割り振り（これをエンコードと呼ぶ）、token_idを並べた系列データをモデルは受け取り、処理結果も同様のtoken_id系列データとして出力する。以下を実行して確認してみよう。\n",
        "- 系列長は固定する必要がある。\n",
        "    - 例えば固定長10で学習したモデルに対しては、それ未満の系列データを入力する場合には不足分を埋める必要がある。このためのトークンを padding token と呼ぶ。逆に10を超える系列データを入力する場合には、事前にサイズ10で打ち切る、もしくは文末トークンを考慮し9で打ち切る必要がある。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ywjkwLOZLoqK"
      },
      "source": [
        "## トークナイザの動作確認"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "bELaZKe0LoqL",
        "outputId": "8d4de1da-839a-4ed1-e0d2-e085b2b79931"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "token_ids =  tensor([[1494,  339, 2978,  260,    1]], device='cuda:0')\n",
            "-----------\n",
            "model_out =  tensor([[     0, 250099,      1]], device='cuda:0')\n",
            "-----------\n",
            "output_text =  <pad> <extra_id_0></s>\n"
          ]
        }
      ],
      "source": [
        "# テキスト例\n",
        "example_input_str = 'This is test.'\n",
        "\n",
        "# tokenizer.encode() でエンコード。\n",
        "token_ids = tokenizer.encode(\n",
        "    example_input_str,          # 入力したいテキスト\n",
        "    return_tensors='pt').cuda() # PyTorchのテンソル型を指定\n",
        "print('token_ids = ', token_ids) \n",
        "print('-----------')\n",
        "\n",
        "# 試しにモデルに入力して、その結果を受け取ってみる。\n",
        "model_out = model.generate(token_ids)\n",
        "print('model_out = ', model_out)\n",
        "print('-----------')\n",
        "\n",
        "# 分かりづらいので、出力結果を元の文字に戻す。\n",
        "output_text = tokenizer.convert_tokens_to_string(\n",
        "    tokenizer.convert_ids_to_tokens(model_out[0]))\n",
        "print('output_text = ', output_text)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "id": "cAlK942ALoqM",
        "outputId": "9aed1557-f30d-456c-89ef-a403c7f90b5a"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'This is test.</s>'"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            }
          },
          "metadata": {},
          "execution_count": 5
        }
      ],
      "source": [
        "# token_ids をもとに戻してみる。\n",
        "# id=1 が </s> になっている。これは文末を表す特殊トークン。\n",
        "\n",
        "tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([1494,  339, 2978,  260,    1]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Hv11uIYbLoqN",
        "outputId": "bcd25a91-4d9b-48af-805a-0415c41dc4af"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "250100\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[('<pad>', 0),\n",
              " ('</s>', 1),\n",
              " ('<unk>', 2),\n",
              " ('<0x00>', 3),\n",
              " ('<0x01>', 4),\n",
              " ('<0x02>', 5),\n",
              " ('<0x03>', 6),\n",
              " ('<0x04>', 7),\n",
              " ('<0x05>', 8),\n",
              " ('<0x06>', 9),\n",
              " ('<0x07>', 10),\n",
              " ('<0x08>', 11),\n",
              " ('<0x09>', 12),\n",
              " ('<0x0A>', 13),\n",
              " ('<0x0B>', 14),\n",
              " ('<0x0C>', 15),\n",
              " ('<0x0D>', 16),\n",
              " ('<0x0E>', 17),\n",
              " ('<0x0F>', 18),\n",
              " ('<0x10>', 19),\n",
              " ('<0x11>', 20),\n",
              " ('<0x12>', 21),\n",
              " ('<0x13>', 22),\n",
              " ('<0x14>', 23),\n",
              " ('<0x15>', 24),\n",
              " ('<0x16>', 25),\n",
              " ('<0x17>', 26),\n",
              " ('<0x18>', 27),\n",
              " ('<0x19>', 28),\n",
              " ('<0x1A>', 29),\n",
              " ('<0x1B>', 30),\n",
              " ('<0x1C>', 31),\n",
              " ('<0x1D>', 32),\n",
              " ('<0x1E>', 33),\n",
              " ('<0x1F>', 34),\n",
              " ('<0x20>', 35),\n",
              " ('<0x21>', 36),\n",
              " ('<0x22>', 37),\n",
              " ('<0x23>', 38),\n",
              " ('<0x24>', 39),\n",
              " ('<0x25>', 40),\n",
              " ('<0x26>', 41),\n",
              " ('<0x27>', 42),\n",
              " ('<0x28>', 43),\n",
              " ('<0x29>', 44),\n",
              " ('<0x2A>', 45),\n",
              " ('<0x2B>', 46),\n",
              " ('<0x2C>', 47),\n",
              " ('<0x2D>', 48),\n",
              " ('<0x2E>', 49),\n",
              " ('<0x2F>', 50),\n",
              " ('<0x30>', 51),\n",
              " ('<0x31>', 52),\n",
              " ('<0x32>', 53),\n",
              " ('<0x33>', 54),\n",
              " ('<0x34>', 55),\n",
              " ('<0x35>', 56),\n",
              " ('<0x36>', 57),\n",
              " ('<0x37>', 58),\n",
              " ('<0x38>', 59),\n",
              " ('<0x39>', 60),\n",
              " ('<0x3A>', 61),\n",
              " ('<0x3B>', 62),\n",
              " ('<0x3C>', 63),\n",
              " ('<0x3D>', 64),\n",
              " ('<0x3E>', 65),\n",
              " ('<0x3F>', 66),\n",
              " ('<0x40>', 67),\n",
              " ('<0x41>', 68),\n",
              " ('<0x42>', 69),\n",
              " ('<0x43>', 70),\n",
              " ('<0x44>', 71),\n",
              " ('<0x45>', 72),\n",
              " ('<0x46>', 73),\n",
              " ('<0x47>', 74),\n",
              " ('<0x48>', 75),\n",
              " ('<0x49>', 76),\n",
              " ('<0x4A>', 77),\n",
              " ('<0x4B>', 78),\n",
              " ('<0x4C>', 79),\n",
              " ('<0x4D>', 80),\n",
              " ('<0x4E>', 81),\n",
              " ('<0x4F>', 82),\n",
              " ('<0x50>', 83),\n",
              " ('<0x51>', 84),\n",
              " ('<0x52>', 85),\n",
              " ('<0x53>', 86),\n",
              " ('<0x54>', 87),\n",
              " ('<0x55>', 88),\n",
              " ('<0x56>', 89),\n",
              " ('<0x57>', 90),\n",
              " ('<0x58>', 91),\n",
              " ('<0x59>', 92),\n",
              " ('<0x5A>', 93),\n",
              " ('<0x5B>', 94),\n",
              " ('<0x5C>', 95),\n",
              " ('<0x5D>', 96),\n",
              " ('<0x5E>', 97),\n",
              " ('<0x5F>', 98),\n",
              " ('<0x60>', 99),\n",
              " ('<0x61>', 100),\n",
              " ('<0x62>', 101),\n",
              " ('<0x63>', 102),\n",
              " ('<0x64>', 103),\n",
              " ('<0x65>', 104),\n",
              " ('<0x66>', 105),\n",
              " ('<0x67>', 106),\n",
              " ('<0x68>', 107),\n",
              " ('<0x69>', 108),\n",
              " ('<0x6A>', 109),\n",
              " ('<0x6B>', 110),\n",
              " ('<0x6C>', 111),\n",
              " ('<0x6D>', 112),\n",
              " ('<0x6E>', 113),\n",
              " ('<0x6F>', 114),\n",
              " ('<0x70>', 115),\n",
              " ('<0x71>', 116),\n",
              " ('<0x72>', 117),\n",
              " ('<0x73>', 118),\n",
              " ('<0x74>', 119),\n",
              " ('<0x75>', 120),\n",
              " ('<0x76>', 121),\n",
              " ('<0x77>', 122),\n",
              " ('<0x78>', 123),\n",
              " ('<0x79>', 124),\n",
              " ('<0x7A>', 125),\n",
              " ('<0x7B>', 126),\n",
              " ('<0x7C>', 127),\n",
              " ('<0x7D>', 128),\n",
              " ('<0x7E>', 129),\n",
              " ('<0x7F>', 130),\n",
              " ('<0x80>', 131),\n",
              " ('<0x81>', 132),\n",
              " ('<0x82>', 133),\n",
              " ('<0x83>', 134),\n",
              " ('<0x84>', 135),\n",
              " ('<0x85>', 136),\n",
              " ('<0x86>', 137),\n",
              " ('<0x87>', 138),\n",
              " ('<0x88>', 139),\n",
              " ('<0x89>', 140),\n",
              " ('<0x8A>', 141),\n",
              " ('<0x8B>', 142),\n",
              " ('<0x8C>', 143),\n",
              " ('<0x8D>', 144),\n",
              " ('<0x8E>', 145),\n",
              " ('<0x8F>', 146),\n",
              " ('<0x90>', 147),\n",
              " ('<0x91>', 148),\n",
              " ('<0x92>', 149),\n",
              " ('<0x93>', 150),\n",
              " ('<0x94>', 151),\n",
              " ('<0x95>', 152),\n",
              " ('<0x96>', 153),\n",
              " ('<0x97>', 154),\n",
              " ('<0x98>', 155),\n",
              " ('<0x99>', 156),\n",
              " ('<0x9A>', 157),\n",
              " ('<0x9B>', 158),\n",
              " ('<0x9C>', 159),\n",
              " ('<0x9D>', 160),\n",
              " ('<0x9E>', 161),\n",
              " ('<0x9F>', 162),\n",
              " ('<0xA0>', 163),\n",
              " ('<0xA1>', 164),\n",
              " ('<0xA2>', 165),\n",
              " ('<0xA3>', 166),\n",
              " ('<0xA4>', 167),\n",
              " ('<0xA5>', 168),\n",
              " ('<0xA6>', 169),\n",
              " ('<0xA7>', 170),\n",
              " ('<0xA8>', 171),\n",
              " ('<0xA9>', 172),\n",
              " ('<0xAA>', 173),\n",
              " ('<0xAB>', 174),\n",
              " ('<0xAC>', 175),\n",
              " ('<0xAD>', 176),\n",
              " ('<0xAE>', 177),\n",
              " ('<0xAF>', 178),\n",
              " ('<0xB0>', 179),\n",
              " ('<0xB1>', 180),\n",
              " ('<0xB2>', 181),\n",
              " ('<0xB3>', 182),\n",
              " ('<0xB4>', 183),\n",
              " ('<0xB5>', 184),\n",
              " ('<0xB6>', 185),\n",
              " ('<0xB7>', 186),\n",
              " ('<0xB8>', 187),\n",
              " ('<0xB9>', 188),\n",
              " ('<0xBA>', 189),\n",
              " ('<0xBB>', 190),\n",
              " ('<0xBC>', 191),\n",
              " ('<0xBD>', 192),\n",
              " ('<0xBE>', 193),\n",
              " ('<0xBF>', 194),\n",
              " ('<0xC0>', 195),\n",
              " ('<0xC1>', 196),\n",
              " ('<0xC2>', 197),\n",
              " ('<0xC3>', 198),\n",
              " ('<0xC4>', 199),\n",
              " ('<0xC5>', 200),\n",
              " ('<0xC6>', 201),\n",
              " ('<0xC7>', 202),\n",
              " ('<0xC8>', 203),\n",
              " ('<0xC9>', 204),\n",
              " ('<0xCA>', 205),\n",
              " ('<0xCB>', 206),\n",
              " ('<0xCC>', 207),\n",
              " ('<0xCD>', 208),\n",
              " ('<0xCE>', 209),\n",
              " ('<0xCF>', 210),\n",
              " ('<0xD0>', 211),\n",
              " ('<0xD1>', 212),\n",
              " ('<0xD2>', 213),\n",
              " ('<0xD3>', 214),\n",
              " ('<0xD4>', 215),\n",
              " ('<0xD5>', 216),\n",
              " ('<0xD6>', 217),\n",
              " ('<0xD7>', 218),\n",
              " ('<0xD8>', 219),\n",
              " ('<0xD9>', 220),\n",
              " ('<0xDA>', 221),\n",
              " ('<0xDB>', 222),\n",
              " ('<0xDC>', 223),\n",
              " ('<0xDD>', 224),\n",
              " ('<0xDE>', 225),\n",
              " ('<0xDF>', 226),\n",
              " ('<0xE0>', 227),\n",
              " ('<0xE1>', 228),\n",
              " ('<0xE2>', 229),\n",
              " ('<0xE3>', 230),\n",
              " ('<0xE4>', 231),\n",
              " ('<0xE5>', 232),\n",
              " ('<0xE6>', 233),\n",
              " ('<0xE7>', 234),\n",
              " ('<0xE8>', 235),\n",
              " ('<0xE9>', 236),\n",
              " ('<0xEA>', 237),\n",
              " ('<0xEB>', 238),\n",
              " ('<0xEC>', 239),\n",
              " ('<0xED>', 240),\n",
              " ('<0xEE>', 241),\n",
              " ('<0xEF>', 242),\n",
              " ('<0xF0>', 243),\n",
              " ('<0xF1>', 244),\n",
              " ('<0xF2>', 245),\n",
              " ('<0xF3>', 246),\n",
              " ('<0xF4>', 247),\n",
              " ('<0xF5>', 248),\n",
              " ('<0xF6>', 249),\n",
              " ('<0xF7>', 250),\n",
              " ('<0xF8>', 251),\n",
              " ('<0xF9>', 252),\n",
              " ('<0xFA>', 253),\n",
              " ('<0xFB>', 254),\n",
              " ('<0xFC>', 255),\n",
              " ('<0xFD>', 256),\n",
              " ('<0xFE>', 257),\n",
              " ('<0xFF>', 258),\n",
              " ('▁', 259),\n",
              " ('.', 260),\n",
              " (',', 261),\n",
              " ('a', 262),\n",
              " ('s', 263),\n",
              " ('-', 264),\n",
              " ('e', 265),\n",
              " ('i', 266),\n",
              " (':', 267),\n",
              " ('o', 268),\n",
              " ('▁de', 269),\n",
              " ('t', 270),\n",
              " (')', 271),\n",
              " ('n', 272),\n",
              " ('u', 273),\n",
              " ('▁(', 274),\n",
              " ('/', 275),\n",
              " ('y', 276),\n",
              " (\"'\", 277),\n",
              " ('en', 278),\n",
              " ('и', 279),\n",
              " ('l', 280),\n",
              " ('▁in', 281),\n",
              " ('m', 282),\n",
              " ('▁la', 283),\n",
              " ('com', 284),\n",
              " ('d', 285),\n",
              " ('r', 286),\n",
              " ('▁the', 287),\n",
              " ('▁to', 288),\n",
              " ('▁en', 289),\n",
              " ('_', 290),\n",
              " ('?', 291),\n",
              " ('、', 292),\n",
              " ('’', 293),\n",
              " ('▁na', 294),\n",
              " ('er', 295),\n",
              " (';', 296),\n",
              " ('c', 297),\n",
              " ('▁A', 298),\n",
              " ('es', 299),\n",
              " ('▁v', 300),\n",
              " ('▁di', 301),\n",
              " ('...', 302),\n",
              " ('▁se', 303),\n",
              " ('▁of', 304),\n",
              " ('▁and', 305),\n",
              " ('。', 306),\n",
              " ('▁|', 307),\n",
              " ('а', 308),\n",
              " ('!', 309),\n",
              " ('▁на', 310),\n",
              " ('\"', 311),\n",
              " ('(', 312),\n",
              " ('▁\"', 313),\n",
              " ('k', 314),\n",
              " ('▁в', 315),\n",
              " ('b', 316),\n",
              " ('▁c', 317),\n",
              " ('g', 318),\n",
              " ('▁que', 319),\n",
              " ('▁S', 320),\n",
              " ('an', 321),\n",
              " ('▁–', 322),\n",
              " ('▁www', 323),\n",
              " ('е', 324),\n",
              " ('p', 325),\n",
              " ('▁m', 326),\n",
              " ('▁sa', 327),\n",
              " ('3', 328),\n",
              " ('x', 329),\n",
              " ('▁b', 330),\n",
              " ('▁d', 331),\n",
              " ('▁for', 332),\n",
              " ('▁1', 333),\n",
              " ('h', 334),\n",
              " ('▁un', 335),\n",
              " ('▁I', 336),\n",
              " ('os', 337),\n",
              " ('2', 338),\n",
              " ('▁is', 339),\n",
              " ('▁le', 340),\n",
              " ('▁و', 341),\n",
              " ('▁do', 342),\n",
              " ('،', 343),\n",
              " ('▁at', 344),\n",
              " ('ed', 345),\n",
              " ('te', 346),\n",
              " ('ing', 347),\n",
              " ('in', 348),\n",
              " ('=', 349),\n",
              " ('▁da', 350),\n",
              " ('▁on', 351),\n",
              " ('▁M', 352),\n",
              " ('1', 353),\n",
              " ('у', 354),\n",
              " ('▁đ', 355),\n",
              " ('▁2', 356),\n",
              " ('A', 357),\n",
              " ('as', 358),\n",
              " ('▁“', 359),\n",
              " ('z', 360),\n",
              " ('é', 361),\n",
              " ('▁el', 362),\n",
              " ('▁P', 363),\n",
              " ('▁B', 364),\n",
              " ('”', 365),\n",
              " ('▁T', 366),\n",
              " ('f', 367),\n",
              " ('de', 368),\n",
              " ('à', 369),\n",
              " ('ng', 370),\n",
              " ('▁C', 371),\n",
              " ('ar', 372),\n",
              " ('▁og', 373),\n",
              " ('▁за', 374),\n",
              " ('▁no', 375),\n",
              " ('ه', 376),\n",
              " ('na', 377),\n",
              " ('।', 378),\n",
              " ('v', 379),\n",
              " ('re', 380),\n",
              " ('▁3', 381),\n",
              " ('▁h', 382),\n",
              " ('▁et', 383),\n",
              " ('▁je', 384),\n",
              " ('j', 385),\n",
              " ('▁il', 386),\n",
              " ('▁#', 387),\n",
              " ('▁с', 388),\n",
              " ('і', 389),\n",
              " ('▁be', 390),\n",
              " ('://', 391),\n",
              " ('▁2018', 392),\n",
              " ('▁per', 393),\n",
              " ('▁th', 394),\n",
              " ('▁si', 395),\n",
              " ('я', 396),\n",
              " ('▁z', 397),\n",
              " ('▁die', 398),\n",
              " ('S', 399),\n",
              " ('▁te', 400),\n",
              " ('▁не', 401),\n",
              " ('▁ال', 402),\n",
              " ('D', 403),\n",
              " ('▁«', 404),\n",
              " ('ne', 405),\n",
              " ('ی', 406),\n",
              " ('da', 407),\n",
              " ('▁k', 408),\n",
              " ('|', 409),\n",
              " ('4', 410),\n",
              " ('о', 411),\n",
              " ('▁K', 412),\n",
              " ('▁du', 413),\n",
              " ('▁w', 414),\n",
              " ('▁E', 415),\n",
              " ('▁me', 416),\n",
              " ('is', 417),\n",
              " ('▁are', 418),\n",
              " ('▁4', 419),\n",
              " ('í', 420),\n",
              " ('▁p', 421),\n",
              " ('ta', 422),\n",
              " ('の', 423),\n",
              " ('C', 424),\n",
              " ('▁по', 425),\n",
              " ('▁del', 426),\n",
              " ('▁ka', 427),\n",
              " ('5', 428),\n",
              " ('et', 429),\n",
              " ('▁5', 430),\n",
              " ('▁D', 431),\n",
              " ('▁ja', 432),\n",
              " ('ы', 433),\n",
              " ('▁V', 434),\n",
              " ('▁para', 435),\n",
              " ('»', 436),\n",
              " ('\",\"', 437),\n",
              " ('us', 438),\n",
              " (']', 439),\n",
              " ('▁al', 440),\n",
              " ('▁N', 441),\n",
              " ('▁der', 442),\n",
              " ('▁O', 443),\n",
              " ('on', 444),\n",
              " ('ة', 445),\n",
              " ('▁да', 446),\n",
              " ('▁H', 447),\n",
              " ('▁ne', 448),\n",
              " ('8', 449),\n",
              " ('▁con', 450),\n",
              " ('6', 451),\n",
              " ('B', 452),\n",
              " ('▁er', 453),\n",
              " ('ul', 454),\n",
              " ('▁by', 455),\n",
              " ('▁у', 456),\n",
              " ('▁yang', 457),\n",
              " ('▁L', 458),\n",
              " ('▁De', 459),\n",
              " ('0', 460),\n",
              " ('▁an', 461),\n",
              " ('ja', 462),\n",
              " ('\\xad', 463),\n",
              " ('▁van', 464),\n",
              " ('▁ה', 465),\n",
              " ('▁za', 466),\n",
              " ('】【', 467),\n",
              " ('le', 468),\n",
              " ('▁dan', 469),\n",
              " ('em', 470),\n",
              " ('á', 471),\n",
              " ('▁und', 472),\n",
              " ('al', 473),\n",
              " ('è', 474),\n",
              " ('▁10', 475),\n",
              " ('to', 476),\n",
              " ('ي', 477),\n",
              " ('E', 478),\n",
              " ('ka', 479),\n",
              " ('▁...', 480),\n",
              " ('w', 481),\n",
              " ('▁på', 482),\n",
              " (').', 483),\n",
              " ('ly', 484),\n",
              " ('▁po', 485),\n",
              " ('▁The', 486),\n",
              " ('7', 487),\n",
              " ('\":\"', 488),\n",
              " ('▁G', 489),\n",
              " ('T', 490),\n",
              " ('▁[', 491),\n",
              " ('la', 492),\n",
              " ('的', 493),\n",
              " ('li', 494),\n",
              " ('9', 495),\n",
              " ('▁ma', 496),\n",
              " ('▁0', 497),\n",
              " ('▁des', 498),\n",
              " ('▁med', 499),\n",
              " ('▁til', 500),\n",
              " ('▁La', 501),\n",
              " ('kan', 502),\n",
              " ('it', 503),\n",
              " ('▁ki', 504),\n",
              " ('no', 505),\n",
              " ('),', 506),\n",
              " ('м', 507),\n",
              " ('َ', 508),\n",
              " ('▁در', 509),\n",
              " ('▁so', 510),\n",
              " ('M', 511),\n",
              " ('▁som', 512),\n",
              " ('▁ke', 513),\n",
              " ('▁with', 514),\n",
              " ('▁F', 515),\n",
              " ('ni', 516),\n",
              " ('▁su', 517),\n",
              " ('▁και', 518),\n",
              " ('▁por', 519),\n",
              " ('▁les', 520),\n",
              " ('▁you', 521),\n",
              " ('si', 522),\n",
              " ('at', 523),\n",
              " ('ti', 524),\n",
              " ('id', 525),\n",
              " ('▁av', 526),\n",
              " ('▁as', 527),\n",
              " ('▁ya', 528),\n",
              " ('▁ve', 529),\n",
              " ('▁den', 530),\n",
              " ('▁R', 531),\n",
              " ('▁ב', 532),\n",
              " ('▁that', 533),\n",
              " ('▁tr', 534),\n",
              " ('は', 535),\n",
              " ('が', 536),\n",
              " ('do', 537),\n",
              " ('N', 538),\n",
              " ('ia', 539),\n",
              " ('\\\\', 540),\n",
              " ('ce', 541),\n",
              " ('▁om', 542),\n",
              " ('й', 543),\n",
              " ('▁се', 544),\n",
              " ('F', 545),\n",
              " ('&', 546),\n",
              " ('L', 547),\n",
              " ('▁م', 548),\n",
              " ('▁&', 549),\n",
              " ('▁د', 550),\n",
              " ('▁det', 551),\n",
              " ('▁от', 552),\n",
              " ('ó', 553),\n",
              " ('▁به', 554),\n",
              " ('▁pa', 555),\n",
              " ('▁من', 556),\n",
              " ('K', 557),\n",
              " ('на', 558),\n",
              " ('P', 559),\n",
              " ('▁ha', 560),\n",
              " ('V', 561),\n",
              " ('▁ch', 562),\n",
              " ('▁In', 563),\n",
              " ('▁W', 564),\n",
              " ('▁„', 565),\n",
              " ('I', 566),\n",
              " ('▁var', 567),\n",
              " ('▁ni', 568),\n",
              " ('se', 569),\n",
              " ('▁6', 570),\n",
              " ('ra', 571),\n",
              " ('ل', 572),\n",
              " ('▁una', 573),\n",
              " ('を', 574),\n",
              " ('▁في', 575),\n",
              " ('▁ta', 576),\n",
              " ('▁http', 577),\n",
              " ('COM', 578),\n",
              " ('am', 579),\n",
              " ('ה', 580),\n",
              " ('▁U', 581),\n",
              " ('R', 582),\n",
              " ('▁з', 583),\n",
              " ('▁re', 584),\n",
              " ('▁op', 585),\n",
              " ('ن', 586),\n",
              " ('т', 587),\n",
              " ('▁har', 588),\n",
              " ('ο', 589),\n",
              " ('H', 590),\n",
              " ('“', 591),\n",
              " ('ek', 592),\n",
              " ('▁ag', 593),\n",
              " ('▁ng', 594),\n",
              " ('▁los', 595),\n",
              " ('{', 596),\n",
              " ('▁och', 597),\n",
              " ('▁2017', 598),\n",
              " ('▁WWW', 599),\n",
              " ('に', 600),\n",
              " ('▁ku', 601),\n",
              " ('ir', 602),\n",
              " ('▁pe', 603),\n",
              " ('un', 604),\n",
              " ('х', 605),\n",
              " ('um', 606),\n",
              " ('▁2019', 607),\n",
              " ('je', 608),\n",
              " ('▁it', 609),\n",
              " ('▁до', 610),\n",
              " ('을', 611),\n",
              " ('ʻ', 612),\n",
              " ('www', 613),\n",
              " ('▁ب', 614),\n",
              " ('▁li', 615),\n",
              " ('но', 616),\n",
              " ('▁7', 617),\n",
              " ('▁»', 618),\n",
              " ('▁ir', 619),\n",
              " ('▁kan', 620),\n",
              " ('G', 621),\n",
              " ('▁het', 622),\n",
              " ('▁ho', 623),\n",
              " ('▁par', 624),\n",
              " ('▁vi', 625),\n",
              " ('・', 626),\n",
              " ('で', 627),\n",
              " ('▁20', 628),\n",
              " ('▁të', 629),\n",
              " ('▁8', 630),\n",
              " ('▁or', 631),\n",
              " ('ا', 632),\n",
              " ('م', 633),\n",
              " ('ie', 634),\n",
              " ('▁В', 635),\n",
              " ('ت', 636),\n",
              " ('ом', 637),\n",
              " ('W', 638),\n",
              " ('▁was', 639),\n",
              " ('την', 640),\n",
              " ('▁के', 641),\n",
              " ('▁En', 642),\n",
              " ('▁af', 643),\n",
              " ('▁12', 644),\n",
              " ('me', 645),\n",
              " ('O', 646),\n",
              " ('nya', 647),\n",
              " ('ma', 648),\n",
              " ('의', 649),\n",
              " ('ki', 650),\n",
              " ('▁cu', 651),\n",
              " ('μ', 652),\n",
              " ('▁No', 653),\n",
              " ('▁2016', 654),\n",
              " ('▁es', 655),\n",
              " ('▁een', 656),\n",
              " ('ки', 657),\n",
              " ('▁mi', 658),\n",
              " ('Ð', 659),\n",
              " ('10', 660),\n",
              " ('▁—', 661),\n",
              " ('ku', 662),\n",
              " ('\":', 663),\n",
              " ('▁J', 664),\n",
              " ('px', 665),\n",
              " ('일', 666),\n",
              " ('▁ל', 667),\n",
              " ('ни', 668),\n",
              " ('>', 669),\n",
              " ('▁15', 670),\n",
              " ('▁‘', 671),\n",
              " ('▁ver', 672),\n",
              " ('▁um', 673),\n",
              " ('▁man', 674),\n",
              " ('▁ko', 675),\n",
              " ('+', 676),\n",
              " ('▁nh', 677),\n",
              " ('η', 678),\n",
              " ('ка', 679),\n",
              " ('ny', 680),\n",
              " ('α', 681),\n",
              " ('▁od', 682),\n",
              " ('▁wa', 683),\n",
              " ('▁ge', 684),\n",
              " ('ов', 685),\n",
              " ('н', 686),\n",
              " ('ten', 687),\n",
              " ('▁С', 688),\n",
              " ('▁מ', 689),\n",
              " ('▁ph', 690),\n",
              " ('▁>', 691),\n",
              " ('▁men', 692),\n",
              " ('▁ber', 693),\n",
              " ('▁του', 694),\n",
              " ('▁از', 695),\n",
              " ('il', 696),\n",
              " ('ch', 697),\n",
              " ('▁bir', 698),\n",
              " ('▁το', 699),\n",
              " ('▁να', 700),\n",
              " ('el', 701),\n",
              " ('▁from', 702),\n",
              " ('▁nu', 703),\n",
              " ('ko', 704),\n",
              " ('st', 705),\n",
              " ('ë', 706),\n",
              " ('▁lo', 707),\n",
              " ('ủ', 708),\n",
              " ('▁az', 709),\n",
              " ('▁dem', 710),\n",
              " ('mi', 711),\n",
              " ('▁va', 712),\n",
              " ('▁att', 713),\n",
              " ('▁this', 714),\n",
              " ('ur', 715),\n",
              " ('▁nie', 716),\n",
              " ('#', 717),\n",
              " ('▁gi', 718),\n",
              " ('▁tu', 719),\n",
              " ('di', 720),\n",
              " ('å', 721),\n",
              " ('ات', 722),\n",
              " ('or', 723),\n",
              " ('▁em', 724),\n",
              " ('と', 725),\n",
              " ('ת', 726),\n",
              " ('▁Na', 727),\n",
              " ('▁am', 728),\n",
              " ('▁из', 729),\n",
              " ('▁11', 730),\n",
              " ('▁pro', 731),\n",
              " ('▁în', 732),\n",
              " ('▁30', 733),\n",
              " ('▁che', 734),\n",
              " ('для', 735),\n",
              " ('▁Z', 736),\n",
              " ('ru', 737),\n",
              " ('▁can', 738),\n",
              " ('ya', 739),\n",
              " ('▁ang', 740),\n",
              " ('ai', 741),\n",
              " ('▁f', 742),\n",
              " ('ga', 743),\n",
              " ('▁+', 744),\n",
              " ('za', 745),\n",
              " ('▁Se', 746),\n",
              " ('이', 747),\n",
              " ('ю', 748),\n",
              " ('▁mit', 749),\n",
              " ('ca', 750),\n",
              " ('▁all', 751),\n",
              " ('▁של', 752),\n",
              " ('ke', 753),\n",
              " ('\",', 754),\n",
              " ('°', 755),\n",
              " ('▁tak', 756),\n",
              " ('ने', 757),\n",
              " ('▁bu', 758),\n",
              " ('▁bo', 759),\n",
              " ('▁zu', 760),\n",
              " ('ą', 761),\n",
              " ('ή', 762),\n",
              " ('▁pour', 763),\n",
              " ('▁Le', 764),\n",
              " ('[', 765),\n",
              " ('▁ت', 766),\n",
              " ('▁ter', 767),\n",
              " ('▁با', 768),\n",
              " ('ci', 769),\n",
              " ('▁és', 770),\n",
              " ('co', 771),\n",
              " ('▁your', 772),\n",
              " ('om', 773),\n",
              " ('▁9', 774),\n",
              " ('▁کے', 775),\n",
              " ('▁not', 776),\n",
              " ('их', 777),\n",
              " ('▁к', 778),\n",
              " ('▁din', 779),\n",
              " ('im', 780),\n",
              " ('q', 781),\n",
              " ('ă', 782),\n",
              " ('▁have', 783),\n",
              " ('▁mai', 784),\n",
              " ('▁{', 785),\n",
              " ('▁pre', 786),\n",
              " ('▁we', 787),\n",
              " ('▁Re', 788),\n",
              " ('▁El', 789),\n",
              " ('▁he', 790),\n",
              " ('ς', 791),\n",
              " ('▁•', 792),\n",
              " ('và', 793),\n",
              " ('Y', 794),\n",
              " ('▁von', 795),\n",
              " ('▁là', 796),\n",
              " ('ې', 797),\n",
              " ('▁ar', 798),\n",
              " ('▁16', 799),\n",
              " ('▁las', 800),\n",
              " ('ú', 801),\n",
              " ('app', 802),\n",
              " ('▁کی', 803),\n",
              " ('▁au', 804),\n",
              " ('▁при', 805),\n",
              " ('U', 806),\n",
              " ('th', 807),\n",
              " ('▁}', 808),\n",
              " ('▁2014', 809),\n",
              " ('▁ba', 810),\n",
              " ('be', 811),\n",
              " ('▁18', 812),\n",
              " ('X', 813),\n",
              " ('▁2015', 814),\n",
              " ('▁2013', 815),\n",
              " ('▁(1)', 816),\n",
              " ('ой', 817),\n",
              " ('▁14', 818),\n",
              " ('▁qu', 819),\n",
              " ('ِ', 820),\n",
              " ('ha', 821),\n",
              " ('▁می', 822),\n",
              " ('man', 823),\n",
              " ('▁met', 824),\n",
              " ('are', 825),\n",
              " ('▁nga', 826),\n",
              " ('▁das', 827),\n",
              " ('▁της', 828),\n",
              " ('‘', 829),\n",
              " ('▁है', 830),\n",
              " ('ية', 831),\n",
              " ('то', 832),\n",
              " ('ь', 833),\n",
              " ('va', 834),\n",
              " ('ba', 835),\n",
              " ('】', 836),\n",
              " ('▁bi', 837),\n",
              " ('日', 838),\n",
              " ('한', 839),\n",
              " ('▁24', 840),\n",
              " ('ر', 841),\n",
              " ('ى', 842),\n",
              " ('▁est', 843),\n",
              " ('▁में', 844),\n",
              " ('lar', 845),\n",
              " ('▁2012', 846),\n",
              " ('▁dengan', 847),\n",
              " ('年', 848),\n",
              " ('▁13', 849),\n",
              " ('▁με', 850),\n",
              " ('▁untuk', 851),\n",
              " ('▁Y', 852),\n",
              " (');', 853),\n",
              " ('▁ini', 854),\n",
              " ('▁ש', 855),\n",
              " ('▁ist', 856),\n",
              " ('ve', 857),\n",
              " ('▁ا', 858),\n",
              " ('▁im', 859),\n",
              " ('this', 860),\n",
              " ('est', 861),\n",
              " ('▁online', 862),\n",
              " ('न', 863),\n",
              " ('▁А', 864),\n",
              " ('▁sur', 865),\n",
              " ('J', 866),\n",
              " ('▁У', 867),\n",
              " ('ך', 868),\n",
              " ('은', 869),\n",
              " ('ado', 870),\n",
              " ('▁ti', 871),\n",
              " ('ہ', 872),\n",
              " ('에', 873),\n",
              " ('ri', 874),\n",
              " ('▁för', 875),\n",
              " ('tu', 876),\n",
              " ('▁25', 877),\n",
              " ('lo', 878),\n",
              " ('」', 879),\n",
              " ('den', 880),\n",
              " ('%', 881),\n",
              " ('▁א', 882),\n",
              " ('د', 883),\n",
              " ('▁את', 884),\n",
              " ('▁có', 885),\n",
              " ('▁pas', 886),\n",
              " ('=\"', 887),\n",
              " ('▁ein', 888),\n",
              " ('ou', 889),\n",
              " ('▁mu', 890),\n",
              " ('月', 891),\n",
              " ('▁что', 892),\n",
              " ('ого', 893),\n",
              " ('*', 894),\n",
              " ('ի', 895),\n",
              " ('ים', 896),\n",
              " ('р', 897),\n",
              " ('▁will', 898),\n",
              " ('▁fa', 899),\n",
              " ('net', 900),\n",
              " ('▁για', 901),\n",
              " ('д', 902),\n",
              " ('ê', 903),\n",
              " ('▁*', 904),\n",
              " ('ُ', 905),\n",
              " ('ada', 906),\n",
              " ('▁qui', 907),\n",
              " ('ới', 908),\n",
              " ('г', 909),\n",
              " ('▁over', 910),\n",
              " ('▁17', 911),\n",
              " ('▁από', 912),\n",
              " ('ها', 913),\n",
              " (',\"', 914),\n",
              " ('ā', 915),\n",
              " ('▁را', 916),\n",
              " ('▁со', 917),\n",
              " ('та', 918),\n",
              " ('▁ser', 919),\n",
              " ('л', 920),\n",
              " ('que', 921),\n",
              " ('▁так', 922),\n",
              " ('▁про', 923),\n",
              " ('ể', 924),\n",
              " ('ok', 925),\n",
              " ('▁To', 926),\n",
              " ('▁σ', 927),\n",
              " ('▁და', 928),\n",
              " ('가', 929),\n",
              " ('ό', 930),\n",
              " ('ción', 931),\n",
              " ('ak', 932),\n",
              " ('ị', 933),\n",
              " ('▁که', 934),\n",
              " ('▁non', 935),\n",
              " ('ן', 936),\n",
              " ('▁је', 937),\n",
              " ('ro', 938),\n",
              " ('「', 939),\n",
              " ('ag', 940),\n",
              " ('ان', 941),\n",
              " ('على', 942),\n",
              " ('▁आ', 943),\n",
              " ('ите', 944),\n",
              " ('да', 945),\n",
              " ('с', 946),\n",
              " ('▁się', 947),\n",
              " ('▁€', 948),\n",
              " ('▁mo', 949),\n",
              " ('▁است', 950),\n",
              " ('▁·', 951),\n",
              " ('ý', 952),\n",
              " ('▁این', 953),\n",
              " ('Р', 954),\n",
              " ('▁if', 955),\n",
              " ('▁für', 956),\n",
              " ('не', 957),\n",
              " ('▁como', 958),\n",
              " ('▁X', 959),\n",
              " ('▁ca', 960),\n",
              " ('▁är', 961),\n",
              " ('ní', 962),\n",
              " ('▁19', 963),\n",
              " ('▁co', 964),\n",
              " ('▁כ', 965),\n",
              " ('▁100', 966),\n",
              " ('ere', 967),\n",
              " ('▁að', 968),\n",
              " ('wa', 969),\n",
              " ('▁cho', 970),\n",
              " ('▁voor', 971),\n",
              " ('▁2020', 972),\n",
              " ('▁میں', 973),\n",
              " ('و', 974),\n",
              " ('▁की', 975),\n",
              " ('ji', 976),\n",
              " ('▁Đ', 977),\n",
              " ('も', 978),\n",
              " ('▁pri', 979),\n",
              " ('▁este', 980),\n",
              " ('▁2011', 981),\n",
              " ('▁ce', 982),\n",
              " ('▁О', 983),\n",
              " ('▁է', 984),\n",
              " ('ik', 985),\n",
              " ('ት', 986),\n",
              " ('▁21', 987),\n",
              " ('는', 988),\n",
              " ('ку', 989),\n",
              " ('ж', 990),\n",
              " ('ے', 991),\n",
              " ('▁во', 992),\n",
              " ('ç', 993),\n",
              " ('ে', 994),\n",
              " ('п', 995),\n",
              " ('र', 996),\n",
              " ('Z', 997),\n",
              " ('▁од', 998),\n",
              " ('▁ob', 999),\n",
              " ...]"
            ]
          },
          "metadata": {},
          "execution_count": 6
        }
      ],
      "source": [
        "# トークン一覧を確認\n",
        "# <pad>: padding用のトークン。トークン長を揃えるためのもの。\n",
        "# </s>: 文末トークン。\n",
        "# <unk>: unknownトークン。未知語は全てこれになる。\n",
        "\n",
        "print(len(tokenizer.vocab))\n",
        "sorted(tokenizer.vocab.items(), key=lambda x: x[1])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "BzUnT7kpLoqN",
        "outputId": "28aed289-8aab-4a0a-88c7-a3a40b8a116a"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "token_ids =  tensor([[1494,  339, 2978,  260,    1,    0,    0,    0,    0,    0,    0,    0,\n",
            "            0,    0,    0,    0,    0,    0,    0,    0]], device='cuda:0')\n",
            "len(token_ids[0]) =  20\n",
            "tokens =  ['▁This', '▁is', '▁test', '.', '</s>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']\n",
            "len(tokens) = 20\n"
          ]
        }
      ],
      "source": [
        "# 短いトークン系列に対して padding する例。\n",
        "# パラメータ padding でパディング方法を指定。'max_length'とすると、max_lengthで指定した長さになるまでパディングする。\n",
        "\n",
        "example_input_str = 'This is test.'\n",
        "\n",
        "token_ids = tokenizer.encode(\n",
        "    example_input_str, return_tensors='pt', padding='max_length',\n",
        "    truncation=True, max_length=max_seq_len).cuda()\n",
        "print('token_ids = ', token_ids)\n",
        "print('len(token_ids[0]) = ', len(token_ids[0]))\n",
        "\n",
        "tokens = tokenizer.convert_ids_to_tokens(token_ids[0])\n",
        "print('tokens = ', tokens)\n",
        "print('len(tokens) =', len(tokens))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-moGg5IuLoqO",
        "outputId": "e289c182-d41d-4f9b-dbd7-2bf8fb0328f5"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "length =  3\n",
            "token_ids =  tensor([[1494,  339, 2978,  260,    1,    0,    0,    0,    0,    0,    0,    0,\n",
            "            0,    0,    0,    0,    0,    0,    0,    0]], device='cuda:0')\n",
            "len(token_ids[0]) =  20\n",
            "tokens =  ['▁This', '▁is', '▁test', '.', '</s>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']\n",
            "len(tokens) =  20\n"
          ]
        }
      ],
      "source": [
        "# 長いトークン系列に対して打ち切りする例。\n",
        "# パラメータ trancation を Trueにする。\n",
        "\n",
        "long_example_input_str = 'The mT5 model was presented in mT5: A massively multilingual pre-trained text-to-text transformer by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.'\n",
        "length = len(example_input_str.split())\n",
        "print('length = ', length)\n",
        "\n",
        "token_ids = tokenizer.encode(\n",
        "    example_input_str, return_tensors='pt', padding='max_length',\n",
        "    truncation=True, max_length=max_seq_len).cuda()\n",
        "print('token_ids = ', token_ids)\n",
        "print('len(token_ids[0]) = ', len(token_ids[0]))\n",
        "\n",
        "tokens = tokenizer.convert_ids_to_tokens(token_ids[0])\n",
        "print('tokens = ', tokens)\n",
        "print('len(tokens) = ', len(tokens))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yMZ5G7d7LoqO"
      },
      "source": [
        "## タスクに向けた専用トークンの追加\n",
        "翻訳して欲しいということを伝えやすくするために専用のトークンを追加しよう。具体的には以下の通り。\n",
        "- 日本語に翻訳して欲しい場合の書式: ``<jp> This is test.``\n",
        "- 英語に翻訳してほしい場合の書式: ``<en> これはテストです。``"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "eMprIAm-LoqO",
        "outputId": "916882dd-279e-4586-d211-30aa1adef2cb"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "token_ids =  tensor([[1042, 3889,  669, 1494,  339, 2978,  260,    1,    0,    0,    0,    0,\n",
            "            0,    0,    0,    0,    0,    0,    0,    0]], device='cuda:0')\n",
            "tokens =  ['▁<', 'jp', '>', '▁This', '▁is', '▁test', '.', '</s>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']\n"
          ]
        }
      ],
      "source": [
        "# トークンの確認。\n",
        "# <jp>, <en> はトークンが存在せず別トークン系列に分解されてしまうので、新しいトークンとして追加する。\n",
        "\n",
        "example_input_str = '<jp> This is test.'\n",
        "\n",
        "token_ids = tokenizer.encode(\n",
        "    example_input_str, return_tensors='pt', padding='max_length',\n",
        "    truncation=True, max_length=max_seq_len).cuda()\n",
        "print('token_ids = ', token_ids)\n",
        "\n",
        "tokens = tokenizer.convert_ids_to_tokens(token_ids[0])\n",
        "print('tokens = ', tokens)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "bKu-ZfaXLoqP",
        "outputId": "605a7a8c-1bbc-46c8-f0f0-bcdbdbc41180"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "token_ids =  tensor([[  1042,    278,    669, 144591,  80822,   1252,    306,      1,      0,\n",
            "              0,      0,      0,      0,      0,      0,      0,      0,      0,\n",
            "              0,      0]], device='cuda:0')\n",
            "tokens =  ['▁<', 'en', '>', '▁これは', 'テスト', 'です', '。', '</s>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']\n"
          ]
        }
      ],
      "source": [
        "example_input_str = '<en> これはテストです。'\n",
        "\n",
        "token_ids = tokenizer.encode(\n",
        "    example_input_str, return_tensors='pt', padding='max_length',\n",
        "    truncation=True, max_length=max_seq_len).cuda()\n",
        "print('token_ids = ', token_ids)\n",
        "\n",
        "tokens = tokenizer.convert_ids_to_tokens(token_ids[0])\n",
        "print('tokens = ', tokens)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ebR54sZjLoqP",
        "outputId": "0b8f8943-869d-46c6-f70f-afe494ed1b45"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "元のボキャブラリ数 =  250100\n",
            "追加後のボキャブラリ数 =  250102\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Embedding(250102, 512)"
            ]
          },
          "metadata": {},
          "execution_count": 11
        }
      ],
      "source": [
        "# 追加したいトークンを辞書型で用意。\n",
        "LANG_TOKEN_MAPPING = {\n",
        "    'ja': '<jp>',\n",
        "    'en': '<en>'\n",
        "}\n",
        "\n",
        "print('元のボキャブラリ数 = ', len(tokenizer.vocab))\n",
        "\n",
        "# tokenizer.add_special_tokens を使って追加。\n",
        "# 追加したことをモデルにも伝える必要がある。\n",
        "special_tokens_dict = {'additional_special_tokens': list(LANG_TOKEN_MAPPING.values())}\n",
        "tokenizer.add_special_tokens(special_tokens_dict) # 専用トークン追加。\n",
        "\n",
        "print('追加後のボキャブラリ数 = ', len(tokenizer.vocab))\n",
        "\n",
        "model.resize_token_embeddings(len(tokenizer))     # モデルのembeddingを調整。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Qg_IHvvLLoqP",
        "outputId": "e9bbf43b-6a30-4a98-eaf6-b9e04bdea24a"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "token_ids =  tensor([[250100,   1494,    339,   2978,    260,      1,      0,      0,      0,\n",
            "              0,      0,      0,      0,      0,      0,      0,      0,      0,\n",
            "              0,      0]], device='cuda:0')\n",
            "tokens =  ['<jp>', '▁This', '▁is', '▁test', '.', '</s>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']\n"
          ]
        }
      ],
      "source": [
        "# トークン追加後の動作確認。\n",
        "# <ja> が一つのトークンとして処理されていることを確認。\n",
        "\n",
        "example_input_str = '<jp> This is test.'\n",
        "\n",
        "token_ids = tokenizer.encode(\n",
        "    example_input_str, return_tensors='pt', padding='max_length',\n",
        "    truncation=True, max_length=max_seq_len).cuda()\n",
        "print('token_ids = ', token_ids)\n",
        "\n",
        "tokens = tokenizer.convert_ids_to_tokens(token_ids[0])\n",
        "print('tokens = ', tokens)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O5qI5AcqLoqQ"
      },
      "source": [
        "## データセットの準備\n",
        "ファインチューニングするためのデータセットを用意しよう。今回は13言語での対訳がある[alt](https://huggingface.co/datasets/alt)を使う。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HVBJ-XSULoqQ"
      },
      "source": [
        "### データセット外観\n",
        "train, testが別途用意されており、それぞれ辞書型で保存されている。\n",
        "- ``dataset['url']``: ソースURL\n",
        "- ``dataset['translation']``: 対訳テキスト\n",
        "  - ``dataset['translation']['ja']``: 日本語テキスト\n",
        "  - ``dataset['translation']['en']``: 英語テキスト"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 102,
          "referenced_widgets": [
            "ae1d9c9ade2a4053877dd8888c69549e",
            "b79abb2b6da04ee5be620d30b4839855",
            "0b6f0f0607f644c8881f9e6618514a97",
            "a47dcc9316c6416a8562d309d3a1a00b",
            "c7c7c787cf614421b13e6f5b8878acf1",
            "f14bc6eea84c4d97ab266b30038557cb",
            "b58848c35ba147fcbcd769a8c9fe3f49",
            "31c3181f5d2c4eb3b4fb623596b31db1",
            "4604b40f5bb546c69f98194dacdff29e",
            "f50367f8355b4e289f46374598e58381",
            "6ca0fe6eee0643bd9798fb324633a2ad"
          ]
        },
        "id": "dtSQ-6zZLoqQ",
        "outputId": "af630c9b-3dff-488e-da81-4ada46da72cf"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "No config specified, defaulting to: alt/alt-parallel\n",
            "Reusing dataset alt (/root/.cache/huggingface/datasets/alt/alt-parallel/1.0.0/e784a3f2a9f6bdf277940de6cc9d700eab852896cd94aad4233caf26008da9ed)\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "  0%|          | 0/3 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "ae1d9c9ade2a4053877dd8888c69549e"
            }
          },
          "metadata": {}
        }
      ],
      "source": [
        "# Source: https://huggingface.co/datasets/alt\n",
        "dataset = load_dataset('alt')\n",
        "\n",
        "train_dataset = dataset['train']\n",
        "test_dataset = dataset['test']"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WdWZhOf5LoqQ",
        "outputId": "92214360-eff2-4219-9f70-3a07f0d20afb"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'SNT.URLID': '80188',\n",
              " 'SNT.URLID.SNTID': '1',\n",
              " 'translation': {'bg': 'ফ্রান্সের প্যারিসের পার্ক দি প্রিন্সেস-এ হওয়া ২০০৭-এর রাগবি বিশ্বকাপের পুল সি-তে ইটালি পর্তুগালকে ৩১-৫ গোলে হারিয়েছে।',\n",
              "  'en': 'Italy have defeated Portugal 31-5 in Pool C of the 2007 Rugby World Cup at Parc des Princes, Paris, France.',\n",
              "  'en_tok': 'Italy have defeated Portugal 31-5 in Pool C of the 2007 Rugby World Cup at Parc des Princes , Paris , France .',\n",
              "  'fil': 'Natalo ng Italya ang Portugal sa puntos na 31-5 sa Grupong C noong 2007 sa Pandaigdigang laro ng Ragbi sa Parc des Princes, Paris, France.',\n",
              "  'hi': '2007 में फ़्रांस, पेरिस के पार्क डेस प्रिंसेस में हुए रग्बी विश्व कप के पूल C में इटली ने पुर्तगाल को 31-5 से हराया।',\n",
              "  'id': 'Italia berhasil mengalahkan Portugal 31-5 di grup C dalam Piala Dunia Rugby 2007 di Parc des Princes, Paris, Perancis.',\n",
              "  'ja': 'フランスのパリ、パルク・デ・プランスで行われた2007年ラグビーワールドカップのプールCで、イタリアは31対5でポルトガルを下した。',\n",
              "  'khm': 'អ៊ីតាលីបានឈ្នះលើព័រទុយហ្គាល់ 31-5 ក្នុងប៉ូលCនៃពីធីប្រកួតពានរង្វាន់ពិភពលោកនៃកីឡាបាល់ឱបឆ្នាំ2007ដែលប្រព្រឹត្តនៅប៉ាសឌេសប្រីន ក្រុងប៉ារីស បារាំង។',\n",
              "  'lo': 'ອິຕາລີໄດ້ເສຍໃຫ້ປ໊ອກຕຸຍການ 31 ຕໍ່ 5 ໃນພູລ C ຂອງ ການແຂ່ງຂັນຣັກບີ້ລະດັບໂລກປີ 2007 ທີ່ ປາກເດແພຣັງ ປາຣີ ປະເທດຝຣັ່ງ.',\n",
              "  'ms': 'Itali telah mengalahkan Portugal 31-5 dalam Pool C pada Piala Dunia Ragbi 2007 di Parc des Princes, Paris, Perancis.',\n",
              "  'my': 'ပြင်သစ်နိုင်ငံ ပါရီမြို့ ပါ့ဒက်စ် ပရင့်စက် ၌ ၂၀၀၇ခုနှစ် ရပ်ဘီ ကမ္ဘာ့ ဖလား တွင် အီတလီ သည် ပေါ်တူဂီ ကို ၃၁-၅ ဂိုး ဖြင့် ရေကူးကန် စီ တွင် ရှုံးနိမ့်သွားပါသည် ။',\n",
              "  'th': 'อิตาลีได้เอาชนะโปรตุเกสด้วยคะแนน31ต่อ5 ในกลุ่มc ของการแข่งขันรักบี้เวิลด์คัพปี2007 ที่สนามปาร์กเดแพร็งส์ ที่กรุงปารีส ประเทศฝรั่งเศส',\n",
              "  'vi': 'Ý đã đánh bại Bồ Đào Nha với tỉ số 31-5 ở Bảng C Giải vô địch Rugby thế giới 2007 tại Parc des Princes, Pari, Pháp.',\n",
              "  'zh': '意大利在法国巴黎王子公园体育场举办的2007年橄榄球世界杯C组以31-5击败葡萄牙。'},\n",
              " 'url': 'http://en.wikinews.org/wiki/2007_Rugby_World_Cup:_Italy_31_-_5_Portugal'}"
            ]
          },
          "metadata": {},
          "execution_count": 14
        }
      ],
      "source": [
        "# 中身を確認してみよう。\n",
        "\n",
        "train_dataset[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "iT8OWxWcLoqQ",
        "outputId": "931adb92-bb52-4462-87ea-4abdbbfb65c2"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "train_dataset[0]['translation']['ja'] =  フランスのパリ、パルク・デ・プランスで行われた2007年ラグビーワールドカップのプールCで、イタリアは31対5でポルトガルを下した。\n",
            "train_dataset[0]['translation']['en'] =  Italy have defeated Portugal 31-5 in Pool C of the 2007 Rugby World Cup at Parc des Princes, Paris, France.\n"
          ]
        }
      ],
      "source": [
        "print(\"train_dataset[0]['translation']['ja'] = \", train_dataset[0]['translation']['ja'])\n",
        "print(\"train_dataset[0]['translation']['en'] = \", train_dataset[0]['translation']['en'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LNT_dK4PLoqR"
      },
      "source": [
        "## ファインチューニング方針\n",
        "- mT5への入出力データを対訳で用意する。\n",
        "- やりたい翻訳は日英または英日のみだが、ファインチューニングでは言語を問わず対訳学習させる。\n",
        "- 元の対訳文にはテキストのみが書かれているため、データセットとして利用する際には専用トークン ``<ja>, <en>`` を追加する。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1qP3zoPQLoqR"
      },
      "source": [
        "### データセット前処理"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OJNx4Mw7nlRr"
      },
      "outputs": [],
      "source": [
        "# モデルへの入力分をエンコードする関数。\n",
        "# テキスト本文(text)に専用トークン(target_lang)を追加し、tokenizerでエンコード。\n",
        "# 戻り値は token_id 系列。\n",
        "def encode_input_str(text, target_lang, tokenizer, seq_len,\n",
        "                     lang_token_map=LANG_TOKEN_MAPPING):\n",
        "  target_lang_token = lang_token_map[target_lang]\n",
        "\n",
        "  # Tokenize and add special tokens\n",
        "  input_ids = tokenizer.encode(\n",
        "      text = target_lang_token + text,\n",
        "      return_tensors = 'pt',\n",
        "      padding = 'max_length',\n",
        "      truncation = True,\n",
        "      max_length = seq_len).cuda()\n",
        "\n",
        "  return input_ids[0]\n",
        "\n",
        "# 対訳テキスト本文(text)をエンコード。ここでは専用トークンを追加しない。\n",
        "def encode_target_str(text, tokenizer, seq_len,\n",
        "                      lang_token_map=LANG_TOKEN_MAPPING):\n",
        "  token_ids = tokenizer.encode(\n",
        "      text = text,\n",
        "      return_tensors = 'pt',\n",
        "      padding = 'max_length',\n",
        "      truncation = True,\n",
        "      max_length = seq_len).cuda()\n",
        "  \n",
        "  return token_ids[0]\n",
        "\n",
        "# 上で用意した関数を使って、\n",
        "# 翻訳対象テキスト(input_text)と翻訳後テキスト(target_text)の系列データを用意する。\n",
        "# 実行する度に dataset['translation'] からランダムに2言語を選び、処理する。\n",
        "def format_translation_data(translations, lang_token_map,\n",
        "                            tokenizer, seq_len=128):\n",
        "  # Choose a random 2 languages for in i/o\n",
        "  langs = list(lang_token_map.keys())\n",
        "  input_lang, target_lang = np.random.choice(langs, size=2, replace=False)\n",
        "\n",
        "  # Get the translations for the batch\n",
        "  input_text = translations[input_lang]\n",
        "  target_text = translations[target_lang]\n",
        "\n",
        "  if input_text is None or target_text is None:\n",
        "    return None\n",
        "\n",
        "  input_token_ids = encode_input_str(\n",
        "      input_text, target_lang, tokenizer, seq_len, lang_token_map)\n",
        "  \n",
        "  target_token_ids = encode_target_str(\n",
        "      target_text, tokenizer, seq_len, lang_token_map)\n",
        "\n",
        "  return input_token_ids, target_token_ids\n",
        "\n",
        "# format_translation_dataを使ってバッチデータを作成。\n",
        "def transform_batch(batch, lang_token_map, tokenizer):\n",
        "  inputs = []\n",
        "  targets = []\n",
        "  for translation_set in batch['translation']:\n",
        "    formatted_data = format_translation_data(\n",
        "        translation_set, lang_token_map, tokenizer, max_seq_len)\n",
        "    \n",
        "    if formatted_data is None:\n",
        "      continue\n",
        "    \n",
        "    input_ids, target_ids = formatted_data\n",
        "    inputs.append(input_ids.unsqueeze(0))\n",
        "    targets.append(target_ids.unsqueeze(0))\n",
        "    \n",
        "  batch_input_ids = torch.cat(inputs).cuda()   # CPU実行したいなら、.cuda()を外そう。\n",
        "  batch_target_ids = torch.cat(targets).cuda()\n",
        "  #batch_input_ids = torch.cat(inputs)         # CPU実行の例。\n",
        "  #batch_target_ids = torch.cat(targets)\n",
        "\n",
        "  return batch_input_ids, batch_target_ids\n",
        "\n",
        "# transform_batchを効率よく作成するために yield で返す。\n",
        "def get_data_generator(dataset, lang_token_map, tokenizer, batch_size=32):\n",
        "  dataset = dataset.shuffle()\n",
        "  for i in range(0, len(dataset), batch_size):\n",
        "    raw_batch = dataset[i:i+batch_size]\n",
        "    yield transform_batch(raw_batch, lang_token_map, tokenizer)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o5RUscl9LoqS"
      },
      "source": [
        "### 前処理の動作確認\n",
        "``train_dataset[0]`` を指定しているが、その中のどの言語を指定するかはランダム選択しているため、実行する都度結果が変わる。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "M59-4uB8LoqS",
        "outputId": "26abe3e7-9bc7-47cc-ebcc-99227c3022e9"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/alt/alt-parallel/1.0.0/e784a3f2a9f6bdf277940de6cc9d700eab852896cd94aad4233caf26008da9ed/cache-09ad8735b58de120.arrow\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "<jp> ▁Italy ▁have ▁de feat ed ▁Portugal ▁3 1-5 ▁in ▁Pool ▁C ▁of ▁the ▁2007 ▁ Rugby ▁World ▁Cup ▁at ▁Parc ▁des ▁Princes , ▁Paris , ▁France . </s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>\n",
            "▁ フランス の パリ 、 パル ク ・ デ ・ プラン ス で 行われた 2007 年 ラグビー ワールド カップ の プール C で 、 イタリア は 31 対 5 で ポル ト ガル を下 した 。 </s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>\n",
            "Input shape: torch.Size([8, 20])\n",
            "Output shape: torch.Size([8, 20])\n"
          ]
        }
      ],
      "source": [
        "# Testing `data_transform`\n",
        "in_ids, out_ids = format_translation_data(\n",
        "    train_dataset[0]['translation'], LANG_TOKEN_MAPPING, tokenizer)\n",
        "\n",
        "print(' '.join(tokenizer.convert_ids_to_tokens(in_ids)))\n",
        "print(' '.join(tokenizer.convert_ids_to_tokens(out_ids)))\n",
        "\n",
        "# Testing data generator\n",
        "data_gen = get_data_generator(train_dataset, LANG_TOKEN_MAPPING, tokenizer, 8)\n",
        "data_batch = next(data_gen)\n",
        "print('Input shape:', data_batch[0].shape)\n",
        "print('Output shape:', data_batch[1].shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RfXTKdsqLoqS"
      },
      "source": [
        "## ファインチューニング部分"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wiEVqJKVLoqS"
      },
      "source": [
        "### パラメータやモデル評価関数を準備"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-4uv5u_FnE2F",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "08cf31af-1026-4606-e7df-ce52e50d5c2c"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/transformers/optimization.py:310: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
            "  FutureWarning,\n"
          ]
        }
      ],
      "source": [
        "# Constants\n",
        "n_epochs = 5    # エポック数\n",
        "batch_size = 16 # バッチサイズ\n",
        "print_freq = 50 # 途中経過を出力するタイミング(50バッチ毎に出力)\n",
        "checkpoint_freq = 1000 # モデルを保存するタイミング(1000バッチ毎に上書き保存)\n",
        "lr = 5e-4       # 学習率\n",
        "n_batches = int(np.ceil(len(train_dataset) / batch_size))\n",
        "total_steps = n_epochs * n_batches\n",
        "n_warmup_steps = int(total_steps * 0.01) # 学習率の減衰スケジューラ。徐々に減らしていく。\n",
        "\n",
        "# Optimizer\n",
        "optimizer = AdamW(model.parameters(), lr=lr)\n",
        "scheduler = get_linear_schedule_with_warmup(\n",
        "    optimizer, n_warmup_steps, total_steps)\n",
        "\n",
        "# 損失履歴を保存するリスト\n",
        "losses = []"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rKU9rtbHWkeB"
      },
      "outputs": [],
      "source": [
        "# モデルを評価する関数。\n",
        "# test_dataset を受け取り、平均損失を計算。\n",
        "def eval_model(model, gdataset, max_iters=8):\n",
        "  test_generator = get_data_generator(gdataset, LANG_TOKEN_MAPPING,\n",
        "                                      tokenizer, batch_size)\n",
        "  eval_losses = []\n",
        "  for i, (input_batch, label_batch) in enumerate(test_generator):\n",
        "    if i >= max_iters:\n",
        "      break\n",
        "\n",
        "    model_out = model.forward(\n",
        "        input_ids = input_batch,\n",
        "        labels = label_batch)\n",
        "    eval_losses.append(model_out.loss.item())\n",
        "\n",
        "  return np.mean(eval_losses)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zv1xRw6xLoqT"
      },
      "source": [
        "### ファインチューニング"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "po1XJsYkLoqT",
        "outputId": "9d206e5f-757d-4f40-b735-d9dff2c60501"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Thu Jun 23 01:54:04 UTC 2022\n"
          ]
        }
      ],
      "source": [
        "!date"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000,
          "referenced_widgets": [
            "e04f0fca781f45839f95e8bc21200655",
            "abb535522d5e4f21b7f1696d52193380",
            "a6cc1b1009c44b5e84472f7179ee84ae",
            "ec2f238aa4b44ff89f72d804e8c404e8",
            "c744994c21c8456b879691b1d4526eb3",
            "eecb8a6c0b4744fa81ce16d9395fdcf5",
            "c824e28e74f24cb8a8930ce5cb3d79b3",
            "f68ed3a53f13496fa65e3f6aba0e23bb",
            "3f88d4bbc6464fe68bf0a0aacc499ac7",
            "50de5e38862a430eb5d3addd6689a9ee",
            "0aad4afa452d4950948a94b3240a45dc",
            "61c9dac7f0b74f98a6222af18eaf003f",
            "d8ab212b837b4240a7a935c7c5004bf6",
            "24ee6dd47f3e42b887ec3d60fde18bda",
            "d5349a95a2944f9b83b4d65564d55273",
            "a62202bfc3f0482a8525d544e4db03b8",
            "65fe942df2b4434fae8c9c89b4b3ce4f",
            "5b8473a2bacb45c1aa5afa1e489620a8",
            "4df69aceeada44c6b1e3d1a40c6823cd",
            "05b953a1f35d4764a5baf3f625553705",
            "5ad493c2a8fb44bc8eed372de74b3038",
            "122f08881993415ba8591c4cdbc42ac9",
            "c4719eef452140388f9bbd2304ed73fd",
            "c2af9c05612f47b6851e9a1767f4dfc5",
            "0256259cfde94fa181a50804dd4673de",
            "442bbeae8bc744b595204799d4a5771d",
            "069396bcffbc49c68f88d8cfbabb8230",
            "a28e828cca6f4fc085ff0d045edf38ba",
            "5dd270895f5743818dea17dcc458f821",
            "75e738e5ab9a482b82b82df92882056a",
            "8d014fe3a4384fa9b74cc97b3fa06b2d",
            "f3cc98a0132f4b0094a885f699a22a6e",
            "4999315197a44a85a2256c82f0bc8970",
            "5d625d6b14684c2a9c8446f4c23e7214",
            "45e094304f494da69f00552a6adb9b16",
            "bb7d6fa764824d4d8d7f7602883e1518",
            "9f5a8768c56d46ce8d6a03306af4b6be",
            "2f5ad541d85b40ba8c5778a257d79dc8",
            "1ed2fd81696c49c39f77f7e90a1a9692",
            "a8a795d2014f40c895d33e5da3692be5",
            "1b6696dc2af44f81a73633432859c5ed",
            "ecb2daf92f384873b66e333cf787c6b9",
            "03c9ec0b08a04dc0acf67868b68ebbd4",
            "088804ed423d4432bfe6b7ab42e75b97",
            "81d2ca5f9cdd4ca68c25949d56b26604",
            "27f9086e2ce84ad69eb9d42160a19fd0",
            "387d12b5c5b34ef6aa847e43f3a94a98",
            "a16db321a4de4762b9a5f22ed94d0b3f",
            "779a63761a084439b7579ca7d23feed3",
            "3badf9971e0c41978ad1ad087799de22",
            "55da7364afb94beea30bc6b99041946d",
            "798f01b56a114ffba7b9a8af705a171f",
            "8cecd9f904084c8abf77d97ae0efb44c",
            "6421527d79224bfd8b48e1048560cb0e",
            "d8500a8c47374c5682271300aedc7a48"
          ]
        },
        "id": "Kv8a0jwDnEzK",
        "outputId": "1963bd0d-b317-49be-e82c-902be4362278"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:6: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
            "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n",
            "  \n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "  0%|          | 0/1131 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "e04f0fca781f45839f95e8bc21200655"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 1 | Step: 50 | Avg. loss: 12.114 | lr: 0.00044642857142857147\n",
            "Epoch: 1 | Step: 100 | Avg. loss: 5.509 | lr: 0.0004960707269155207\n",
            "Epoch: 1 | Step: 150 | Avg. loss: 4.604 | lr: 0.0004916056438649759\n",
            "Epoch: 1 | Step: 200 | Avg. loss: 4.273 | lr: 0.00048714056081443114\n",
            "Epoch: 1 | Step: 250 | Avg. loss: 3.990 | lr: 0.0004826754777638864\n",
            "Epoch: 1 | Step: 300 | Avg. loss: 3.813 | lr: 0.0004782103947133417\n",
            "Epoch: 1 | Step: 350 | Avg. loss: 3.695 | lr: 0.00047374531166279694\n",
            "Epoch: 1 | Step: 400 | Avg. loss: 3.614 | lr: 0.0004692802286122522\n",
            "Epoch: 1 | Step: 450 | Avg. loss: 3.558 | lr: 0.0004648151455617075\n",
            "Epoch: 1 | Step: 500 | Avg. loss: 3.488 | lr: 0.00046035006251116275\n",
            "Epoch: 1 | Step: 550 | Avg. loss: 3.353 | lr: 0.00045588497946061796\n",
            "Epoch: 1 | Step: 600 | Avg. loss: 3.444 | lr: 0.00045141989641007323\n",
            "Epoch: 1 | Step: 650 | Avg. loss: 3.339 | lr: 0.0004469548133595285\n",
            "Epoch: 1 | Step: 700 | Avg. loss: 3.293 | lr: 0.00044248973030898376\n",
            "Epoch: 1 | Step: 750 | Avg. loss: 3.236 | lr: 0.000438024647258439\n",
            "Epoch: 1 | Step: 800 | Avg. loss: 3.169 | lr: 0.0004335595642078943\n",
            "Epoch: 1 | Step: 850 | Avg. loss: 3.172 | lr: 0.00042909448115734957\n",
            "Epoch: 1 | Step: 900 | Avg. loss: 3.106 | lr: 0.0004246293981068048\n",
            "Epoch: 1 | Step: 950 | Avg. loss: 3.086 | lr: 0.00042016431505626005\n",
            "Epoch: 1 | Step: 1000 | Avg. loss: 3.030 | lr: 0.0004156992320057153\n",
            "Saving model with test loss of 3.238\n",
            "Epoch: 1 | Step: 1050 | Avg. loss: 2.993 | lr: 0.0004112341489551706\n",
            "Epoch: 1 | Step: 1100 | Avg. loss: 3.027 | lr: 0.0004067690659046258\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "  0%|          | 0/1131 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "61c9dac7f0b74f98a6222af18eaf003f"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 2 | Step: 50 | Avg. loss: 2.793 | lr: 0.00039953563136274335\n",
            "Epoch: 2 | Step: 100 | Avg. loss: 2.830 | lr: 0.0003950705483121986\n",
            "Epoch: 2 | Step: 150 | Avg. loss: 2.841 | lr: 0.00039060546526165383\n",
            "Epoch: 2 | Step: 200 | Avg. loss: 2.805 | lr: 0.00038614038221110916\n",
            "Epoch: 2 | Step: 250 | Avg. loss: 2.833 | lr: 0.00038167529916056437\n",
            "Epoch: 2 | Step: 300 | Avg. loss: 2.753 | lr: 0.00037721021611001964\n",
            "Epoch: 2 | Step: 350 | Avg. loss: 2.841 | lr: 0.00037274513305947496\n",
            "Epoch: 2 | Step: 400 | Avg. loss: 2.804 | lr: 0.0003682800500089302\n",
            "Epoch: 2 | Step: 450 | Avg. loss: 2.737 | lr: 0.00036381496695838544\n",
            "Epoch: 2 | Step: 500 | Avg. loss: 2.722 | lr: 0.00035934988390784066\n",
            "Epoch: 2 | Step: 550 | Avg. loss: 2.758 | lr: 0.000354884800857296\n",
            "Epoch: 2 | Step: 600 | Avg. loss: 2.777 | lr: 0.0003504197178067512\n",
            "Epoch: 2 | Step: 650 | Avg. loss: 2.766 | lr: 0.00034595463475620646\n",
            "Epoch: 2 | Step: 700 | Avg. loss: 2.671 | lr: 0.00034148955170566173\n",
            "Epoch: 2 | Step: 750 | Avg. loss: 2.723 | lr: 0.000337024468655117\n",
            "Epoch: 2 | Step: 800 | Avg. loss: 2.715 | lr: 0.00033255938560457226\n",
            "Epoch: 2 | Step: 850 | Avg. loss: 2.779 | lr: 0.0003280943025540275\n",
            "Epoch: 2 | Step: 900 | Avg. loss: 2.690 | lr: 0.0003236292195034828\n",
            "Epoch: 2 | Step: 950 | Avg. loss: 2.731 | lr: 0.000319164136452938\n",
            "Epoch: 2 | Step: 1000 | Avg. loss: 2.715 | lr: 0.0003146990534023933\n",
            "Saving model with test loss of 2.839\n",
            "Epoch: 2 | Step: 1050 | Avg. loss: 2.688 | lr: 0.00031023397035184855\n",
            "Epoch: 2 | Step: 1100 | Avg. loss: 2.738 | lr: 0.0003057688873013038\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "  0%|          | 0/1131 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "c4719eef452140388f9bbd2304ed73fd"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 3 | Step: 50 | Avg. loss: 2.425 | lr: 0.0002985354527594213\n",
            "Epoch: 3 | Step: 100 | Avg. loss: 2.475 | lr: 0.0002940703697088766\n",
            "Epoch: 3 | Step: 150 | Avg. loss: 2.477 | lr: 0.00028960528665833185\n",
            "Epoch: 3 | Step: 200 | Avg. loss: 2.508 | lr: 0.00028514020360778707\n",
            "Epoch: 3 | Step: 250 | Avg. loss: 2.466 | lr: 0.0002806751205572424\n",
            "Epoch: 3 | Step: 300 | Avg. loss: 2.468 | lr: 0.00027621003750669766\n",
            "Epoch: 3 | Step: 350 | Avg. loss: 2.477 | lr: 0.00027174495445615287\n",
            "Epoch: 3 | Step: 400 | Avg. loss: 2.448 | lr: 0.00026727987140560814\n",
            "Epoch: 3 | Step: 450 | Avg. loss: 2.454 | lr: 0.0002628147883550634\n",
            "Epoch: 3 | Step: 500 | Avg. loss: 2.456 | lr: 0.0002583497053045187\n",
            "Epoch: 3 | Step: 550 | Avg. loss: 2.405 | lr: 0.0002538846222539739\n",
            "Epoch: 3 | Step: 600 | Avg. loss: 2.508 | lr: 0.0002494195392034292\n",
            "Epoch: 3 | Step: 650 | Avg. loss: 2.424 | lr: 0.0002449544561528844\n",
            "Epoch: 3 | Step: 700 | Avg. loss: 2.450 | lr: 0.0002404893731023397\n",
            "Epoch: 3 | Step: 750 | Avg. loss: 2.431 | lr: 0.00023602429005179496\n",
            "Epoch: 3 | Step: 800 | Avg. loss: 2.452 | lr: 0.0002315592070012502\n",
            "Epoch: 3 | Step: 850 | Avg. loss: 2.389 | lr: 0.0002270941239507055\n",
            "Epoch: 3 | Step: 900 | Avg. loss: 2.424 | lr: 0.00022262904090016077\n",
            "Epoch: 3 | Step: 950 | Avg. loss: 2.425 | lr: 0.000218163957849616\n",
            "Epoch: 3 | Step: 1000 | Avg. loss: 2.444 | lr: 0.00021369887479907128\n",
            "Saving model with test loss of 2.908\n",
            "Epoch: 3 | Step: 1050 | Avg. loss: 2.395 | lr: 0.00020923379174852652\n",
            "Epoch: 3 | Step: 1100 | Avg. loss: 2.441 | lr: 0.00020476870869798178\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "  0%|          | 0/1131 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "5d625d6b14684c2a9c8446f4c23e7214"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 4 | Step: 50 | Avg. loss: 2.222 | lr: 0.0001975352741560993\n",
            "Epoch: 4 | Step: 100 | Avg. loss: 2.216 | lr: 0.00019307019110555458\n",
            "Epoch: 4 | Step: 150 | Avg. loss: 2.266 | lr: 0.00018860510805500982\n",
            "Epoch: 4 | Step: 200 | Avg. loss: 2.276 | lr: 0.0001841400250044651\n",
            "Epoch: 4 | Step: 250 | Avg. loss: 2.199 | lr: 0.00017967494195392033\n",
            "Epoch: 4 | Step: 300 | Avg. loss: 2.262 | lr: 0.0001752098589033756\n",
            "Epoch: 4 | Step: 350 | Avg. loss: 2.207 | lr: 0.00017074477585283086\n",
            "Epoch: 4 | Step: 400 | Avg. loss: 2.209 | lr: 0.00016627969280228613\n",
            "Epoch: 4 | Step: 450 | Avg. loss: 2.229 | lr: 0.0001618146097517414\n",
            "Epoch: 4 | Step: 500 | Avg. loss: 2.220 | lr: 0.00015734952670119664\n",
            "Epoch: 4 | Step: 550 | Avg. loss: 2.210 | lr: 0.0001528844436506519\n",
            "Epoch: 4 | Step: 600 | Avg. loss: 2.267 | lr: 0.00014841936060010715\n",
            "Epoch: 4 | Step: 650 | Avg. loss: 2.240 | lr: 0.00014395427754956242\n",
            "Epoch: 4 | Step: 700 | Avg. loss: 2.178 | lr: 0.0001394891944990177\n",
            "Epoch: 4 | Step: 750 | Avg. loss: 2.253 | lr: 0.00013502411144847295\n",
            "Epoch: 4 | Step: 800 | Avg. loss: 2.292 | lr: 0.00013055902839792822\n",
            "Epoch: 4 | Step: 850 | Avg. loss: 2.198 | lr: 0.00012609394534738346\n",
            "Epoch: 4 | Step: 900 | Avg. loss: 2.272 | lr: 0.00012162886229683873\n",
            "Epoch: 4 | Step: 950 | Avg. loss: 2.218 | lr: 0.00011716377924629399\n",
            "Epoch: 4 | Step: 1000 | Avg. loss: 2.179 | lr: 0.00011269869619574924\n",
            "Saving model with test loss of 2.972\n",
            "Epoch: 4 | Step: 1050 | Avg. loss: 2.225 | lr: 0.00010823361314520451\n",
            "Epoch: 4 | Step: 1100 | Avg. loss: 2.181 | lr: 0.00010376853009465976\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "  0%|          | 0/1131 [00:00<?, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "81d2ca5f9cdd4ca68c25949d56b26604"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 5 | Step: 50 | Avg. loss: 2.059 | lr: 9.653509555277729e-05\n",
            "Epoch: 5 | Step: 100 | Avg. loss: 2.038 | lr: 9.207001250223254e-05\n",
            "Epoch: 5 | Step: 150 | Avg. loss: 2.040 | lr: 8.76049294516878e-05\n",
            "Epoch: 5 | Step: 200 | Avg. loss: 2.016 | lr: 8.313984640114307e-05\n",
            "Epoch: 5 | Step: 250 | Avg. loss: 1.998 | lr: 7.867476335059832e-05\n",
            "Epoch: 5 | Step: 300 | Avg. loss: 2.020 | lr: 7.420968030005358e-05\n",
            "Epoch: 5 | Step: 350 | Avg. loss: 2.082 | lr: 6.974459724950884e-05\n",
            "Epoch: 5 | Step: 400 | Avg. loss: 1.987 | lr: 6.527951419896411e-05\n",
            "Epoch: 5 | Step: 450 | Avg. loss: 2.088 | lr: 6.0814431148419366e-05\n",
            "Epoch: 5 | Step: 500 | Avg. loss: 2.027 | lr: 5.634934809787462e-05\n",
            "Epoch: 5 | Step: 550 | Avg. loss: 2.057 | lr: 5.188426504732988e-05\n",
            "Epoch: 5 | Step: 600 | Avg. loss: 1.990 | lr: 4.7419181996785136e-05\n",
            "Epoch: 5 | Step: 650 | Avg. loss: 2.036 | lr: 4.2954098946240404e-05\n",
            "Epoch: 5 | Step: 700 | Avg. loss: 1.997 | lr: 3.848901589569566e-05\n",
            "Epoch: 5 | Step: 750 | Avg. loss: 2.068 | lr: 3.402393284515092e-05\n",
            "Epoch: 5 | Step: 800 | Avg. loss: 2.006 | lr: 2.955884979460618e-05\n",
            "Epoch: 5 | Step: 850 | Avg. loss: 2.048 | lr: 2.5093766744061443e-05\n",
            "Epoch: 5 | Step: 900 | Avg. loss: 2.070 | lr: 2.06286836935167e-05\n",
            "Epoch: 5 | Step: 950 | Avg. loss: 1.994 | lr: 1.6163600642971962e-05\n",
            "Epoch: 5 | Step: 1000 | Avg. loss: 2.039 | lr: 1.1698517592427218e-05\n",
            "Saving model with test loss of 2.833\n",
            "Epoch: 5 | Step: 1050 | Avg. loss: 2.031 | lr: 7.2334345418824795e-06\n",
            "Epoch: 5 | Step: 1100 | Avg. loss: 2.074 | lr: 2.7683514913377387e-06\n"
          ]
        }
      ],
      "source": [
        "for epoch_idx in range(n_epochs):\n",
        "  # Randomize data order\n",
        "  data_generator = get_data_generator(train_dataset, LANG_TOKEN_MAPPING,\n",
        "                                      tokenizer, batch_size)\n",
        "                \n",
        "  for batch_idx, (input_batch, label_batch) \\\n",
        "      in tqdm_notebook(enumerate(data_generator), total=n_batches):\n",
        "    optimizer.zero_grad()\n",
        "\n",
        "    # Forward pass\n",
        "    model_out = model.forward(\n",
        "        input_ids = input_batch,\n",
        "        labels = label_batch)\n",
        "\n",
        "    # Calculate loss and update weights\n",
        "    loss = model_out.loss\n",
        "    losses.append(loss.item())\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "    scheduler.step()\n",
        "\n",
        "    # Print training update info\n",
        "    if (batch_idx + 1) % print_freq == 0:\n",
        "      avg_loss = np.mean(losses[-print_freq:])\n",
        "      print('Epoch: {} | Step: {} | Avg. loss: {:.3f} | lr: {}'.format(\n",
        "          epoch_idx+1, batch_idx+1, avg_loss, scheduler.get_last_lr()[0]))\n",
        "      \n",
        "    if (batch_idx + 1) % checkpoint_freq == 0:\n",
        "      test_loss = eval_model(model, test_dataset)\n",
        "      print('Saving model with test loss of {:.3f}'.format(test_loss))\n",
        "      torch.save(model.state_dict(), model_path)\n",
        "\n",
        "torch.save(model.state_dict(), model_path)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "BKRuoLx5LoqU",
        "outputId": "0d4c3d1b-a3be-4ef9-a223-ea086d5ef9e3"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Thu Jun 23 02:18:53 UTC 2022\n"
          ]
        }
      ],
      "source": [
        "!date"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FuPOXVHlLoqU"
      },
      "source": [
        "## 学習中の損失推移"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 283
        },
        "id": "tgaCaATQNr9b",
        "outputId": "92f2caff-99f0-4810-febb-1145697627a0"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[<matplotlib.lines.Line2D at 0x7f83f5ae2f50>]"
            ]
          },
          "metadata": {},
          "execution_count": 23
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAD7CAYAAABgzo9kAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3dd2BT5foH8G/SvdPd0gJllVUQZMlQoUxFWrz+FGXcqxSKFkVUFBQVwUURURRkg9crFxfLK8gFbgHLkCGzjAKFFrr33sn5/ZHmtCFpm5a2SU6/n39MznlP8rykPjl5z3ueVyYIggAiIpIcubEDICKi5sEET0QkUUzwREQSxQRPRCRRTPBERBLFBE9EJFFM8EREEmVp7ABqyskpgkrV8Gn57u6OyMoqbIaITIOU+8e+mScp9w0wn/7J5TK4ujrUut+kErxKJTQqwWuOlTIp9499M09S7hsgjf5xiIaISKKY4ImIJIoJnohIopjgiYgkigmeiEiimOCJiCTK7BP8xbgszPn8ECqVKmOHQkRkUsw+wZ+7kYHbyfk4ez3D2KEQEZkUs0/wuQVlAICrCTlGjoSIyLSYfYJ/8pGOAIAu/i5GjoSIyLSYfYJ3cbQBAJSWK40cCRGRaWlQgl+1ahW6du2K69ev6+xbsGABHnnkEYSGhiI0NBRr1qxpsiDrYmttAYAJnojoXgYXG7t8+TLOnz8PPz+/WtuEh4dj6tSpTRKYoawt5ZDLZSgpq2zR9yUiMnUGncGXl5djyZIl+OCDD5o5nIaTyWSws7FEaRnP4ImIajIowa9cuRIhISHw9/evs92WLVswYcIEREREIC4urkkCNERRSQX+dzaxxd6PiMgc1DtEc+7cOcTExGDevHl1tnvttdfg6ekJuVyOXbt2YcaMGTh48CAsLCwMDsbd3dHgtvp4ejrd1/GmjH0zT+yb+ZJC/2SCINRZ1X79+vX47rvvYG1tDQBITU2Fu7s7Pv30UwwbNqzW4wYNGoQdO3bUOWZ/r6yswkYV2d+yLxbX7+Tg0/CHGnysOfD0dEJGRoGxw2gW7Jt5knLfAPPpn1wuq/PEuN4z+PDwcISHh4vPg4ODsXbtWgQGBmq1S0tLg7e3NwAgOjoacrlcfN7cnB2sUVhc3iLvRURkLu5ryb7Q0FCsX78e3t7emD9/PrKysiCTyeDo6Ig1a9bA0rJlVgS0tbZAWQVr0RAR1dTgDBwVFSU+3r17t/j422+/bZKAGiP6QjIqlSoUllTA0c7KaHEQEZkSs7+TFQDaeKhXFc/OLzVyJEREpkMSCf6pEZ0B8G5WIqKaJJHgnR3U9WgKSyqMHAkRkemQRIJ3sldP4byZmGfkSIiITIckEry7iy0AQFX3lH4iolZFEgleLpcBAC7EZRk5EiIi0yGJBK+Rll1s7BCIiEyGZBL8qP7qQmj1VF4gImo1JJPgXRzUF1rzi1iygIgIkFCCT0hVFwb6IeqmkSMhIjINkknwTz3aCQCQkllk5EiIiEyDZBK8l6sdAOBOeqGRIyEiMg2SSfAymczYIRARmRTJJHgAGPGgenER3vBERCSxBH/obBIA4My1dCNHQkRkfJJK8FNGq1eZustxeCIiaSX4Ad28AADX7uQYORIiIuOTVIJ3rrrZiUPwREQSS/Aat5LzjR0CEZHRSS7B9whwhZM912UlIpJcgvfzcERBcQVUKo7TEFHrJrkEn56jLhn8/f5YI0dCRGRckkvwQR3dAQA3kzgOT0Stm+QS/Mh+/vBytYO3m52xQyEiMirJJXgA8HCxRW5BmbHDICIyKkkmeCd7axQUVxg7DCIio5JmgrezQn4xV3YiotZNmgne3gql5UpUVKqMHQoRkdE0KMGvWrUKXbt2xfXr13X2lZSUYO7cuRg9ejTGjRuHQ4cONVmQDeVUVbKggGfxRNSKWRra8PLlyzh//jz8/Pz07t+0aRMcHR1x4MABxMfHY8qUKdi/fz8cHByaLFhDOdlpEnwF3JxtW/z9iYhMgUFn8OXl5ViyZAk++OCDWtv8/vvvmDRpEgAgICAAQUFB+OOPP5okyIbSlCooKOEZPBG1XgYl+JUrVyIkJAT+/v61tklOTtY6u/f19UVqaur9R9gImgQfcyvbKO9PRGQK6h2iOXfuHGJiYjBv3rxmD8bd3bHRx3p6OomPnRX2AIC84gqt7eZMKv3Qh30zT1LuGyCN/tWb4E+fPo24uDiMHDkSAJCamoqwsDB8+umnGDZsmNiuTZs2SEpKgpubGwAgJSUFgwYNalAwWVmFjSoS5unphIyMAp3tJy+n6t1ubmrrnxSwb+ZJyn0DzKd/crmszhPjeodowsPDcfToUURFRSEqKgo+Pj7YtGmTVnIHgHHjxuHHH38EAMTHx+PSpUt4+OGH7zN8IiJqrPuaBx8aGoq0tDQAQFhYGPLz8zF69GjMmjULS5YsgaNj44dc7teYAW0BANn5pUaLgYjImAyeJqkRFRUlPt69e7f42N7eHl999VXTRNUENAtvf739Eha9MMDI0RARtTxJ3skKAG3c1fPvU7KLjBwJEZFxSDbBPxPcGQDg72m8YSIiImOSbIK3spRjSJAPcgtZNpiIWifJJngAcHWyQV5hOVQC12clotZH0gle4WgDpUpAQRFLFhBR6yPpBO/mZAMASM0uNnIkREQtT9IJ3sVRneAj/33OyJEQEbU8SSd4TwVLBRNR6yXpBO9kby0+TsnifHgial0kneABYFZITwDAL4fjjBwJEVHLknyC79/NEwBw7kYmBE6XJKJWRPIJ3kJe3cW0nBIjRkJE1LIkn+ABYOaEHgCAX4/eNnIkREQtp1UkeE3hsT+vpPGuViJqNVpFgm/rXV1w7M1vjhsxEiKiltMqErxcJsOn4Q8BAHIKWHyMiFqHVpHgAcDbzV58rFSpjBgJEVHLaDUJvqa9f94xdghERM2uVSX4OU/1BgDcTs43ciRERM2vVSX4Pl08AAAymZEDISJqAa0qwWucu5Fp7BCIiJpdq0zwAHAjMdfYIRARNatWl+DtbCwBAJ9+f5YVJolI0lpdgv/q1WHi44UbTuJaQo4RoyEiaj6tLsFbyOX46tWHxefLtp1jlUkikqRWl+ABwNHOCg/18Bafx97heDwRSU+rTPAAEB7SE84O6hWflm3jmq1EJD2tNsEDwPKIIeLj6UujcO56hhGjISJqWq06wVtaaHf/6x2XUFahNFI0RERNy6AEHxERgZCQEEycOBGTJ0/G1atXddp8/fXXGDx4MEJDQxEaGorFixc3ebDN4fnHumk9f+nzI0aKhIioaVka0igyMhJOTk4AgIMHD+Kdd97Bzp07ddpNnDgR8+fPb9oIm9kjD7TBkCAfhH92WNxWWl4JW2uD/mmIiEyWQWfwmuQOAIWFhZBJrJiLpYUcU0YHis+TMnkDFBGZP4NPUxcuXIhjx45BEARs3LhRb5s9e/bg6NGj8PT0xCuvvIK+ffs2WaDNbWQ/f7T1csTSrWdx+mo6OrVxMXZIRET3RSY08C6fXbt2Yc+ePdiwYYPW9oyMDCgUClhZWeHYsWOYN28e9u7dC1dX1yYNuDnlFZZh6qJ9AIAdkRNgZdmqr0ETkZlr8EDzxIkT8f777yMnJ0creXt6eoqPhw4dCl9fX9y4cQMDBw40+LWzsgqhUjX8rlJPTydkZBQ0+Li6JCTmwKVqnryxNUf/TAX7Zp6k3DfAfPonl8vg7u5Y+/76XqCoqAgpKSni86ioKLi4uEChUGi1S0tLEx9fvXoVSUlJ6NChQ2NiNio/TwcAEKdLCoKAguJyxCXnGTMsIqIGq/cMvqSkBK+++ipKSkogl8vh4uKCtWvXQiaTYebMmZgzZw569eqFFStW4PLly5DL5bCyssKyZcu0zurNxZj+bbHl92s4dSUNTwwJwD/3xeKPC8kAgK/nPgwHWysjR0hEZJh6E7yHhwd++uknvftqjsNHRkY2XVRG1LWd+pfJr8duY/zg9mJyB4CTV9IQ/KC/sUIjImoQTva+h5erPaws5egZ4IakDO3pkokZnD5JROaD00T0COrghrScYly9o64V38HXGQBw+FwSV4IiIrPBBK+Hh4sdsvPLcPhcEgBgSJCPuO/T788aKywiogZhgtdD4WiNsgolUrKKAQDBD/ph5hM9xP2VSpWxQiMiMhgTvB5FpZVaz2UyGQZrncX/1aj5+kRELYkJXo8JQwL0bu/T2QMAcDulAL+fTGjBiIiIGo4JXg8bawuM7KeeDvnEkPbi9pf/1kt8vP3ILWw/EtfisRERGYrTJGsxZXQgRvf3h4fCTtwml2tX0dxzIgFPPdqppUMjIjIIz+Dr4OVqD/k9pZE3zR+ByaO6iM95wZWITBUTfAPJZDIM7+snPv8p6qYRoyEiqh0TfCNYWsixau7DAICDfyUaORoiIv2Y4BvJvkbRsQaW1CciahFM8PdhZFXhsaOXUnAzkeWEici0MMHfhzvp6gUBtuy9hk++/wvpuSVGjoiIqBoT/H148zntNWcv3Mw0UiRERLqY4O+DpYUcYeO7i8+v38nleDwRmQwm+Ps0tJevWE74r+sZWLT5NJIyCqFUcX48ERkXE3wTWPj3fuLjxIxCvLfpFGYuO4zE9EJsPXAd+cXlRoyOiForJvgmIJfJ0CPAVWf7+5tP4X9/JeK1r48aISoiau2Y4JvIvGf74o1n++jdx2F5IjIGFhtrQj0D3LDhreEoKVMiPacEH313RtyXmVcCDxe7Oo4mImpaPINvYhZyORztrNCxjTOWTB8obn9rzQkUllQYMTIiam2Y4JuRv5cjFj0/QHx+7kYGACCvsIwXXomo2THBN7P2Pk4In6Bez3XfyTsAgNdWHcPcr3jhlYiaFxN8C3iop3o915SsYkxfGiVuV/HqKxE1IyZ4I/rix/NcvJuImg0TvBFdjs/B2t0xxg6DiCSKCb6FaEoL21hbYHDVkA0A3EximWEiah4GzYOPiIhAYmIi5HI57O3t8d5776F79+5abZRKJT766CNER0dDJpMhPDwcTz/9dLMEbY4mjeyMkf394eNmDwCYNjYQESv+QG4hZ9MQUfMwKMFHRkbCyckJAHDw4EG888472Llzp1ab//znP7hz5w7279+P3NxcTJw4EYMHD4a/v3/TR22GLC3kYnIHAFtrS3i72SMtuxjFpRVaK0R9/N0ZxCXn49E+bfCPcd2MES4RSYBBQzSa5A4AhYWFkMlkOm327t2Lp59+GnK5HG5ubhg1ahT27dvXdJFKkKa08Bc/X0BadjFSsoqQklWEuOR8AMCR88nGDI+IzJzBpQoWLlyIY8eOQRAEbNy4UWd/SkoK2rRpIz739fVFampq00QpUbOf7IVFm08hLikfb6//U2+bCzczMcrTSe8+IqK6GJzgP/74YwDArl27sGzZMmzYsKHJg3F3d2z0sZ5mmAQNifmf+2IxanAHs+yfodg38yTlvgHS6F+Di41NnDgR77//PnJycuDqWl0i19fXF8nJyejduzcA3TN6Q2RlFTZqXrinpxMyMgoafJwpWPT8ACz+9rTO9rlPP4Avf74AQP3vYa79q485f3b1Yd/Ml7n0Ty6X1XliXG+CLyoqQn5+Pnx9fQEAUVFRcHFxgUKh0Go3btw4/PzzzxgzZgxyc3Nx8OBBbN269T7Dl742Hg4AgOF9/eDqaI1R/dtCEAB7W/VHw1k2RNRY9Sb4kpISvPrqqygpKYFcLoeLiwvWrl0LmUyGmTNnYs6cOejVqxdCQ0Nx4cIFjBkzBgAwe/ZstG3bttk7YO6sLOVY/+ZwWMhlei9eA0BpeWULR0VEUlBvgvfw8MBPP/2kd1/NcXgLCwssXry46SJrRSwt9E9m6tpWgdi7ucjOK4WV3hZERLXjnawmbMxA9S+gX6NvGTkSIjJHTPAmzKrqzH7PsdtGjoSIzBETvAnr0cENAODn6WDkSIjIHDHBmzC5TIZHHvBFUQkvshJRwzHBmzhvN3vkFpahuFQ3yQuCgOlLo/Cv/8aK205dTUP0Rd0SB3fTC5s1TiIyPUzwJk7hYAMA4hquF+MyEXsnBwCw/IfzAIBD55LE9mt3X8aWvde0XuP8zUws2nwKh2u0IyLpa/CdrNSynBzUEyTzi8rh42aPL3++CAAY0M0LVxNyxHaJGYWwtbIQnytVKljI1d/fGTklAIDv/huL4X39Wip0IjIyJngT52xvDQAoKC5HXmGZuP30tXStdu9vOqX1POZWNjwVdmjj4YCi0gpx+/kbmXBxtIa7sy2cHaybMXIiMjYmeBPnaKc+g1+9MwbPjOiss39Yb18cvZiis33lL+oz/Q1vDcelW1ni9q+2q7dbyGXY8NaIOt+7vEIJpUqAnQ3/TIjMEf/PNXEKJxvx8U+HbmrtW/HyUCgcbfQmeI2Zyw7r3a5UCVAJAuS1lEcAgPnrTiCvsByb5o+otYwCEZkuXmQ1cXUlYIWjTa377qXvVVZtv1Rr+7yicuRVFTo7eTXN4PfRiE/NR6VS1eDjVIIgLoRCRPeHZ/BmYNtHj+O5d/cCAF79v97wcbPXGj9f8fJQnL2eAQCIvpiCXh3d8NvxBK3XWPXaI7CxskBxWSXmrIwGoJ5dU1NSRiGUKgHtvJ3w+Q/nxO3rf70CB1sr3E7Ox/gh7cWLt7XJKSjDkm/PYHBPb8yc0LNBfZ0ReQj9Aj0x+2+9GnQcEeligjcDjnZW2LwguNb9CkcbBD+oXvtW89++XTzx4T/PAFDXnNeMozvaWeGrVx8WkzwA/Hz4JgL9FeK4/ZTRgUjMKNJ6jy9+ugAAsLG2wNiB7eqMNz5VveTgictpeLSPHwLbKupsr6GqOnP/q+rLiojuD4doJMqlxhl+O2/tBQEc7awwfnB7AMDKny/g9z/viMkdALYeuA4ACGyrwKN9tBdtScnSTvyVShWuxGdjy96rmL40Cv/7KxFf1xj6Wbr1rPj47PUMTF8ahdTsYr0x5xZUzxJqzMIvRKSNZ/AS5WBbXWBY3wVSK0v1d/uFuCydfRpvPdcXcrkMlZUqHItRr6+bV2MBkm92xeDMPdM1NV8ONR06l4SOvs5YtUOd+N9Z/yeeerQjxg8OENtcic8Wb9wCgJtJeQaf+RORfjyDlyhrK/VHO6yXr979A7p51Xn8iL5+kMvVXwxhT/TA5gXBsLGywIW4LJSVKwFAJ7nX1L+rp/j4X/+N1VmWcPuRWygpU5dfUAmCVnIHtM/8iahxeAYvUTKZrM5xe193B0RMDMI3u2IAAM+N6oI7aQUIG98Dmbkl8FDY6RxTVqFO7B9sOYUXQ4NqfW0vVztEPNkLJ2JSseG3K7W2m/3FH2jr7Yi7aayTQ9QceAbfij3Q2UN8PLp/W4SN7wEAepM7AEwe1QUAkJZTIp6Rd2unHkbp7O8itnv37/0BAIODfLSO/7/hnbD0xcFa2+5N7oN7+uCpRzsCgHiGT0SNwwTfillZyrHhreHYNL/uO1o1gvv562wb2a8tPp45CPMm9YG7s3pevubuWwAIHdYBABAyNACPP9QeXgo7vFLLFMgFUx7E38d2hbuLLQAgu8ZFVyJqOA7RtHL1zWmvSS6T4cFAT3HOPQA80NldXFP2s4ihOsdMGBqAxx9qByvL6kJofQM9tdr07OCGGeO7w6Xqxi03J3WCz8kvhZ8HFzshaiwmeGqQDr5OYoJf9PyAWhcM15DLZJDXSO4ai54fgItxmRg9uANs73kJj6oz+PTckqYJmqiV4hANNUjNm5za+zg1+nXa+zhhwtAOaOut+xqa+jvf79edcklEhuMZPDWIpYUcyyOGwNKy+c4NNPV3WN6M6P4wwVODuTnbNvt79O7krnVTFRE1HIdoyCQ52FppLVRCRA3HBE8mycHOEoUlTPBE94MJnkySo50VSsuVjaopT0RqTPBkkjQ3SxWV8m5WosZigieT5FS12Hgu72YlarR6Z9Hk5OTgrbfewp07d2BtbY327dtjyZIlcHNz02q3YMECHD9+HK6urgCAcePG4aWXXmqeqEnyNDc75RSWoT1qn29fUalCQmqBVi0cIlKrN8HLZDLMmDEDgwYNAgBERkZi+fLl+OSTT3TahoeHY+rUqU0fJbU6ttbqu19Layk4lpBagMXfnkbXtgrE3s3FshcH11okraxcCQsLWb133dZUVFqByK1nMef/esPDRf/rEpm6ev/iFQqFmNwBoE+fPkhOTm7WoIhsrNQJfmf0LQDAnhPx+Cu2ugaOpppl7N1cAOoKl7V5acURhH92uEHv/8qX0UjMKMJba0406DgiU9KgMXiVSoVt27YhOFh/nfEtW7ZgwoQJiIiIQFxcXJMESK2Ti6N6DN7K0gIVlSpsP3ILq3dWLwXoYKv94/NueiHKypXYsvcqSsurz/rzihp+sxRn7pBUyARBMHjxy8WLFyMtLQ2rVq2C/J4qhGlpafD09IRcLseuXbuwcuVKHDx4EBYWuoWmiAwx4Y3dOtvCQoJw+koqLt7MrPW4xwYHIOL/HkBOfin+vvi/4vapj3XDpFFdAQAXb2agrZcTyitVcHG0hq119RfGwVMJWPmjeoUpBzsr/PDR4wCAwpIKyGVApVKAnY2luOwhkakyOMFHRkYiNjYWa9euhbW1db3tBw0ahB07dsDPz8/gYLKyChu12LKnpxMyMgoafJy5kHL/6urbG6uPIaeOWTSB/i6YFRqEN1YfM/j9po4J1FvErObqVzMiD0ElCOjWToFrd3Ix5/96o09nD0xfGiW28XCxxbKXhtT5XjX7JggCwiIPAQA+njkIPm72WmvlpuUUw9rSAlcTsvFQTx+xHo+pkvLfJGA+/ZPLZXB3d6x9vyEvsmLFCsTExGD16tW1Jve0tDTxcXR0NORyOby9vRsYLlG1Jx9Wr+xkZSlHZz/dWTKTRwfC1ckGGw1csASovULlip+q14RVVZ3z9Oygnin21S8Xddpn5pUCAC7dysKcldGYvjQKN5Pyan3f+NTqZLFww0kcPl99HSunoAxvr/sTb6w+ho2/XcWhs0kG94eoLvUm+Bs3bmDdunVIT0/Hs88+i9DQUMyePRsAEBoaKib2+fPnY8KECQgJCcGaNWuwZs0aWFqylhk1nr+XerGPikoVrCzlmFi1OhQAjB/cHu2qSg3LZTKsf3M4endyx/C+ur8YN741Ah+8MKDO94q5lY3pS6NwOT5b3Pb4Q+3FxzXP3mtu++KnC2JJhU/+9Zf2a8ZlYvrSKCz791l8+M8zWvvO3ai+YHw7JV9rn2btW6L7VW8G7tKlC2JjY/Xu2727eoz022+/bbKgiACgjXv1ak5uTjaYMDQAV+KzcT0xD8P7aCdySws55j79AABg6uhALFh3Apl5pfgwbCDkchnaeTshqIMbYm6rE/ibz/XFmdh0nbPl/51JBACEje8OmUyGIUE+OB6TanDMgiBAJQiYueywuO3anVyddjG3qr9IVu24pLWvoJhVNKlp8CoRmSxrq+oL9H8f1xUymQwLpvbD5gXB4rqt+sjlMiyePhAfvDAAfp7V45Ntqpb/69ZOge7tXTFtTFd89erDmBTcWWxzvuribd8u6mUFn3q0k9Zr9++qvdygZtuU0YEA1MMtH357RqeNxqq5j4iPVVVfBhpjB7aFl6tdndcdiBqCYyhk0lbOGQY7G8sG3aQEAHY2luIQjsbjg9sjt7AM08Z2Fbc52llh7MB28HCxxeqdMeJ2+6ppmK5ONoiYGIRvdsVg2phAjHjQHxWVSsxafkRsO6SXL+xt1O3nfXO81pieGBIAe1tLjBnQFvtP38Uvh+LEL6opowMxsp8/ElILuNg4NRkmeDJpmpo0TcHZ3hovhgbp3devqxd6BLjiSnwOurVTaO3r381La5aNlaUFFk7rh4pKFbYeuI5AfwX0TXr5z+eh4kwMpUolLnDuX/WrYt+pO2Jbza8LVycbXL+rO6RD1BhM8ERVXp/UB2eupWNg9/pnf3WqmtXz4Yzqu7wHdvfCqavpANQXdmuyqHHfSO9O7jqv177q14bCyQa5heVQCYLJT5Uk08cET1RFLpMZlNxr82JoEF4Mrb+ds4M1vF3ttMoraIaEHGytoFQJyMwrhVcttXWIDMUET2QEn84aDAA4fD4Jro424nZvV3VSv5tWyARP940JnsiI7p3u2cHXGQCnSlLT4DRJIhOiKbKWnFlk5EhICpjgiUyI5mLswb8S62x39GIKzlxLb4mQ6lRYUoGKSlbfNFVM8EQmqrY6gIIgYPPeq/hmV0ytbVpCek4x5qyMxqzlh5HPISWTxARPZGKeHdkFAJBfXKF3f82FyK8k5GjtuxiXiT+vGF5aoTblFUp8+v1fOH9DtyyzUqXC9qgbWFmjCNsbq6orembllWL60ijMX3vcqF9AxIusRCZHUTUOv/vobRy9mIzXn+mDbu1dUalUYceRW1o3SH3+w3msef1R2FQtcfjlz+qkm55TgglDArRKEhvqk3/9JVbGvJF4Ee/+vT86tnEW90dfTMF3+7TrUylVArLySmFnY4k316jv5s3ILcWJy6kYEuTb4BioaTDBE5mYTm3UN1EdPqcuhLZs27k62x+9lIKR/fy1zpZ3Rd9GWnYxZk7oqdNeJQj494HrCH7QH7F3c+HjaofuAW7i/nvLHn/03Rk8N6oLRvdvCwA6yV1Dk9hrik8t0EnwJWWVyC8uh7erfZ39ovvHIRoiE+PuYquzJKE+D/VU35SlKZCWmKE98+bE5TSdY77fH4udf9xC1NkkvLvxJP7131h89kN1Lfy76YV632vbwRu1xjGwu1et+w6eSURZeXX5Y0EQ8MXPF/D2uj+RlFFYo91dXIzLqvV1qHGY4IlMUM1x9ns99lA7rH7tEcx8ogcA4HJVCeRFm0/ptC2rUEIQBFy+ra53H3U2CXtOJOi0KyguR15hmd7XqOneuviPPOCLF0ODMG5gO63tYwa0FR+v2a0u4pZbWIawyEO4maj+hfBD1E0AwInLqfj3wRv48ucLdb43NRyHaIhMWFBHN7F2vI2VBcoqlBg7sB3sbLT/102scea97KXBuJteiK+3X8JLnx+BIV796qjW880LgiEI6nH1X47E4dTVdJ0yxl4KOzw3Ul0muZNf9Rj9vGf7oEeAG/afvp1Ai3QAAA3BSURBVAsA4pl57D118XMLypCUWYQN/7kibispq9TpGzUe/yWJTNB7/+iPoxdTMGVMIG4l5eNuegGG9/VDfnEFnGtU2Gzn5Yg76YV4v+rM+5kRneHhYqdV3Kw2PTu4iWf/NbXzVle7lMlk8FDYiQuW1Fz7tp2PEz54vnqVLNsaSdnB1goAsG7eo2JZ5V+P3kbePVMpkzKL8N7Gk1rb0nNK0N5Hu8wzNR4TPJEJ6uDrLJYt6Ozvgs7+6guvLg7a5ZNf/lsvvLX2hPg8qKP6YqlmJk5N3q52CH7QHz07uEEQBLg42uCXw3H440KyVrsPXhio9dzd2Rb5RdXJ+fnHuuFvIwORmVn9q8HasvoLxauqno6VpQV83OyRml2MXUdva72mv6cjEjN0x/szcpngmxLH4InMmIfCDgE1EqKmWJlMJsPmBcHYvCAYTz7cAc+O7IJPZw3G6AFt0cbDAX6ejnC0s8I/xnXVu6B5TS+Fas/EGRLkozP9sou/Ak8P74SX/9ZLa4hl8XTtLwsAeOGxbvBUaK/INeJBdU2eb3bFaK1yRfeHZ/BEZu7tqQ/iakIOcgvLYWVpobN/wtAOeo5Sk8lkeGdaP+QXlWPeN8fx+jMP6LTxUNhhSdhAvL/pFALbKmpdXeuxGouUa1hZymFtJUd5hbqcwfzJfdG1nSsy8kpxrsZNVM8GdxHXx83JL6tzSUYynEwwoVvNsrIKoVI1PBxPTydx5RwpknL/2Dfz1JC+KVUq7Dt5B6P6tRVvyCqvUOL0tXQM7ukDuVz9a0AzQ6dfV0/MfrJX8wRuIHP57ORyGdzdHWvf34KxEFErZCGXY/zgADG5A+oF1Yf28hWTOwAsjxgCAPUOGZHhmOCJyCQonGwgl8lQWKK/Bg81HBM8EZkEuUwGJ3urFl3sJCuvVNIXdZngichk5BWV448LKff9OtEXkjF9aRTSc4rFbSVllSgqrf51kJlXgjfXHMe2A7WXYahp0eZTOHI+SWd7Q78gkjIKcfZ6RoOOaSwmeCIyOfcO0xSWVOjcSVvXsVt+vwYAWLDuT3H7m98cxytfRiPmtvrOWk3tnv+drXtxFQCIS87D3fRC/HNfrFZsJ6+kYUbkIVy/q74ZrKxCqfUlUlNieiGWbj2L9zadwqodl1BYUoFKZfMulsIET0QmI2x8dwDqZKyhEgTMWRmtdSdtXeasjNZ6rlSpoBIEFJep6/us+FFd86ZmsbPU7GLUJbfGl8v+09Xlmtf9ehkAxLr5L31+BK98GY09J+J1XuP9zafELwJNnOGfHYZS1XxJngmeiEyGu7N6/ntZhRJXqxYz2X/qrrh/+tIovYuQaOib9X01PkfnF8Hh80nYfuSW+Pyd9X/ee5iWX2q0/e14AhJSC7Tea9+pOzh6sXpoafuRW1p3/9asqHmvmcsON9savEzwRGQy2npXz+k+fC4Ji7ecxk+Hbmq1+Wq7elGTsnKl3qEcjUVVtXKOx6Ri+T019WuraX/2egamL43ChDd24/1N1XVy0u45w88uKEV8qvY8+c17r2o9/3qHOs7s/FK8tKLuom/31uBvKvUm+JycHMycORNjx47FhAkT8PLLLyM7W7dAUUlJCebOnYvRo0dj3LhxOHToULMETETS5WBrBTsb9Xz509fSkZCm/2aj6Uuj8NKKI5izMhqJGYVQqlQ4cTlVvBt29pNBaO/jBGd7K1hbWYjj7eMHa99tOyukJ/p09gAAlJZXYtWOS+K+xIwiZOWVarXX3ICVmVuKk1fU9fYD7qmd88akPgCAuKR8AMC8b3QXQrnXw72bZ9WreksVyGQyzJgxA4MGDQIAREZGYvny5fjkk0+02m3atAmOjo44cOAA4uPjMWXKFOzfvx8ODg7NEjgRSdPq1x7VqTs/tJcPCoor9C4K8v4m3Rr2blVDPc4O1lrF1Ib38dOqhz+ohzcqlSqcv5mJiBV/6LzO8h/P44FO7gCAQH8X9K56vO1/1TNvZoX0xNs1hnh6dqheHes/x+O1Xq+TnzPikvKxcf4IyGUy3EkrQHmFqlFLKxqi3jN4hUIhJncA6NOnD5KTk3Xa/f7775g0aRIAICAgAEFBQfjjD91/MCIiQ3Xxd8HDvX3xwmPdMffpB7Bq7iMGHefqZAMAqKisvoDpYGsJN2cb8fm0sV0BQO9MlufGqPelZReLde3DQ3rCylI3ZXq62uHLOcPg5+GAb17Xjm/nH9Vj96/8rRcWTuuPzQuCIa9K6O28ncRKoc2hQcXGVCoVtm3bhuDgYJ19ycnJ8PPzE5/7+voiNbVhq7vXVVOhPp6e0i4xKuX+sW/mqTn7tnXJY1CqVHB10i06FhbSE4521ujV2QMzPj6g9/guHdTDLmk5JeK2p4K7wMuremGSZ8Z0AwA8NqwT/lljTH5maBBGDWyHbfu1x+m7dvIEAIwe2A4HqhY+//b9MXB3sYM3gLVvjxLbTh7bDf/+7zWt48cM7Vhvv5tagxL8hx9+CHt7e0ydOrVZgmGxMf2k3D/2zTy1VN8y9MwpH9pDvRYtlEqsm/coABlmLT8MB1tLhD3RAwE+1bG98WwffP7DeVjIZRjawwsZGQVwcbSGt6u9VvxfvDwU89edQHmFCoO6esLe1gqb5o9AWKT6WuKDgZ5i+zsp+eJxqvJKvf8Oo/q2wYGT8cjIVY/hv/B4t2b596qv2JjBCT4yMhIJCQlYu3Yt5HpWi2nTpg2SkpLg5qYef0pJSdEa2iEiamqa8sibF+iOKgBAzwA3nX1fvDxMp52Low3WvjFca5tMJoOvuz1SsooxrMZF0EE9vBF7NxePDWqHusyf/CC+3n4Js58MgofCzpDuNDmDEvyKFSsQExOD9evXw9pad6UYABg3bhx+/PFH9OrVC/Hx8bh06RI+//zzJg2WiKglfTzzIZ1tw/v6YWB3r3rXjnVztsWiFwbU2aa51XuR9caNG1i3bh3S09Px7LPPIjQ0FLNnzwYAhIaGIi1NPVUoLCwM+fn5GD16NGbNmoUlS5bA0bHxY+pERKbK3taq2Wa+NCUu+GEGpNw/9s08SblvgPn0jwt+EBG1UkzwREQSxQRPRCRRTPBERBLFBE9EJFENupO1udVcYb0ljzUHUu4f+2aepNw3wDz6V1+MJjVNkoiImg6HaIiIJIoJnohIopjgiYgkigmeiEiimOCJiCSKCZ6ISKKY4ImIJIoJnohIopjgiYgkyuwT/O3btzFp0iSMHTsWkyZNQnx8vLFDqlNkZCSCg4PRtWtXXL9+XdxeVz8au6+l5eTkYObMmRg7diwmTJiAl19+GdnZ2QCA8+fPIyQkBGPHjsX06dORlZUlHtfYfS0tIiICISEhmDhxIiZPnoyrV68CkMZnp7Fq1Sqtv00pfG4AEBwcjHHjxiE0NBShoaGIjo6uN05z6l+tBDM3bdo0YdeuXYIgCMKuXbuEadOmGTmiup0+fVpITk4WRowYIcTGxorb6+pHY/e1tJycHOHPP/8Uny9dulR4++23BaVSKYwaNUo4ffq0IAiCsHr1amHBggWCIAiN3mcM+fn54uMDBw4IEydOFARBGp+dIAhCTEyMEBYWJv5tSuVzEwRB5/83QWh8H0yxf7Ux6wSfmZkp9OvXT6isrBQEQRAqKyuFfv36CVlZWUaOrH41/+Dq6kdj95mCffv2Cf/4xz+ECxcuCOPHjxe3Z2VlCX369BEEQWj0PmPbuXOn8OSTT0rmsysrKxOeeeYZ4e7du+LfppQ+N30JXkr9q41JVZNsqJSUFHh7e8PCwgIAYGFhAS8vL6SkpMDNzc3I0Rmurn4IgtCofcbuv0qlwrZt2xAcHIyUlBS0adNG3Ofm5gaVSoXc3NxG71MoFC3aH42FCxfi2LFjEAQBGzdulMxnt3LlSoSEhMDf31/cJqXPDQDmzZsHQRDQr18/vP7665Lrnz5mPwZPpunDDz+Evb09pk6dauxQmtTHH3+Mw4cP47XXXsOyZcuMHU6TOHfuHGJiYjB58mRjh9Jstm7dil9//RXbt2+HIAhYsmSJsUNqEWad4H19fZGWlgalUgkAUCqVSE9Ph6+vr5Eja5i6+tHYfcYUGRmJhIQEfPnll5DL5fD19UVycrK4Pzs7G3K5HAqFotH7jG3ixIk4efIkfHx8zP6zO336NOLi4jBy5EgEBwcjNTUVYWFhSEhIkMznpvl3tba2xuTJk3H27FlJ/l3ey6wTvLu7O7p3747ffvsNAPDbb7+he/fuRh+eaKi6+tHYfcayYsUKxMTEYPXq1bC2tgYABAUFobS0FGfOnAEA/PDDDxg3btx97WtpRUVFSElJEZ9HRUXBxcVFEp9deHg4jh49iqioKERFRcHHxwebNm3CjBkzzP5zA4Di4mIUFBQAAARBwN69e9G9e3dJ/F3Wx+wX/IiLi8OCBQuQn58PZ2dnREZGomPHjsYOq1YfffQR9u/fj8zMTLi6ukKhUGDPnj119qOx+1rajRs38MQTTyAgIAC2trYAAH9/f6xevRpnz57FokWLUFZWBj8/P3z22Wfw8PAAgEbva0mZmZmIiIhASUkJ5HI5XFxcMH/+fPTs2VMSn11NwcHBWLt2LQIDA83+cwOAu3fv4pVXXoFSqYRKpUKnTp3w7rvvwsvLSxL9q4vZJ3giItLPrIdoiIiodkzwREQSxQRPRCRRTPBERBLFBE9EJFFM8EREEsUET0QkUUzwREQS9f/fu1LA+W/7ewAAAABJRU5ErkJggg==\n"
          },
          "metadata": {}
        }
      ],
      "source": [
        "# Graph the loss\n",
        "\n",
        "window_size = 50\n",
        "smoothed_losses = []\n",
        "for i in range(len(losses)-window_size):\n",
        "  smoothed_losses.append(np.mean(losses[i:i+window_size]))\n",
        "\n",
        "plt.plot(smoothed_losses[100:])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "77CgIkYjnEXk"
      },
      "source": [
        "## ファインチューニングしたモデルで翻訳してみる"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "XoZhIC8U7_GV",
        "outputId": "ff95a9b3-dbdd-4e8f-d6f5-d0de10fd1033"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Raw input text: It has been confirmed that eight thoroughbred race horses at Randwick Racecourse in Sydney have been infected with equine influenza.\n",
            "Truncated input text: <jp> It has been confirmed that eight thoroughbred race horses at Randwick Racecourse</s>\n"
          ]
        }
      ],
      "source": [
        "test_sentence = test_dataset[0]['translation']['en']\n",
        "print('Raw input text:', test_sentence)\n",
        "\n",
        "input_ids = encode_input_str(\n",
        "    text = test_sentence,\n",
        "    target_lang = 'ja',\n",
        "    tokenizer = tokenizer,\n",
        "    seq_len = model.config.max_length,\n",
        "    lang_token_map = LANG_TOKEN_MAPPING)\n",
        "input_ids = input_ids.unsqueeze(0)\n",
        "\n",
        "print('Truncated input text:', tokenizer.convert_tokens_to_string(\n",
        "    tokenizer.convert_ids_to_tokens(input_ids[0])))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "44BvtXAOHnWL",
        "outputId": "31eed1ad-1f43-444b-beba-f3c49723427a"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "ニュージャージー州のRandwickRacecourseで8人の競輪のレース選手が\n",
            "ニュージャージー州のRandwickRacecourseで8人の競輪のレース選手は\n",
            "ニュージャージー州のRandwickRacecourseで8人の競輪の競輪の\n"
          ]
        }
      ],
      "source": [
        "output_tokens = model.generate(input_ids, num_beams=10, num_return_sequences=3)\n",
        "# print(output_tokens)\n",
        "for token_set in output_tokens:\n",
        "  print(tokenizer.decode(token_set, skip_special_tokens=True))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yMvWuP11LoqV"
      },
      "source": [
        "## Google Colab でフォームを用いた実行例"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gfN3QjHmPZJ8",
        "outputId": "e664db55-b1cc-4034-8cb9-ee141f96e926"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "This is test  ->  これがテストの結果です。\n"
          ]
        }
      ],
      "source": [
        "#@title Slick Blue Translate\n",
        "input_text = 'This is test' #@param {type:\"string\"}\n",
        "output_language = 'ja' #@param [\"en\", \"ja\", \"zh\"]\n",
        "\n",
        "input_ids = encode_input_str(\n",
        "    text = input_text,\n",
        "    target_lang = output_language,\n",
        "    tokenizer = tokenizer,\n",
        "    seq_len = model.config.max_length,\n",
        "    lang_token_map = LANG_TOKEN_MAPPING)\n",
        "input_ids = input_ids.unsqueeze(0)\n",
        "\n",
        "output_tokens = model.generate(input_ids, num_beams=20, length_penalty=0.2)\n",
        "print(input_text + '  ->  ' + \\\n",
        "      tokenizer.decode(output_tokens[0], skip_special_tokens=True))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vIMaG0pU5lLA"
      },
      "outputs": [],
      "source": [
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "machine_shape": "hm",
      "name": "machine_translation.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3.8.9 64-bit",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.9"
    },
    "vscode": {
      "interpreter": {
        "hash": "880b2a8c90f9e6beae80b56829e3f671fedd58b6d14887184ddce26124cedfbd"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
