{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "recommendation_nn.ipynb", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "fw3fbJCt5124" }, "source": [ "# 推薦システムのコード例\n", "- 参考: [Collaborative Filtering for Movie Recommendations](https://keras.io/examples/structured_data/collaborative_filtering_movielens/) by Keras例題\n", "- 全体の流れ\n", " - データセットの用意\n", " - 学習用データ・検証用データに分割\n", " - モデル構築\n", " - 学習\n", " - 学習過程の観察\n", " - top-N推薦" ] }, { "cell_type": "markdown", "metadata": { "id": "9rymVz_vo_jv" }, "source": [ "## 環境構築" ] }, { "cell_type": "code", "metadata": { "id": "L4ebGTz3ky_D" }, "source": [ "import pandas as pd\n", "import numpy as np\n", "from zipfile import ZipFile\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "from tensorflow.keras import layers\n", "from pathlib import Path\n", "import matplotlib.pyplot as plt" ], "execution_count": 1, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "yhnKe2YNpDhx" }, "source": [ "## データセットの用意\n", "- [MovieLens](https://grouplens.org/datasets/movielens/)の小データセットをダウンロード。\n", "- pd.read_csvで ratings.csv を DataFrame として読み込む。" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 487 }, "id": "v2XQTRSQk0bC", "outputId": "5a8aba45-d01f-475a-8333-f93b1759c810" }, "source": [ "# Download the actual data from http://files.grouplens.org/datasets/movielens/ml-latest-small.zip\"\n", "# Use the ratings.csv file\n", "movielens_data_file_url = (\n", " \"http://files.grouplens.org/datasets/movielens/ml-latest-small.zip\"\n", ")\n", "movielens_zipped_file = keras.utils.get_file(\n", " \"ml-latest-small.zip\", movielens_data_file_url, extract=False\n", ")\n", "keras_datasets_path = Path(movielens_zipped_file).parents[0]\n", "movielens_dir = keras_datasets_path / \"ml-latest-small\"\n", "\n", "# Only extract the data the first time the script is run.\n", "if not movielens_dir.exists():\n", " with ZipFile(movielens_zipped_file, \"r\") as zip:\n", " # Extract files\n", " print(\"Extracting all the files now...\")\n", " zip.extractall(path=keras_datasets_path)\n", " print(\"Done!\")\n", "\n", "ratings_file = movielens_dir / \"ratings.csv\"\n", "df = pd.read_csv(ratings_file)\n", "df" ], "execution_count": 2, "outputs": [ { "output_type": "stream", "text": [ "Downloading data from http://files.grouplens.org/datasets/movielens/ml-latest-small.zip\n", "983040/978202 [==============================] - 0s 0us/step\n", "Extracting all the files now...\n", "Done!\n" ], "name": "stdout" }, { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
userIdmovieIdratingtimestamp
0114.0964982703
1134.0964981247
2164.0964982224
31475.0964983815
41505.0964982931
...............
1008316101665344.01493848402
1008326101682485.01493850091
1008336101682505.01494273047
1008346101682525.01493846352
1008356101708753.01493846415
\n", "

100836 rows × 4 columns

\n", "
" ], "text/plain": [ " userId movieId rating timestamp\n", "0 1 1 4.0 964982703\n", "1 1 3 4.0 964981247\n", "2 1 6 4.0 964982224\n", "3 1 47 5.0 964983815\n", "4 1 50 5.0 964982931\n", "... ... ... ... ...\n", "100831 610 166534 4.0 1493848402\n", "100832 610 168248 5.0 1493850091\n", "100833 610 168250 5.0 1494273047\n", "100834 610 168252 5.0 1493846352\n", "100835 610 170875 3.0 1493846415\n", "\n", "[100836 rows x 4 columns]" ] }, "metadata": { "tags": [] }, "execution_count": 2 } ] }, { "cell_type": "markdown", "metadata": { "id": "il7FHiD9pd5O" }, "source": [ "## データ前処理1:連番振り直し\n", "userId, movieIDは整数がラベルとして振られているが、欠番が存在する。このままでは扱いづらいため番号を振り直し。" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "pZ4_ulSyk6-B", "outputId": "5d695208-668b-41ba-8586-716eb55ea130" }, "source": [ "user_ids = df[\"userId\"].unique().tolist()\n", "user2user_encoded = {x: i for i, x in enumerate(user_ids)}\n", "userencoded2user = {i: x for i, x in enumerate(user_ids)}\n", "movie_ids = df[\"movieId\"].unique().tolist()\n", "movie2movie_encoded = {x: i for i, x in enumerate(movie_ids)}\n", "movie_encoded2movie = {i: x for i, x in enumerate(movie_ids)}\n", "df[\"user\"] = df[\"userId\"].map(user2user_encoded)\n", "df[\"movie\"] = df[\"movieId\"].map(movie2movie_encoded)\n", "\n", "num_users = len(user2user_encoded)\n", "num_movies = len(movie_encoded2movie)\n", "df[\"rating\"] = df[\"rating\"].values.astype(np.float32)\n", "# min and max ratings will be used to normalize the ratings later\n", "min_rating = min(df[\"rating\"])\n", "max_rating = max(df[\"rating\"])\n", "\n", "print(\n", " \"Number of users: {}, Number of Movies: {}, Min rating: {}, Max rating: {}\".format(\n", " num_users, num_movies, min_rating, max_rating\n", " )\n", ")" ], "execution_count": 3, "outputs": [ { "output_type": "stream", "text": [ "Number of users: 610, Number of Movies: 9724, Min rating: 0.5, Max rating: 5.0\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 419 }, "id": "yl_aCO4kpPRX", "outputId": "89fc5fc5-8aa5-49a6-f835-2404e4ade2b2" }, "source": [ "df" ], "execution_count": 4, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
userIdmovieIdratingtimestampusermovie
0114.096498270300
1134.096498124701
2164.096498222402
31475.096498381503
41505.096498293104
.....................
1008316101665344.014938484026093120
1008326101682485.014938500916092035
1008336101682505.014942730476093121
1008346101682525.014938463526091392
1008356101708753.014938464156092873
\n", "

100836 rows × 6 columns

\n", "
" ], "text/plain": [ " userId movieId rating timestamp user movie\n", "0 1 1 4.0 964982703 0 0\n", "1 1 3 4.0 964981247 0 1\n", "2 1 6 4.0 964982224 0 2\n", "3 1 47 5.0 964983815 0 3\n", "4 1 50 5.0 964982931 0 4\n", "... ... ... ... ... ... ...\n", "100831 610 166534 4.0 1493848402 609 3120\n", "100832 610 168248 5.0 1493850091 609 2035\n", "100833 610 168250 5.0 1494273047 609 3121\n", "100834 610 168252 5.0 1493846352 609 1392\n", "100835 610 170875 3.0 1493846415 609 2873\n", "\n", "[100836 rows x 6 columns]" ] }, "metadata": { "tags": [] }, "execution_count": 4 } ] }, { "cell_type": "markdown", "metadata": { "id": "btnjSrFrp7ss" }, "source": [ "## データ前処理2:レーティングを正規化\n", "元の評価値は0〜5の範囲を取る。これを学習しやすくするため0〜1の範囲に調整。" ] }, { "cell_type": "code", "metadata": { "id": "-Xv4nB8_k-LD" }, "source": [ "df = df.sample(frac=1, random_state=42)\n", "x = df[[\"user\", \"movie\"]].values\n", "# Normalize the targets between 0 and 1. Makes it easy to train.\n", "y = df[\"rating\"].apply(lambda x: (x - min_rating) / (max_rating - min_rating)).values\n" ], "execution_count": 5, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "Ux0iuJQYqlpv" }, "source": [ "学習用データと検証用データに分割" ] }, { "cell_type": "code", "metadata": { "id": "cl5r0qN4qhUj" }, "source": [ "# Assuming training on 90% of the data and validating on 10%.\n", "train_indices = int(0.9 * df.shape[0])\n", "x_train, x_val, y_train, y_val = (\n", " x[:train_indices],\n", " x[train_indices:],\n", " y[:train_indices],\n", " y[train_indices:],\n", ")" ], "execution_count": 6, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0QFDc3_2pybG", "outputId": "34393f37-bf4f-4ba2-94e0-4ba0a6df45b4" }, "source": [ "print(x_train[0].shape)\n", "print(x_train[0])\n", "print(y_train[0])" ], "execution_count": 7, "outputs": [ { "output_type": "stream", "text": [ "(2,)\n", "[ 431 4730]\n", "0.8888888888888888\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "7DjuIeucqpYZ" }, "source": [ "## モデル構築\n", "word2vecでも用いるembeddingレイヤーを用いてモデルを構築している。「ユーザ x レーティング」をtf.tensordotで演算してるだけのシンプルなモデル。" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "53wKsC8hlBvv", "outputId": "d0370738-938c-4749-8abb-e90a87196fab" }, "source": [ "EMBEDDING_SIZE = 50\n", "\n", "\n", "class RecommenderNet(keras.Model):\n", " def __init__(self, num_users, num_movies, embedding_size, **kwargs):\n", " super(RecommenderNet, self).__init__(**kwargs)\n", " self.num_users = num_users\n", " self.num_movies = num_movies\n", " self.embedding_size = embedding_size\n", " self.user_embedding = layers.Embedding(\n", " num_users,\n", " embedding_size,\n", " embeddings_initializer=\"he_normal\",\n", " embeddings_regularizer=keras.regularizers.l2(1e-6),\n", " )\n", " self.user_bias = layers.Embedding(num_users, 1)\n", " self.movie_embedding = layers.Embedding(\n", " num_movies,\n", " embedding_size,\n", " embeddings_initializer=\"he_normal\",\n", " embeddings_regularizer=keras.regularizers.l2(1e-6),\n", " )\n", " self.movie_bias = layers.Embedding(num_movies, 1)\n", "\n", " def call(self, inputs):\n", " user_vector = self.user_embedding(inputs[:, 0])\n", " user_bias = self.user_bias(inputs[:, 0])\n", " movie_vector = self.movie_embedding(inputs[:, 1])\n", " movie_bias = self.movie_bias(inputs[:, 1])\n", " dot_user_movie = tf.tensordot(user_vector, movie_vector, 2)\n", " # Add all the components (including bias)\n", " x = dot_user_movie + user_bias + movie_bias\n", " # The sigmoid activation forces the rating to between 0 and 1\n", " return tf.nn.sigmoid(x)\n", "\n", "\n", "model = RecommenderNet(num_users, num_movies, EMBEDDING_SIZE)\n", "model.compile(\n", " loss=tf.keras.losses.BinaryCrossentropy(), optimizer=keras.optimizers.Adam(lr=0.001)\n", ")" ], "execution_count": 8, "outputs": [ { "output_type": "stream", "text": [ "/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:375: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.\n", " \"The `lr` argument is deprecated, use `learning_rate` instead.\")\n" ], "name": "stderr" } ] }, { "cell_type": "markdown", "metadata": { "id": "M5zABGrU6Jy6" }, "source": [ "## 学習" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "htHvjYItlEmE", "outputId": "83d57ca8-93b6-4223-95fb-cc4af2d04c87" }, "source": [ "history = model.fit(\n", " x=x_train,\n", " y=y_train,\n", " batch_size=64,\n", " epochs=5,\n", " verbose=1,\n", " validation_data=(x_val, y_val),\n", ")" ], "execution_count": 9, "outputs": [ { "output_type": "stream", "text": [ "Epoch 1/5\n", "1418/1418 [==============================] - 12s 7ms/step - loss: 0.6370 - val_loss: 0.6206\n", "Epoch 2/5\n", "1418/1418 [==============================] - 10s 7ms/step - loss: 0.6135 - val_loss: 0.6168\n", "Epoch 3/5\n", "1418/1418 [==============================] - 10s 7ms/step - loss: 0.6082 - val_loss: 0.6126\n", "Epoch 4/5\n", "1418/1418 [==============================] - 11s 7ms/step - loss: 0.6071 - val_loss: 0.6150\n", "Epoch 5/5\n", "1418/1418 [==============================] - 10s 7ms/step - loss: 0.6078 - val_loss: 0.6123\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "bSzUU_z66N-U" }, "source": [ "## 学習履歴のグラフ化" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 295 }, "id": "5_gPhufGlGEE", "outputId": "fea6c81f-7cb6-4653-9dc3-01a0dcdce306" }, "source": [ "plt.plot(history.history[\"loss\"])\n", "plt.plot(history.history[\"val_loss\"])\n", "plt.title(\"model loss\")\n", "plt.ylabel(\"loss\")\n", "plt.xlabel(\"epoch\")\n", "plt.legend([\"train\", \"test\"], loc=\"upper left\")\n", "plt.show()" ], "execution_count": 10, "outputs": [ { "output_type": "display_data", "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "tags": [], "needs_background": "light" } } ] }, { "cell_type": "markdown", "metadata": { "id": "7oHeirss6TFL" }, "source": [ "## 上位N件の推薦\n" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "5XpK2qVMlRAc", "outputId": "df1a2b84-fd7c-4da3-ce89-bebe6fd902c2" }, "source": [ "movie_df = pd.read_csv(movielens_dir / \"movies.csv\")\n", "\n", "# Let us get a user and see the top recommendations.\n", "user_id = df.userId.sample(1).iloc[0]\n", "\n", "# 視聴済み映画リスト。\n", "movies_watched_by_user = df[df.userId == user_id]\n", "\n", "# 未視聴映画リスト。\n", "# not演算、重複排除。前処理で用意した movie2movie_encoded で前処理し直し。\n", "movies_not_watched = movie_df[\n", " ~movie_df[\"movieId\"].isin(movies_watched_by_user.movieId.values)\n", "][\"movieId\"]\n", "movies_not_watched = list(\n", " set(movies_not_watched).intersection(set(movie2movie_encoded.keys()))\n", ")\n", "movies_not_watched = [[movie2movie_encoded.get(x)] for x in movies_not_watched]\n", "\n", "# モデルで予測するためのデータ整形。\n", "user_encoder = user2user_encoded.get(user_id)\n", "user_movie_array = np.hstack(\n", " ([[user_encoder]] * len(movies_not_watched), movies_not_watched)\n", ")\n", "\n", "# 学習したモデルで予測。上位10件の映画idを取得。\n", "ratings = model.predict(user_movie_array).flatten()\n", "top_ratings_indices = ratings.argsort()[-10:][::-1]\n", "recommended_movie_ids = [\n", " movie_encoded2movie.get(movies_not_watched[x][0]) for x in top_ratings_indices\n", "]\n", "\n", "# 視聴済み映画のうち上位5件を出力。\n", "print(\"Showing recommendations for user: {}\".format(user_id))\n", "print(\"====\" * 9)\n", "print(\"Movies with high ratings from user\")\n", "print(\"----\" * 8)\n", "top_movies_user = (\n", " movies_watched_by_user.sort_values(by=\"rating\", ascending=False)\n", " .head(5)\n", " .movieId.values\n", ")\n", "movie_df_rows = movie_df[movie_df[\"movieId\"].isin(top_movies_user)]\n", "for row in movie_df_rows.itertuples():\n", " print(row.title, \":\", row.genres)\n", "\n", "# 推薦候補上位10件を出力。\n", "print(\"----\" * 8)\n", "print(\"Top 10 movie recommendations\")\n", "print(\"----\" * 8)\n", "recommended_movies = movie_df[movie_df[\"movieId\"].isin(recommended_movie_ids)]\n", "for row in recommended_movies.itertuples():\n", " print(row.title, \":\", row.genres)" ], "execution_count": 11, "outputs": [ { "output_type": "stream", "text": [ "Showing recommendations for user: 174\n", "====================================\n", "Movies with high ratings from user\n", "--------------------------------\n", "French Kiss (1995) : Action|Comedy|Romance\n", "Ace Ventura: Pet Detective (1994) : Comedy\n", "Jurassic Park (1993) : Action|Adventure|Sci-Fi|Thriller\n", "Tombstone (1993) : Action|Drama|Western\n", "Batman (1989) : Action|Crime|Thriller\n", "--------------------------------\n", "Top 10 movie recommendations\n", "--------------------------------\n", "Braveheart (1995) : Action|Drama|War\n", "Taxi Driver (1976) : Crime|Drama|Thriller\n", "Godfather, The (1972) : Crime|Drama\n", "Reservoir Dogs (1992) : Crime|Mystery|Thriller\n", "Star Wars: Episode V - The Empire Strikes Back (1980) : Action|Adventure|Sci-Fi\n", "Princess Bride, The (1987) : Action|Adventure|Comedy|Fantasy|Romance\n", "Raiders of the Lost Ark (Indiana Jones and the Raiders of the Lost Ark) (1981) : Action|Adventure\n", "Lawrence of Arabia (1962) : Adventure|Drama|War\n", "Apocalypse Now (1979) : Action|Drama|War\n", "Goodfellas (1990) : Crime|Drama\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "bfeU2egI8oUO" }, "source": [ "" ], "execution_count": 11, "outputs": [] } ] }