{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "recommendation_nn.ipynb", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "fw3fbJCt5124" }, "source": [ "# 推薦システムのコード例\n", "- 参考: [Collaborative Filtering for Movie Recommendations](https://keras.io/examples/structured_data/collaborative_filtering_movielens/) by Keras例題\n", "- 全体の流れ\n", " - データセットの用意\n", " - 学習用データ・検証用データに分割\n", " - モデル構築\n", " - 学習\n", " - 学習過程の観察\n", " - top-N推薦" ] }, { "cell_type": "markdown", "metadata": { "id": "9rymVz_vo_jv" }, "source": [ "## 環境構築" ] }, { "cell_type": "code", "metadata": { "id": "L4ebGTz3ky_D" }, "source": [ "import pandas as pd\n", "import numpy as np\n", "from zipfile import ZipFile\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "from tensorflow.keras import layers\n", "from pathlib import Path\n", "import matplotlib.pyplot as plt" ], "execution_count": 1, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "yhnKe2YNpDhx" }, "source": [ "## データセットの用意\n", "- [MovieLens](https://grouplens.org/datasets/movielens/)の小データセットをダウンロード。\n", "- pd.read_csvで ratings.csv を DataFrame として読み込む。" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 487 }, "id": "v2XQTRSQk0bC", "outputId": "5a8aba45-d01f-475a-8333-f93b1759c810" }, "source": [ "# Download the actual data from http://files.grouplens.org/datasets/movielens/ml-latest-small.zip\"\n", "# Use the ratings.csv file\n", "movielens_data_file_url = (\n", " \"http://files.grouplens.org/datasets/movielens/ml-latest-small.zip\"\n", ")\n", "movielens_zipped_file = keras.utils.get_file(\n", " \"ml-latest-small.zip\", movielens_data_file_url, extract=False\n", ")\n", "keras_datasets_path = Path(movielens_zipped_file).parents[0]\n", "movielens_dir = keras_datasets_path / \"ml-latest-small\"\n", "\n", "# Only extract the data the first time the script is run.\n", "if not movielens_dir.exists():\n", " with ZipFile(movielens_zipped_file, \"r\") as zip:\n", " # Extract files\n", " print(\"Extracting all the files now...\")\n", " zip.extractall(path=keras_datasets_path)\n", " print(\"Done!\")\n", "\n", "ratings_file = movielens_dir / \"ratings.csv\"\n", "df = pd.read_csv(ratings_file)\n", "df" ], "execution_count": 2, "outputs": [ { "output_type": "stream", "text": [ "Downloading data from http://files.grouplens.org/datasets/movielens/ml-latest-small.zip\n", "983040/978202 [==============================] - 0s 0us/step\n", "Extracting all the files now...\n", "Done!\n" ], "name": "stdout" }, { "output_type": "execute_result", "data": { "text/html": [ "
\n", " | userId | \n", "movieId | \n", "rating | \n", "timestamp | \n", "
---|---|---|---|---|
0 | \n", "1 | \n", "1 | \n", "4.0 | \n", "964982703 | \n", "
1 | \n", "1 | \n", "3 | \n", "4.0 | \n", "964981247 | \n", "
2 | \n", "1 | \n", "6 | \n", "4.0 | \n", "964982224 | \n", "
3 | \n", "1 | \n", "47 | \n", "5.0 | \n", "964983815 | \n", "
4 | \n", "1 | \n", "50 | \n", "5.0 | \n", "964982931 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
100831 | \n", "610 | \n", "166534 | \n", "4.0 | \n", "1493848402 | \n", "
100832 | \n", "610 | \n", "168248 | \n", "5.0 | \n", "1493850091 | \n", "
100833 | \n", "610 | \n", "168250 | \n", "5.0 | \n", "1494273047 | \n", "
100834 | \n", "610 | \n", "168252 | \n", "5.0 | \n", "1493846352 | \n", "
100835 | \n", "610 | \n", "170875 | \n", "3.0 | \n", "1493846415 | \n", "
100836 rows × 4 columns
\n", "\n", " | userId | \n", "movieId | \n", "rating | \n", "timestamp | \n", "user | \n", "movie | \n", "
---|---|---|---|---|---|---|
0 | \n", "1 | \n", "1 | \n", "4.0 | \n", "964982703 | \n", "0 | \n", "0 | \n", "
1 | \n", "1 | \n", "3 | \n", "4.0 | \n", "964981247 | \n", "0 | \n", "1 | \n", "
2 | \n", "1 | \n", "6 | \n", "4.0 | \n", "964982224 | \n", "0 | \n", "2 | \n", "
3 | \n", "1 | \n", "47 | \n", "5.0 | \n", "964983815 | \n", "0 | \n", "3 | \n", "
4 | \n", "1 | \n", "50 | \n", "5.0 | \n", "964982931 | \n", "0 | \n", "4 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
100831 | \n", "610 | \n", "166534 | \n", "4.0 | \n", "1493848402 | \n", "609 | \n", "3120 | \n", "
100832 | \n", "610 | \n", "168248 | \n", "5.0 | \n", "1493850091 | \n", "609 | \n", "2035 | \n", "
100833 | \n", "610 | \n", "168250 | \n", "5.0 | \n", "1494273047 | \n", "609 | \n", "3121 | \n", "
100834 | \n", "610 | \n", "168252 | \n", "5.0 | \n", "1493846352 | \n", "609 | \n", "1392 | \n", "
100835 | \n", "610 | \n", "170875 | \n", "3.0 | \n", "1493846415 | \n", "609 | \n", "2873 | \n", "
100836 rows × 6 columns
\n", "