ステージ2-5: 半教師あり学習の練習 (label_propagation.LabelSpreading) (情報工学実験 3 : データマイニング班)
目次- 想定環境
- データセットの用意
二重の同心円上にデータ(サンプル)を並べ、内側と外側の円上で最短距離にあるサンプル2点のみに教師信号を設定する。
- case 1: 2点にのみ教師データを用意し、残りを「ラベルが無いデータ」として半教師あり学習を適用してみる
- case 2: (平等な条件ではないが)教師あり学習(svm.SVC)でも試してみる
- case 3: (平等な条件ではないが)教師無し学習(cluster.KMeans)でも試してみる
想定環境
- OS: Mac OS X 10.8.x (10.7.x以降であれば同じ方法で問題無いはず)
- Python: 2.7.x
- Mercurial: 2.2
- Scikit-learn: 0.13.1 (sklearn.__version__)
- matplotlib: 1.3.x (matplotlib.__version__)
- numpy: 1.8.0 (numpy.version)
- scipy: 0.13.0 (scipy.version)
データセットの用意
- make_circles(): 2次元空間に二重円上にサンプルを配置。ラベルは内側の円(label=-1)と外側の円(label=0)の2種類が設定される。
import numpy as np
from sklearn.datasets import make_circles
n_samples = 200
data, target0 = make_circles(n_samples=n_samples, shuffle=False)
# make_circles() がどういうデータを生成したのかを確認してみる。
# dataは2次元特徴ベクトル:
# targetは教師データ: 前半=0, 後半=1
import pylab as pl
pl.figure()
X = data[:,0]
Y = data[:,1]
labels = target0
pl.title("make_circles")
pl.scatter(X, Y, c=labels)
#pl.show()
pl.savefig("fig0_make_circles.svg")
# 半教師あり学習用にラベルを設定し直す。
# 教師データ2種:
# 内側と外側の円上で最短距離にあるサンプル2点のみに教師信号を与える。
# 外側の円(outer): label=0
# 内側の円(inner): label=1
# それ以外のデータは未知ラベルとして-1を設定
target1 = - np.ones(n_samples)
unlabeled = -1
outer = 1
inner = 2
target1[0] = outer
target1[-1] = inner
pl.figure(1)
pl.title("before learning (law data)")
pl.scatter(X[target1==outer], Y[target1==outer], c='b', label='outer')
pl.scatter(X[target1==inner], Y[target1==inner], c='r', label='inner')
pl.scatter(X[target1==unlabeled], Y[target1==unlabeled], c='g', label='unlabeled')
pl.legend()
#pl.show()
pl.savefig("fig1_lawdata.svg")
2点にのみ教師データを用意し、残りを「ラベルが無いデータ」として半教師あり学習を適用してみる
- label_propagation.LabelSpreading: 半教師あり学習の一種。
# 半教師あり学習(label_propagetion.LabelSpreading)を実行。
# unlabeled がどうなったかを確認。
from sklearn.semi_supervised import label_propagation
label_spread = label_propagation.LabelSpreading(kernel='knn', alpha=1.0)
label_spread.fit(data, target1)
output_labels = label_spread.transduction_
pl.figure(2)
pl.title("LabelSpreading")
pl.scatter(X[output_labels==outer], Y[output_labels==outer], c='b', label='outer')
pl.scatter(X[output_labels==inner], Y[output_labels==inner], c='r', label='inner')
pl.scatter(X[output_labels==unlabeled], Y[output_labels==unlabeled], c='g', label='unlabeled')
pl.legend()
#pl.show()
pl.savefig("fig2_LabelSpreading.svg")
(平等な条件ではないが)教師あり学習(svm.SVC)でも試してみる
# 類似条件(2点のみ教師データあり)で教師あり学習(svm.SVC)を試してみる。
# 教師あり学習の場合には「教師データが無いデータ」は学習時に提示できない。
# その状況で学習したモデルを使って、unlabeled がどうなるかを確認。
def filter(data, target, label):
u'''labelで指定されたtarget値を持つdataのみを抽出。
ここでは教師データを付与したtarget[0], target[-1]に該当するdata[0], data[1]
のみを切り出せば十分なので手動でやっても良いが、汎用性のある関数として記述した例。
'''
filtered_list = []
samples = data[target==label].tolist()
for sample in samples:
filtered_list.append(sample)
return filtered_list
# 教師データを付与した2点のみで学習用データセット(data2, target2)を構築。
# 残りの未ラベルデータでテスト用データセット(data2_test)を構築。
data2 = filter(data, target1, inner) + filter(data, target1, outer)
target2 = filter(target1, target1, inner) + filter(target1, target1, outer)
data2_test = filter(data, target1, unlabeled)
from sklearn import svm
clf = svm.SVC(gamma=0.001, C=100.)
clf.fit(data2, target2)
output_labels2 = clf.predict(data2_test)
pl.figure(3)
pl.title("svm.SVC")
pl.scatter(X[output_labels2==outer], Y[output_labels2==outer], c='b', label='outer')
pl.scatter(X[output_labels2==inner], Y[output_labels2==inner], c='r', label='inner')
pl.scatter(X[output_labels2==unlabeled], Y[output_labels2==unlabeled], c='g', label='unlabeled')
pl.legend()
#pl.show()
pl.savefig("fig3_SVC.svg")
(平等な条件ではないが)教師無し学習(cluster.KMeans)でも試してみる
# 教師なし学習(cluster.KMeans)を試してみる。
from sklearn import cluster
k_means = cluster.KMeans(n_clusters=2)
k_means.fit(data)
output_labels3 = k_means.labels_
pl.figure(4)
pl.title("cluster.KMeans")
pl.scatter(X[output_labels3==0], Y[output_labels3==0], c='b', label='label = 0')
pl.scatter(X[output_labels3==1], Y[output_labels3==1], c='r', label='label = 1')
pl.scatter(X[output_labels3==unlabeled], Y[output_labels3==unlabeled], c='g', label='unlabeled')
pl.legend()
#pl.show()
pl.savefig("fig4_KMeans.svg")