5回目は飛ばして6回目。講義はex8まであるけども、このシリーズはこれで終わりにする。
今回はSVMを使った分類問題。
対象のデータは画像のような二次元特徴ベクトル + ラベル
線形分離不可能なので、SVMのカーネルトリックを使う。
今回のように線形分離できないときは、特徴ベクトルを非線形写像して線形分離可能な問題に変換する。
このあたりははじパタを読んで復習した。
実装に当たっては、非線形写像する関数でなく、カーネル関数と呼ばれる関数を定義してやれば十分。
Andrew先生の講義ではガウシアンカーネルを使ったが、今回はRBFカーネルを使う。(使うライブラリのデフォルトがRBFだからであって深い理由はない)
SVMの実装は面倒だったのでe1071というRのライブラリを使う。
- SVMの設定(cost)
costはニューラルネットワークでいう正則化項のように、過学習を抑制するパラメータ。
ただしニューラルネットワークとは逆で、costを大きくするほど学習データに対して識別性能が良くなる代わりに過学習しがちになる。つまり、識別境界に複雑な形を許すようになる。
costを小さくすると、パラメータの上限が抑えられ、識別境界がシンプルな形になる。
- SVMの設定(gamma)
こいつも大きいほど複雑な識別境界を描けるようになる。
R
※ドルマークが正しく表示されないらしい
ex6<- function()
{
require("e1071")
require(R.matlab)
datamat <- readMat('ex6data2.mat')
testdata <- data.frame(cbind(datamat$y, datamat$X))
names(testdata) <- c("y", "x1","x2")
testdata$y <- as.factor(testdata$y)
plotData(datamat$X, datamat$y)
par(new = T)
result <- svm(y ~ . , data = testdata, cost = 10, gamma = 1)
summary(result)
print(result)
px <- seq(0.9 * min(testdata$x1), 1.1 * max(testdata$x1),length= 100)
py <- seq(0.9 * min(testdata$x2), 1.1 * max(testdata$x2),length= 100)
pgrid <- expand.grid(px,py)
names(pgrid)<-c("x1","x2")
result.plot <- predict(result, pgrid, type="vector")
contour(px, py, array(result.plot, dim=c(length(px),length(py))),col = "red", lwd=3, drawlabels=F, c(0.9 * min(testdata$x1), 1.1 * max(testdata$x1)), ylim = c(0.9 * min(testdata$x2), 1.1 * max(testdata$x2)))
y_trained <-predict(result, cbind(testdata$x1, testdata$x2), type="vector")
cat("precision", mean(datamat$y == y_trained) * 100, "\n")
}
plotData <- function(X, y)
{
# ----- plot -----
plot(0, 0, type = "n", xlim = c(0.9 * min(X[, 1]), 1.1 * max(X[, 1])), ylim = c(0.9 * min(X[, 2]), 1.1 * max(X[, 2])),xlab = "x1", ylab = "x2")
# points(X, col = ifelse(y == 1, "black", "yellow"), pch = ifelse(y == 1, 3, 16))
points(X, col = c("black", "orange")[y+1], pch = c(3, 16)[y+1])
legend("topleft", legend = c("y = 1", "y = 0"), pch = c(3, 16), col = c("black", "orange"))
}
結果
デフォルト設定でSVM
> ex6()
Call:
svm(formula = y ~ ., data = testdata)
Parameters:
SVM-Type: C-classification
SVM-Kernel: radial
cost: 1
gamma: 0.5
Number of Support Vectors: 406
precision 90.38239
次にcost = 10, gamma = 1で
> ex6()
Call:
svm(formula = y ~ ., data = testdata, cost = 10, gamma = 1)
Parameters:
SVM-Type: C-classification
SVM-Kernel: radial
cost: 10
gamma: 1
Number of Support Vectors: 151
precision 98.84125



0 件のコメント:
コメントを投稿