2015年12月23日水曜日

[Rの]Andrew先生の機械学習2-1[練習]





2回目のお題はロジスティック回帰
簡単な関数いくつか作って、後はRの最適化関数を使うだけ

ex2 <- function()
{
    # ----- load data -----
    data = read.csv('ex2data1.txt', header = F)
    X <- data[, 1:2]
    y <- data[, 3]
   
    # ----- plot -----
    plot(0, 0, type = "n", xlim = c(0.9 * min(X[, 1]), 1.1 * max(X[, 1])), ylim = c(0.9 * min(X[, 2]), 1.1 * max(X[, 2])),xlab = "x1", ylab = "x2")   
    points(X, col = ifelse(y == 1, "black", "yellow"), pch = ifelse(y == 1, 3, 16))

    # ----- variables -----
    X_mat <- cbind(X[, 1],X[, 2])
    m <- length(X[,1])
    n <- length(X[1,])

    # ----- add bias -----
    initial_theta <- rep(0, n + 1)
    X_mat <- cbind(rep(1, m), X_mat)

    # ----- optimize -----
    result <- optim(par = initial_theta, costFunction, X = X_mat, y = y)
    theta <- result$par
    plotDecisionBoundary(theta, X, y)
}

sigmoid <- function(x)
{
    return(1 / (1 + exp(-x)))
}

costFunction <- function(theta, X, y)
{
    m <- length(y)
    predicts <- X %*% theta
   
    probs <- sigmoid(predicts)
    costs <- -y * log(probs) - (1 - y) * log(1 - probs)
    J <- sum(costs) / m

    return(J)
}

plotDecisionBoundary <- function(theta, X, y)
{
    # ----- plot -----
    plot(0, 0, type = "n", xlim = c(0.9 * min(X[, 1]), 1.1 * max(X[, 1])), ylim = c(0.9 * min(X[, 2]), 1.1 * max(X[, 2])),xlab = "x1", ylab = "x2")
    points(X, col = ifelse(y == 1, "black", "yellow"), pch = ifelse(y == 1, 3, 16))
    par(new=T)

    x1 <- c(min(X[, 1]), max(X, 1))
    x2 <- -1 / theta[3] *( theta[2] * x1 +theta[1])
    plot(x1, x2, ,type = 'l', xlim = c(0.9 * min(X[, 1]), 1.1 * max(X[, 1])), ylim = c(0.9 * min(X[, 2]), 1.1 * max(X[, 2])))
}


















----- 追記 -----
ちなみにglm関数を使うと
ex2_glm <- function()
{
    # ----- load data -----
    data = read.csv('ex2data1.txt', header = F)
    X <- data[, 1:2]
    y <- data[, 3]
   
    # ----- GLM -----
    fit <- glm(formula = y ~ V1 + V2 + 1, family = binomial, data = X)
    theta <- fit$coef
    plotDecisionBoundary(theta, X, y)
    print(summary(fit))   
    return(fit)
}


0 件のコメント:

コメントを投稿