ディープラーニングを使った因果推定 〜SAMのアルゴリズムを理解する〜

スポンサーリンク

 
 近年、機械学習のアルゴリズムは目覚ましい発展を遂げ、機械学習を使ったサービスが広まっています。そして、今後も機械学習は注目され、さらなる成長が起こるでしょう。しかし、それと共に機械学習のモデルは大きくなっており、予測結果がブラックボックスになりがちです。その結果ビジネスにおいては、良い予測をするという点では良いことですが、予測結果の要因がわからなくなっているため、類似の事象で行動ができないといったことが起きています。


 この問題が起きるのは、既存の機械学習では因果関係がわからないからです。その点、因果推定の文脈だと、反実仮想の考え方をベースにしたアプローチが主に研究されており、ここに機械学習のアルゴリズムが使われ始めています。

 
 機械学習を使った因果推定では、反実仮想の仮想の結果を機械学習で予測する方法がメインでしたが、ディープラーニングのGAN(日本名:敵対的生成ネットワーク)を使って因果グラフを作ることで因果推定を行うという面白い手法が見いだされました。それが今回ご紹介するStructural Agnostic Modeling(SAM)です。元論文はこちら


f:id:dskomei:20210318170108p:plain:w600




今回の記事を書く理由


 「相関だけでははだめで、因果がわからないとだめ」というのは、耳にタコができるぐらいよく聞きます。しかし、手元にあるデータから因果を推定しようと思い、傾向スコアベースの方法を試しても納得いかない結果になることが多いです。機械学習の予測が的外れでなんとも言えない感じがします。更に、扱うデータの介入変数が0 or 1ではない場合もあります。


 なので、観測データを使っての因果推定の新しいアルゴリズムを待ち焦がれており、ついに、GANを使った因果推定というセンセーショナルなやり方に出会いました。それはStructural Agnostic Modeling(SAM)と呼ばれるものです。SAMの実装方法に関しては既に書かれている記事がありますが、アルゴリズムに関して書かれたものはあまり見られません。GAN自体面白い発想であり、それを使って因果推定を行うのは秀逸です。論文を読んで全体観を勉強したので、備忘録ながらこの記事にまとめていきます。


SAMモデルの全体観


 GANはディープラーニングのモデルの中でも特殊な部類だと思うので、それぞれのニューラルネットワークが何をしたいかを抑えないと、なかなか理解ができません。まず、ネットワークの全体観を理解することが大事です。


 GANは、偽のデータを作るGeneratorと真のデータと偽のデータを見破るDiscriminatorが同時に精度が良くなるように学習していき、その結果Generatorが真のデータに類似したデータを作れるようになるという方法です。今回は、与えられた観測データを再現することがGeneratorの役割になります。


f:id:dskomei:20210318170108p:plain:w600


 全体観は上の図のようになります。Generatorは、実際のデータがインプットとして与えられ、それを元に実際のデータに類似するように出力します。それを受けて、Discriminatorは、真のデータか偽のデータかを判定します。ここでポイントとなるのが、「Generatorのニューラルネットワークへのインプットは実際のデータに因果行列を掛けたもの」であるということです。また、この因果行列はStructual Gatesと呼ばれています。更に、Genenratorのニューラルネットワークのアウトプットにも行列が掛けられています。これらの行列の要素は、最小値0、最大値1です。


 Generatorは、ニューラルネットワークを挟むような形で、インプットとアウトプットそれぞれに対して行列が掛けられています。この2つの行列の値も学習対象のパラメータになっており、それぞれの行列の合計値(\( \mathrm{L}_0 \) ノルム)が小さくなるようにパラメータを更新していきます。


 Structural Gatesは合計値が小さくなるようになっていき、Generatorのニューラルネットワークへのインプットとしては、実際のデータにそのStructural Gatesを掛けたものです。ただ、Generatorはその合計値が小さくなるようにされた行列(全体的に0になっている)がかけらたデータ使って実際のデータに似たものを生成しなければいけません。そうすると、Structural Gatesの合計値は小さくても、データを作り出す上で必要な行列の値は大きく(最大値は1)しないといけません。その結果、Structural Gatesの行列は必要な係数が1に近くなり、因果グラフとなります。この部分が今回のアルゴリズムを理解する上での勘所なので、次の章で更に詳しく見ていきます。


Generatorにおける因果グラフの役割


 今回の方法では、実際のデータに因果グラフを表す因果行列をかけて、それをGeneratorのニューラルネットワークの入力としています。なぜわざわざこのような処理をしているのかに関して見ていきます。


 入力データを \( \mathrm{X} \) とし、各変数 \( j \) のベクトルは \( \mathrm{X}_j \) とします。そうすると、因果グラフができた場合は、各変数は他の変数によって説明できるため、\( \mathrm{X}_j \) は \( \mathrm{X}_j \) 以外の変数で表すことができます。今回の方法では更に正規分布に従うノイズ \(E\) も含まれています。具体的には以下のグラフを御覧ください。


f:id:dskomei:20210318210345p:plain:w500


 上のグラフでは、 \( \mathrm{X}_1 \) は \( \mathrm{E}_1 \) だけが原因となっており、\( \mathrm{X}_2 \) は \( \mathrm{X}_1 \) と \( \mathrm{E}_2 \) が原因です。それを数式で表すならば何かしらの関数を \( f \) として、\( \mathrm{X}_2 = f( \mathrm{X}_1, \mathrm{E}_2 ) \) となります。そして、この \( f \) がGeneratorなわけです。


 では、これをどう実現するかですが、それがStructural Gatesです。実際のデータにStructural Gatesの行列を掛けることで、\( \mathrm{X}_j \) を \( \mathrm{X}_j \) 以外の変数で表せます。これをGeneratorのニューラルネットワークの入力として、偽のデータを生成します。ここで大事なのが、Structural Gatesの行列の総和が小さくなるようにGeneratorが学習されることです。Structural Gatesの要素が0が多くなる(Sparsityになる)ようにしつつも、Generatorは真のデータに似るようにしなければいけないことで、Strucrtural Gatesの重要な要素の値だけ大きくなります(最大1)。これにより因果行列が出来上がります。ただ因果グラフを作るためには、有向非巡回グラフ(DAG)でなければいけません。これを担保するために損失関数が工夫されています。次の章でこの損失関数を見ていきます。


損失関数


 ここまでの話で、SAMの構造やGANを使ってどのように因果行列を作るのかがわかってきたかと思います。SAMの構造が組み立てられれば、データを使ってパラメータを学習させることで、因果行列ができます。ただ今回の方法では、パラメータを学習させるための損失関数は工夫されており、一筋縄では理解できないので、じっくり見ていきます。


 GAN自体の損失関数は、ニューラルネットワークが2つあることから一般的なディープラーニングの損失関数とは異なり、Generator用とDiscriminator用の2つの損失関数があります。これを交互に計算して、バックプロパゲーションによりパラメータを更新していきます。交互にパラメータを更新することで、Generatorは最強の偽データを作り、Discriminatorは最強の仕分け人を目指します。ライバルがいることでお互いが強くなるアレです。


 交互にニューラルネットワークを学習する際に、Generatorの損失関数には、因果行列の合計値(\( \mathrm{L}_0 \)ノルム)も加えます。こうすることで、因果行列が疎行列になるようにしています。ただ、これだけでは因果行列は不十分であり、因果グラフになるためのDAG(有向非巡回グラフ)を満たしていません。そこで、Generatorの損失関数に途中からDAGを満たすようにするためのコストを加えます。それが以下の式です。なぜこの式かは
こちらを御覧ください。



\[ \sum_{k=1}^d \frac {\mathrm{tr} A^k} {k!} = 0 \]


 文字だけで説明すると分かりづらいかなと思うので、以下の図にまとめてみました。


f:id:dskomei:20210319162501p:plain:w600


 上の図を見てもらえばわかると思いますが、DiscriminatorとGeneratorが交互にパラメータの更新が行われる中で、Generatorは途中から損失関数にDAG化のコストが追加されます。


 以上でアルゴリズムの説明は終わりです。


終わりに


 今回はアルゴリズムの説明だけになりましたが、次回はしっかり実装したいと思います。ただ、実装してある記事を読むと、凝りに凝った作りになっている気がするので、簡易な形で再構築したいです。Pythonを使った因果推定に関しては以下の本が参考になります。




参考Web