準備

必要なパッケージを読み込む。

pacman::p_load(tidyverse,
               titanic,
               rsample,
               gmodels,
               e1071,
               tm,
               SnowballC)


単純ベイズ分類

データの準備

タイタニック号のデータを使って実習を行う。データは、次の方法で読み込める。

data(titanic_train)

データの中身を確認する。

glimpse(titanic_train)
## Rows: 891
## Columns: 12
## $ PassengerId <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,…
## $ Survived    <int> 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1…
## $ Pclass      <int> 3, 1, 3, 1, 3, 3, 1, 3, 3, 2, 3, 1, 3, 3, 3, 2, 3, 2, 3, 3…
## $ Name        <chr> "Braund, Mr. Owen Harris", "Cumings, Mrs. John Bradley (Fl…
## $ Sex         <chr> "male", "female", "female", "female", "male", "male", "mal…
## $ Age         <dbl> 22, 38, 26, 35, 35, NA, 54, 2, 27, 14, 4, 58, 20, 39, 14, …
## $ SibSp       <int> 1, 1, 0, 1, 0, 0, 0, 3, 0, 1, 1, 0, 0, 1, 0, 0, 4, 0, 1, 0…
## $ Parch       <int> 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 1, 0, 0, 5, 0, 0, 1, 0, 0, 0…
## $ Ticket      <chr> "A/5 21171", "PC 17599", "STON/O2. 3101282", "113803", "37…
## $ Fare        <dbl> 7.2500, 71.2833, 7.9250, 53.1000, 8.0500, 8.4583, 51.8625,…
## $ Cabin       <chr> "", "C85", "", "C123", "", "", "E46", "", "", "", "G6", "C…
## $ Embarked    <chr> "S", "C", "S", "S", "S", "Q", "S", "S", "S", "C", "S", "S"…

これらの変数のうち、Survived が乗客が生存したことを示すダミー変数である。 この変数を応答変数として、生死を分類する確率的学習を実行しよう。

データフレームの名前 (titanic_train) からわかるとおり、このデータは訓練用であり、検証用データは別に用意されている。 しかし、ここでは後で分類性能を確かめたいので、このデータをさらに訓練用と検証用に分割しよう。PassengerIdNameTicketは分類に使わないので除外する。

set.seed(2021-10-18)
titanic_split <- titanic_train %>% 
  dplyr::select(!c(PassengerId, Name, Ticket)) %>% 
  initial_split(prop = 0.8)
T_train <- training(titanic_split) # 訓練(学習)用
T_test <- testing(titanic_split)   # 検証(テスト)用

1つの説明変数によって分類する

まず、1つの変数で分類してみよう。説明変数として性別を使う。 生死と性別のクロス表を作る。

with(T_train, table(Survived, Sex)) %>% 
  addmargins()
##         Sex
## Survived female male Sum
##      0       65  383 448
##      1      177   87 264
##      Sum    242  470 712

乗客が女性だったときに生存する確率は、 \[ p(Y = 1 \mid \mathrm{female}) = \frac{p(\mathrm{female} \mid Y = 1) p(Y = 1)}{p(\mathrm{female})} \] なので、

((177 / 264) * (264 / 712)) / (242 / 712)
## [1] 0.731405

である。よって、女性の乗客は生存する (\(\hat{y} = 1\)) と分類される。

同様に、乗客が男性だったときに生存する確率は、 \[ p(Y = 1 \mid \mathrm{male}) = \frac{p(\mathrm{male} \mid Y = 1) p(Y = 1)}{p(\mathrm{male})} \] なので、

((87 / 264) * (264 / 712)) / (470 / 712)
## [1] 0.1851064

である。よって、男性の乗客は死亡する (\(\hat{y} = 0\)) と分類される。

検証用データを使って、分類性能を確かめてみよう。 まずは、性別から生死を予測する

y_pred <- ifelse(T_test$Sex == "female", 1, 0)

実際の生死と比較する。

table(T_test$Survived, y_pred) %>% 
  addmargins()
##      y_pred
##         0   1 Sum
##   0    85  16 101
##   1    22  56  78
##   Sum 107  72 179

gmodels::CrossTable()で確認する。

CrossTable(T_test$Survived, y_pred,
           prop.chisq = FALSE,  prop.c = FALSE,  prop.r =  FALSE,
           dnn = c("事実", "単純ベイズ分類"))
## 
##  
##    Cell Contents
## |-------------------------|
## |                       N |
## |         N / Table Total |
## |-------------------------|
## 
##  
## Total Observations in Table:  179 
## 
##  
##              | 単純ベイズ分類 
##         事実 |         0 |         1 | Row Total | 
## -------------|-----------|-----------|-----------|
##            0 |        85 |        16 |       101 | 
##              |     0.475 |     0.089 |           | 
## -------------|-----------|-----------|-----------|
##            1 |        22 |        56 |        78 | 
##              |     0.123 |     0.313 |           | 
## -------------|-----------|-----------|-----------|
## Column Total |       107 |        72 |       179 | 
## -------------|-----------|-----------|-----------|
## 
## 

分類の的中率は

mean(T_test$Survived == y_pred)
## [1] 0.7877095

である。

全員死亡すると分類した場合の的中率は

mean(T_test$Survived == 0)
## [1] 0.5642458

なので、1変数でも分類の精度が上がることがわかる。

パッケージを利用して分類する

e1071パッケージの naiveBayes() を使って分類する。

nb <- naiveBayes(Survived ~ Sex, 
                 data = T_train)

この結果を使って分類を行う。

nb_pred <- predict(nb, newdata = T_test)

上でパッケージを使わずに求めた結果と比較する。

all(y_pred == nb_pred)
## [1] TRUE

すべての分類結果が一致しており、先ほどと同じ結果であることがわかる。

複数の説明変数を利用した分類

説明変数を2つ以上利用したベイズ分類も簡単に行うことができる。. で応答変数以外にデータフレームに含まれるすべての変数を説明変数として使うことを指定する。

nb2 <- naiveBayes(Survived ~ ., data = T_train)

この結果を使って検証用データの分類を行う。

nb2_pred <- predict(nb2, newdata = T_test)

実際の生死と比較する。

CrossTable(T_test$Survived, nb2_pred,
           prop.chisq = FALSE,  prop.c = FALSE,  prop.r =  FALSE,
           dnn = c("事実", "単純ベイズ分類"))
## 
##  
##    Cell Contents
## |-------------------------|
## |                       N |
## |         N / Table Total |
## |-------------------------|
## 
##  
## Total Observations in Table:  179 
## 
##  
##              | 単純ベイズ分類 
##         事実 |         0 |         1 | Row Total | 
## -------------|-----------|-----------|-----------|
##            0 |        88 |        13 |       101 | 
##              |     0.492 |     0.073 |           | 
## -------------|-----------|-----------|-----------|
##            1 |        37 |        41 |        78 | 
##              |     0.207 |     0.229 |           | 
## -------------|-----------|-----------|-----------|
## Column Total |       125 |        54 |       179 | 
## -------------|-----------|-----------|-----------|
## 
## 

分類の的中率は

mean(T_test$Survived == nb2_pred)
## [1] 0.7206704

である。性別だけを使った場合よりも分類性能が悪くなっている。

これはなぜだと思う?

スパムメールの判定

データの準備

UCI Machine Learning Repository にある SMS Spam Collection Data Setを使ってスパム判定を実行してみよう。

まず、データ (ZIP ファイル) を入手する(download.file() がうまくいかないときは手動でダウンロードする)。

#dir.create("data")
download.file(
  url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip",
  dest = "data/smsspamcollection.zip"
)

ZIP ファイルを展開する。

unzip(zipfile = "data/smsspamcollection.zip",
      exdir = "data")

dataフォルダの中に、SMSSpamCollection というファイルができているはずだ。次のコマンドで確認できる(結果は省略)。

dir("data")

このファイルは TSV(tab seprated value; タブ区切り)形式なので、read_tsv() で読み込む。列名(変数名)は付いていないので注意。

myd <- read_tsv("data/SMSSpamCollection",
                col_names = FALSE)

データの中身を見てみよう。

glimpse(myd)
## Rows: 4,837
## Columns: 2
## $ X1 <chr> "ham", "ham", "spam", "ham", "ham", "spam", "ham", "ham", "spam", "…
## $ X2 <chr> "Go until jurong point, crazy.. Available only in bugis n great wor…

第1列がラベル(spam と hamの区別)で、第2列にSMSの本文が含まれている。 列名を付けておこう。

SMS <- myd %>% 
  rename(type = X1,
         text = X2)

データの前処理

ラベルは文字列で表されているが、この後の分析で扱いやすくするためにfactor 型に変換しておこう。

SMS <- SMS %>% 
  mutate(type = factor(type, levels = c("ham", "spam")))

次に、SMSの本文を分析できるように処理する。 文書データを分析する1つの方法は、コーパス (corpus) と呼ばれる文書の集合を利用することである。SMSのコーパスを作ってみよう。tmパッケージのVectorSource()VCorpus() を利用してコーパスを作る。

SMS_corpus <- SMS %>% 
  pull(text) %>% 
  VectorSource() %>% 
  VCorpus()

できたコーパスを確認する。

SMS_corpus
## <<VCorpus>>
## Metadata:  corpus specific: 0, document level (indexed): 0
## Content:  documents: 4837

4837個のドキュメント(SMS)が含まれていることがわかる。

中身を確認するには、inspect()を使う。最初の3つだけ見てみよう。

inspect(SMS_corpus[1:3])
## <<VCorpus>>
## Metadata:  corpus specific: 0, document level (indexed): 0
## Content:  documents: 3
## 
## [[1]]
## <<PlainTextDocument>>
## Metadata:  7
## Content:  chars: 111
## 
## [[2]]
## <<PlainTextDocument>>
## Metadata:  7
## Content:  chars: 29
## 
## [[3]]
## <<PlainTextDocument>>
## Metadata:  7
## Content:  chars: 155

1つ目ののSMSのには111文字、2つ目には29文字、3つ目には155文字のメッセージが含まれていることがわかる。

メッセージ自体を表示したいときは、特定のメッセージに対して as.character()を使う必要がある。例えば、10個目のメッセージの中身は、

SMS_corpus[[10]] %>% as.character()
## [1] "Had your mobile 11 months or more? U R entitled to Update to the latest colour mobiles with camera for Free! Call The Mobile Update Co FREE on 08002986030"

である。カッコ[[]]は二重にする必要があることに注意されたい。

このコーパスを、分析のためにクリーニングする。具体的には、以下の処理を施す。

  • 文書を単語に分割する (tokenize)
  • すべての文字を小文字にする
  • 数字を取り除く
  • ストップワードを取り除く
  • 句読点を取り除く
  • ステミングを行う

ここで、ストップワード (stop word) とは、出現頻度が高く、学習にはあまり役立たない単語のことで、 たとえば、“and”, “but”, “or” などである。 また、ステミングとは、簡単にいうと単語の意味が同じで形が異なるものをまとめる作業である。たえば、“win”, “winning”, “wins” などがすべて “win” としてカウントされるようにする。

これらのの処理は、tmパッケージのDocumentTermMatrix() で行うことができる。 この関数を使うと、文書-単語行列 (document-term matrix) ができる。

SMS_dtm <- DocumentTermMatrix(
  SMS_corpus,
  control = list(
    tolower = TRUE,
    removeNumbers = TRUE,
    removePunctuation = TRUE,
    stemming = TRUE
  )
)

できた行列の中身を確認しよう

SMS_dtm
## <<DocumentTermMatrix (documents: 4837, terms: 7042)>>
## Non-/sparse entries: 53076/34009078
## Sparsity           : 100%
## Maximal term length: 40
## Weighting          : term frequency (tf)

文書 (document) の数が4,837、単語 (term) の数が7,042 であることがわかる。 この行列の次元を確認すると、

dim(SMS_dtm)
## [1] 4837 7042

であることがわかる。つまり、各文書が行に、各単語が列に配置されていることがわかる。 これが、文書-単語行列と呼ばれる理由である。

この行列を、訓練データと検証データに分割しよう。

set.seed(2021-10-18)
N <- nrow(SMS_dtm)
train_num <- sample(1:N, ceiling(0.75 * N), replace = FALSE)
SMS_train <- SMS_dtm[train_num, ]
SMS_test <- SMS_dtm[-train_num, ]

ラベルも訓練用と検証用に分ける。

SMS_train_label <- SMS$type[train_num]
SMS_test_label <- SMS$type[-train_num]

訓練用データには、7042個の特徴量がある。 各特徴量は単語 (term) である。 数少ないメッセージにしか登場しない単語は分類にあまり役立たないと思われるので、5つ以上のメッセージに登場する単語だけを抜き出そう。

freq_words <- findFreqTerms(SMS_train, 5)

5つ以上のメッセージに登場する単語の数は、

length(freq_words)
## [1] 1273

である。

これらの単語の列だけを、データから抜き出そう。

SMS_train_freq <- SMS_train[, freq_words]
SMS_test_freq <- SMS_test[, freq_words]

文書-単語行列の各要素は、特定の文書に特定の単語が何回登場するかを示している。単純ベイズ分類では、文書中に単語が登場するか否かという情報を扱いたいので、登場回数が0回の場合は “N”、1回以上の場合は “Y” になるように値を置き換える。

S_train <- apply(SMS_train_freq, 2, 
                 function(x) ifelse(x == 0, "N", "Y"))
S_test <- apply(SMS_test_freq, 2, 
                 function(x) ifelse(x == 0, "N", "Y"))

訓練

単純ベイズ分類器で学習する。

spam_nb <- naiveBayes(x = S_train,
                      y = SMS_train_label)

分類

検証データの文書を分類する。

spam_pred <- predict(spam_nb, newdata = S_test)

性能評価

実際のラベルと分類結果を比較する。

CrossTable(SMS_test_label, spam_pred,
           prop.chisq = FALSE,  prop.c = FALSE,  prop.r =  FALSE,
           dnn = c("事実", "単純ベイズ分類"))
## 
##  
##    Cell Contents
## |-------------------------|
## |                       N |
## |         N / Table Total |
## |-------------------------|
## 
##  
## Total Observations in Table:  1209 
## 
##  
##              | 単純ベイズ分類 
##         事実 |       ham |      spam | Row Total | 
## -------------|-----------|-----------|-----------|
##          ham |      1048 |         5 |      1053 | 
##              |     0.867 |     0.004 |           | 
## -------------|-----------|-----------|-----------|
##         spam |        20 |       136 |       156 | 
##              |     0.017 |     0.112 |           | 
## -------------|-----------|-----------|-----------|
## Column Total |      1068 |       141 |      1209 | 
## -------------|-----------|-----------|-----------|
## 
## 

予測の的中率は、

mean(SMS_test_label == spam_pred)
## [1] 0.9793218

であり、かなり正確に分類できている。

ただし、実際は ham なのに spam に分類されてるメールが5つある。大事なメールがスパムに分類されることもあるので、たまにはスパムフォルダも確認する必要があるかもしれない。



授業の内容に戻る