AIを脱Black Box! XAI(Explainable Artificial Intelligence)を勉強する〜Permutation Importance〜

広告

 予測モデルは精度が命ということで、内部を複雑化させることで予測精度の向上を図ってきました。内部を非線形関数で複雑にしまくっているディープラーニングのように。しかしながら、内部を複雑にすることで精度が上がった一方、予測モデルの中身がBlack Boxになってしまいました。良い予測をしてくれる魔法の道具は手に入りましたが、なんでその予測結果になったのかがわからなくなってしまいました。ドラえもんの四次元ポッケトから適切に道具が出てきても、ドラえもんがいないがために誰も使い方がわからにような状態です。寂しいですね。


 そこで、予測モデルの中身がBlack Boxにならないようにしようというのが、Explainable Artificial Intelligence(XAI)です。このときに大事なのは、予測モデルの複雑度をなるべく落とさないで、モデルの中身を説明できるようにすることです。単回帰モデルにすればモデルの説明は簡単ですが、それでは予測精度が落ちてしまいます。またディープラーニングを使えば予測精度は上がりやすいですが予測結果に対する説明は神頼みでしょう。


 今回の話は学習済みの予測モデルの中身を説明する話であり、以下の図の右下に位置します。

f:id:dskomei:20190829215458p:plain:w600


 どうやって複雑なモデルの中身を説明するかですが、出来上がった予測モデルをカエルの解剖のように分解して一つ一つ見ていくのではなく、出来上がった予測モデルの反応を見ながらどの変数が重要かを調べていきます。奥さんがかまをかけたら案の定動揺してしまった旦那さんのように、予測モデルの動揺具合が大事ということです。


 最終的には以下のグラフのようにモデルにとって大事な変数と必要のない変数と区別することができます。

f:id:dskomei:20190829221018p:plain:w600


 今回のコードはこちらのGit Hubに置いてあります。
github.com


Permutation Importanceとは

 それぞれの変数の重要度を測定する方法として、変数の中身を変化させたときの予測モデルの反応を見ながら行うものがあります。その中の一つであるPermutation Importanceを今回はやってみます。


 この方法はいたってシンプルです。精度の高い予測モデルは、データからパターンを勝手に見つけてくれるから精度が高いわけです。逆に意味のわからないデータを与えると、とたんにおかしくなってしまいます。まさしく好きなアニメの話以外は急にしどろもどろになってしまう自分のようです・・・


 さて本題のアルゴリズムの話ですが、以下の手順で計算をします。

  1. ある変数の値だけをランダムに入れ替える
  2. 1.のデータを使って予測したときの誤差から入れ替える前のデータを使って予測した誤差を引き、それを入れ替え前の予測誤差で割る


\(\displaystyle \frac{ランダムデータでの予測誤差 - 入れ替前データでの予測誤差}{入れ替前での予測誤差}\)


 3 1.2.をすべての変数で行う


 肝は2の処理です。この処理について補足すると、重要な変数は予測精度を上げるのに必要な他の変数とのパターンのキーになっているため、その変数の値が変わるとパターンがくずれ、予測精度が落ちます。つまり、変数の値を適当に変えたときに予測誤差が大きくなるものほど予測モデルにとっては欠かせないということです。サッカーにおいて、主力選手を徹底マークすることでチームが機能しなくなり、とにかくドカンと前に蹴り出しますよね。

f:id:dskomei:20190829223734p:plain:w600


 更に2.の式を見ると、負になる変数もできる場合があります。そうです、ランダムに値を入れ替えた結果、予測誤差が改善してしまったパターンです。それは、データが良くなったから予測誤差が減少したのではなく、その変数が予測モデルを作る上でノイズとなって予測精度を下げてしまっていたからです。ランダム化したのにノイズが薄まってしまうというかわいそうなノイズなのです。

Permutation Importanceのいいところ

 この方法の良いところは、出来上がった予測モデルを使って、データの入れ替えと予測モデルの反応だけで変数の重要度を求めている点です。つまり、どの予測モデルにも使えます。なので、予測モデルを作る際に変数の重要度も出して説明をすることが求められる場合でも、予測モデルのアルゴリズムとして何を選んでも良く、決定木やXgboostに縛られなくなります。まぁ、でもXGBoostは使いますけど。

実装

 Permutation Importanceを実装しますが、まずはデータを準備します。今回使用するデータはscikit-learnに標準で組み込まれている「breast_cancer」です。このデータは、乳がんかどうかを予測する分類用のデータであり、データ数は569件、カラム数は30個です。

from pathlib import Path
import pandas as pd
from dfply import *
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
from sklearn.datasets import load_breast_cancer

result_dir_path = Path(".").joinpath("result")
if not result_dir_path.exists():
    result_dir_path.mkdir(parents=True)

## データをロードし、データフレームにする
cancer_data = load_breast_cancer()

data_y = cancer_data.target
data_x = pd.DataFrame(
    cancer_data.data,
    columns=cancer_data.feature_names
)

train_x, test_x, train_y, test_y = train_test_split(
    data_x,
    data_y,
    test_size=0.3
)

 上記の処理でデータを読み込み、学習用データと訓練用データに分割しています。
 続いて、モデルを学習します。

model = XGBClassifier()
model.fit(train_x, train_y)
print("正答率 : {:.0f}%".format(model.score(test_x, test_y)*100))

正答率 : 96%


 学習済みの予測モデルができましたので、そのモデルを使ってPermutation Importanceを行ってみます。
 そのために、今回は『eli5』というパッケージを使います。こちらはpipで簡単にインストールできます。

from eli5.sklearn import PermutationImportance
perm = PermutationImportance(model, random_state=1).fit(test_x, test_y)

 結果を保存するためにデータフレームにし、CSVファイルとして保存します。

perm_weights = pd.DataFrame({
    "column" : data_x.columns.tolist(),
    "weight": perm.feature_importances_,
    "std" : perm.feature_importances_std_
    }
) >> mutate(weight=X.weight.astype(float)) >> arrange(X.weight, ascending=False)
perm_weights.to_csv(result_dir_path.joinpath("feature_permutations.csv"), index=False)


 この結果をグラフ化すると冒頭でも表示した棒グラフになります。今回の場合は、最も重要な変数は「worst concave points」になりました。また予測精度に対して悪い変数は「worst perimeter」です。

f:id:dskomei:20190829221018p:plain:w600


 上記の結果で負になる変数がいくつかありました。これらの変数はモデルの精度に悪影響を及ぼす変数なので、取り除いて再学習してみます。つまり、変数選択します。

model = XGBClassifier()

target_perm_werights = perm_weights >> filter_by(X.weight > 0)
model.fit(train_x >> select(target_perm_werights["column"].tolist()), train_y)
print("正答率 : {:.0f}%".format(model.score(test_x >> select(target_perm_werights["column"].tolist()), test_y)*100))

正答率 : 96%


 正答率は変数削減する前と変わりませんでしたが(本当はランダム性があるので複数回実行するなり、交差検証して判断ですが本筋ではないのでご了承を)、変数削減した分、簡素なモデルになっているので良しとします。

終わりに

 今回はどの予測モデルに対しても変数の重要度を算出できるPermutation Importanceを試してみました。変数の重要度がわかることで、予測対象に対して何が重要なのかが語れるようになります。
 しかし、これは訓練データ全体を使って算出したもので、一つ一つの予測結果に対してではありません。そこで、次回はひとつひとつ予測結果に対して各変数がどのように影響しているかがわかる方法を行ってみようと思います。