Matplotlibを使ったグラフ描画入門 with Pandas
<目次>
- Step 1: 描画方法の選定。
- Step 2: データの準備。
- Step 3: 描画方法に合わせてプロット。
- 機械学習入門で良く使われるデータセット「Iris Data Set」を視覚的に眺めてみる。
- Iris Data Set: Iris (アイリス。アヤメ科植物の総称) の花150サンプル分のデータ。
- データの中身
- 1サンプル毎に「sepal length, sepal width, petal length, petal width」の4属性に関する数値データ(cm)と、そのサンプルの詳細分類名(Iris Setosa, Iris Versicolour, Iris Virginica)が記録されている。
- 「4属性の数値」をそのままプロットできればそれに越したことはないが、現実的には2次元か3次元が限界。
- 今回は以下の2通りでプロットしてみる。
- case 1) 「sepal length(1番目の数値), petal length(3番目の数値)」で2次元プロット。
- case 2) 「sepal width(2番目), petal length(3番目), petal width(4番目)」で3次元プロット
- (1) iris.dataをダウンロード。(ターミナル上で実行)
# 適当な作業用ディレクトリに移動してから、iris.dataをダウンロード。
curl https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data -o iris.data
# データの中身を扱いやすいように修正。
# 文字列で書かれている分類名を0,1,2に置換。
cat iris.data | sed -e 's/Iris-setosa/0/g' | sed -e 's/Iris-versicolor/1/g' | sed -e 's/Iris-virginica/2/g' > iris2.data
- (2) pandasモジュールをインストール。(CSV形式データを簡単に操作できるライブラリ)
- ターミナル上でインストール。
pip3 install pandas
```
### 具体例: 散布図を描く(2次元)
# 慣例としてpd, pltというエイリアスを付けることが多い。
import pandas as pd
import matplotlib.pyplot as plt
# データセットの読み込み
filename = 'iris2.data'
df = pd.read_csv(filename, header=None)
# sepal length(1番目の数値)をx座標として準備。
# petal length(3番目の数値)をy座標として準備。
x = df[0]
y = df[2]
labels = df[4]
# 散布図描画
plt.scatter(x, y, c=labels) #実際に描画データを指定している箇所
plt.xlabel('sepal length') #x軸の説明
plt.ylabel('petal length') #y軸の説明
plt.show()
- 事前にデータセットを読み込み終えてる状態で実行すること。
### 具体例: 3次元空間に散布図を描く
# sepal width(2番目)をx座標,
# petal length(3番目)をy座標,
# petal width(4番目)をz座標として準備。
x = df[1]
y = df[2]
z = df[3]
labels = df[4]
# 散布図描画
# 3次元描画の場合はもう一つモジュールを読み込む。
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = Axes3D(fig)
ax.scatter(x, y, z, c=labels)
ax.set_xlabel('sepal width')
ax.set_ylabel('petal length')
ax.set_zlabel('petal width')
plt.show()
plt.show()
で描画し、その後に手動で保存することも可能。これを手間に感じなければこれでOK。
- 保存するところまでコードに書くなら次のように書こう。
- plt.show()の代わりに
plt.savefig()
を使う。
### 具体例: 2次元空間に散布図を書き、保存する。
import pandas as pd
import matplotlib.pyplot as plt
filename = 'iris2.data'
df = pd.read_csv(filename, header=None)
x = df[0]
y = df[2]
labels = df[4]
# 散布図描画
plt.scatter(x, y, c=labels)
plt.xlabel('sepal length')
plt.ylabel('petal length')
# plt.show()の代わりに、plt.savefig()を使う。
plt.savefig('graph.pdf')