準備

必要なパッケージを読み込む。

pacman::p_load(tidyverse,
               rsample,
               class,
               gmodels)


距離

ミンコフスキー距離を測る関数を作る。(dist() という関数もあるが、計算内容を理解するためにあえて作る。)

minkowski <- function(x, y = NULL, p = 2) {
  # y が与えられないときは、原点とxの距離を測る
  if (is.null(y)) y <- rep(0, length(x))
  (sum(abs(x - y)^p))^(1/p)
}

2次元ベクトルをランダムにいくつか生成し、原点からのミンコフスキー距離(ノルム)を測ってみよう。 1回目。

beta <- rnorm(2, mean = 0, sd = 3)
beta
## [1] -3.926167 -2.443105
minkowski(beta, p = 2) # L2ノルム(ユークリッド距離)
## [1] 4.624235
minkowski(beta, p = 1) # L1ノルム(マンハッタン距離)
## [1] 6.369272

2回目。

beta <- rnorm(2, mean = 0, sd = 3)
beta
## [1]  0.04625993 -3.20774908
minkowski(beta, p = 2) # L2ノルム(ユークリッド距離)
## [1] 3.208083
minkowski(beta, p = 1) # L1ノルム(マンハッタン距離)
## [1] 3.254009

3回目。

beta <- rnorm(2, mean = 0, sd = 3)
beta
## [1] -5.593775 -1.169007
minkowski(beta, p = 2) # L2ノルム(ユークリッド距離)
## [1] 5.714621
minkowski(beta, p = 1) # L1ノルム(マンハッタン距離)
## [1] 6.762782

このように、距離を測ることができる。(当たり前だが)\(L_2\)ノルムよりも\(L_1\)ノルムのほうが大きい。

高次元のベクトルについても、原点からの距離を測ることができる。 20次元の場合。

beta <- rnorm(20, mean = 0, sd = 3)
beta
##  [1]  0.9146441 -4.5505172  0.9599669  3.9842991 -3.2336485  4.3306049
##  [7] -1.8077692  4.3012802  6.0072563 -1.0088395  1.7218507  2.0538056
## [13] -4.4672737 -9.5612610 -4.4332970 -2.2749843 -1.9943872  4.0057640
## [19]  4.6617256 -1.0632469
minkowski(beta, p = 2)
## [1] 17.68325
minkowski(beta, p = 1)
## [1] 67.33642

同じ次元の2つのベクトル間の距離も測れる。 20次元の場合。

x <- rnorm(20, mean = 0, sd = 3)
y <- rnorm(20, mean = 0, sd = 3)
cbind(x, y)
##                  x           y
##  [1,] -0.848191423  6.58940462
##  [2,] -2.169053687  1.14978266
##  [3,] -3.635219570 -1.44077379
##  [4,]  3.154534496 -1.92476911
##  [5,] -3.051425208  1.80255388
##  [6,] -1.636872241  1.19653233
##  [7,] -1.282081829  0.45640632
##  [8,]  0.509077708  2.02177432
##  [9,]  1.388155530  5.85392395
## [10,] -1.307016705 -6.00296898
## [11,] -3.737655930  0.06502178
## [12,]  3.420520354  5.73235875
## [13,]  0.512057509 -3.13812750
## [14,] -0.637968117 -0.69257612
## [15,] -2.278908101 -2.82291105
## [16,]  3.185022713  0.98435345
## [17,] -1.253209201 -2.26845500
## [18,] -0.279059957  6.38361902
## [19,]  0.008734508  0.61503575
## [20,] -0.716383777 -1.19281993
minkowski(x, y, p = 2) 
## [1] 16.12854
minkowski(x, y, p = 1)
## [1] 59.45511

\(k\)-NN法による分類

データの準備

UCI Machine Learning Repository にある Congressional Voting Records Data Set(アメリカ合衆国議会における議員の投票データ)を例に、\(k\)-NN法で分類を行ってみよう。

まず、データ を入手する(download.file() がうまくいかないときは手動でダウンロードする)。

dir.create("data")
download.file(
  url = "https://archive.ics.uci.edu/ml/machine-learning-databases/voting-records/house-votes-84.data",
  dest = "data/house-votes-84.csv"
)

データを読み込む。変数名が付いていないので注意。

Vote <- read_csv("data/house-votes-84.csv",
                 col_names = FALSE)

データの中身を確認する。

glimpse(Vote)
## Rows: 435
## Columns: 17
## $ X1  <chr> "republican", "republican", "democrat", "democrat", "democrat", "d…
## $ X2  <chr> "n", "n", "?", "n", "y", "n", "n", "n", "n", "y", "n", "n", "n", "…
## $ X3  <chr> "y", "y", "y", "y", "y", "y", "y", "y", "y", "y", "y", "y", "y", "…
## $ X4  <chr> "n", "n", "y", "y", "y", "y", "n", "n", "n", "y", "n", "n", "y", "…
## $ X5  <chr> "y", "y", "?", "n", "n", "n", "y", "y", "y", "n", "y", "y", "n", "…
## $ X6  <chr> "y", "y", "y", "?", "y", "y", "y", "y", "y", "n", "y", "y", "n", "…
## $ X7  <chr> "y", "y", "y", "y", "y", "y", "y", "y", "y", "n", "n", "y", "n", "…
## $ X8  <chr> "n", "n", "n", "n", "n", "n", "n", "n", "n", "y", "n", "n", "y", "…
## $ X9  <chr> "n", "n", "n", "n", "n", "n", "n", "n", "n", "y", "n", "n", "y", "…
## $ X10 <chr> "n", "n", "n", "n", "n", "n", "n", "n", "n", "y", "n", "n", "y", "…
## $ X11 <chr> "y", "n", "n", "n", "n", "n", "n", "n", "n", "n", "n", "n", "n", "…
## $ X12 <chr> "?", "n", "y", "y", "y", "n", "n", "n", "n", "n", "?", "y", "n", "…
## $ X13 <chr> "y", "y", "n", "n", "?", "n", "n", "n", "y", "n", "?", "?", "n", "…
## $ X14 <chr> "y", "y", "y", "y", "y", "y", "?", "y", "y", "n", "y", "y", "y", "…
## $ X15 <chr> "y", "y", "y", "n", "y", "y", "y", "y", "y", "n", "y", "y", "n", "…
## $ X16 <chr> "n", "n", "n", "n", "y", "y", "y", "?", "n", "?", "n", "?", "?", "…
## $ X17 <chr> "y", "?", "n", "y", "y", "y", "y", "y", "y", "?", "n", "?", "?", "…

このデータの各行は議員である。 第1列(X1)がラベルで、その議員の所属政党を表す。

table(Vote$X1)
## 
##   democrat republican 
##        267        168

各議員は、“democrat” (民主党員)か “republican”(共和党員)かのいずれかである。

第2列 (X2) から 第17列 (X17) までは、ある法案に賛成した (y) か反対した (n) かを示している。

table(Vote$X2)
## 
##   ?   n   y 
##  12 236 187

賛成と反対だけでなく、? が含まれていることがわかる。

必ずしも必要ではないが、ラベルを\(y\)、各説明変数を\(x1\)から\(x16\)として変数名を付けなおそう。

names(Vote) <- c("y", paste0("x", 1:16))

次に、法案に対する賛否を、数値に置き換えよう。

myd <- Vote %>% 
  mutate(across(x1:x16, function(v) case_when(
    v == "y" ~  1,
    v == "n" ~ -1,
    TRUE     ~  0
  )))

このデータを、訓練用と検証用に分割しよう。

set.seed(2021-10-17)
Vote_split <- initial_split(myd, prop = 0.8)
Vote_train <- training(Vote_split) # 訓練(学習)用
Vote_test <- testing(Vote_split)   # 検証(テスト)用

さらに、ラベル(応答変数)と説明変数に分ける。

y_train <- Vote_train[, 1] %>% 
  as.matrix()
X_train <- Vote_train[, -1] %>% 
  as.matrix()
y_test <- Vote_test[, 1] %>% 
  as.matrix()
X_test <- Vote_test[, -1] %>% 
  as.matrix()

学習(訓練)

\(k\)-NN法は遅延学習(怠惰学習)を行うので、特に訓練は必要ない。 利用可能な状態でデータを保存(暗記)すればそれが学習なので、上で訓練済みである。

予測

検証用データの各点について、以下を実行する。

  • すべての点との距離を測る
  • 最近傍の\(k\)点を見つける
  • 最近傍の\(k\)点で、所属政党の多数決をとる
  • 多数派の政党に所属していると予測する
  • 予測が当たっているかどうか確認する

\(k=5\)として、ユークリッド距離を用いて上の手続きを実行してみよう。

まず、検証用データの各点について、訓練用データのすべての点との距離を測る。 検証用データの観測点の数は、

(J <- nrow(X_test))
## [1] 87

である。訓練用データの観測点の数は、

(N <- nrow(X_train))
## [1] 348

である。

距離を記録するための行列を用意する。

d_euclid <- matrix(NA, nrow = N, ncol = J)

距離を測る。

for (j in 1:J) {
  for (n in 1:N) {
    d_euclid[n, j] <- minkowski(X_test[j, ], X_train[n, ], p = 2)
  }
}

検証用データの各点について、最近傍の5つを見つける。

nearest5 <- apply(d_euclid, 2, order)[1:5, ]

最近傍にある5点の所属政党を調べる。

party <- matrix(NA, nrow = 5, ncol = J)
for (j in 1:J) {
  party[, j] <- y_train[nearest5[, j], 1]
}

多数決をとる。

y_pred <- apply(party, 2, function(x) table(x) %>% which.max() %>% names())

予測が当たっているかどうか調べる。

mean(y_pred == y_test)
## [1] 0.8850575

的中率は約88.5%である。

全員民主党員であると予測した場合の的中率は

mean(y_test == "democrat")
## [1] 0.5747126

であり、全員共和党員だと予測した場合の的中率は

mean(y_test == "republican")
## [1] 0.4252874

である。\(k\)-NN法の予測はそこそこ当たっていると言えるだろう。

パッケージを使う

実際に分類を行う際には、class パッケージのknn() を使うことができる。 上と同じ分類を、class::knn() を用いて実行してみよう。

out <- knn(train = X_train,
           test = X_test,
           cl = as.factor(y_train),
           k = 5)

結果は、gmodels パッケージの CrossTable() で確認する。

CrossTable(x = as.factor(y_test), 
           y = out,
           prop.chisq = FALSE)
## 
##  
##    Cell Contents
## |-------------------------|
## |                       N |
## |           N / Row Total |
## |           N / Col Total |
## |         N / Table Total |
## |-------------------------|
## 
##  
## Total Observations in Table:  87 
## 
##  
##                   | out 
## as.factor(y_test) |   democrat | republican |  Row Total | 
## ------------------|------------|------------|------------|
##          democrat |         44 |          6 |         50 | 
##                   |      0.880 |      0.120 |      0.575 | 
##                   |      0.917 |      0.154 |            | 
##                   |      0.506 |      0.069 |            | 
## ------------------|------------|------------|------------|
##        republican |          4 |         33 |         37 | 
##                   |      0.108 |      0.892 |      0.425 | 
##                   |      0.083 |      0.846 |            | 
##                   |      0.046 |      0.379 |            | 
## ------------------|------------|------------|------------|
##      Column Total |         48 |         39 |         87 | 
##                   |      0.552 |      0.448 |            | 
## ------------------|------------|------------|------------|
## 
## 

この表の、主対角線(左上から右下)上にあるのが正しく予測された観測点の数(44 と 33)である。 実際には民主党員 (democrat) である50人のうち44人が民主党員であると正しく予測され、6人が共和党員 (republican) であると誤って予測されている。 同様に、実際には共和党員である37人のうち33人が共和党員であると正しく予測され、4人が民主党員であると誤って予測されている。 予測の的中率は、

(44 + 33) / 87
## [1] 0.8850575

であり、パッケージを使わずに分類した結果に一致することが確認できる。

実習課題



授業の内容に戻る