読者です 読者をやめる 読者になる 読者になる

old school magic

機械学習と統計とプログラミングについてちょっとずつ勉強していきます。

PyStanでMCMC入門

Python 機械学習 統計

概要

PyStan は Stan というMCMC計算用言語の Python インターフェイスです。

Stan
http://mc-stan.org/

PyStan
http://pystan.readthedocs.org/en/latest/index.html

MCMCを計算できるソフトはいくつかあるのですが、Stan は

  • C++で実装されているため高速
  • 最近のサンプリング法を実装している

といった特徴があります。特に速度には目を見張るものがあります。

前回までは PyMC3 をいじっていたのですが、他のソフトにも触ってみようと思い、今回は PyStan でモデリングをしてみました。

PyStan のインストール

Anaconda を入れればもれなく一緒にインストールされます。
Anaconda
https://store.continuum.io/cshop/anaconda/

参考
MacでPythonの機械学習環境構築(2014年5月版) - old school magic


もしくは pip でインストールします。

pip install pystan

参考
http://pystan.readthedocs.org/en/latest/getting_started.html


(※追記 Anacondaだとデフォルトで入ってませんでした... pip でインストールするのが正解ですね)

データの生成

今回モデリングするのは、ガウス混合モデルというモデルです。
その名の通りいくつかのガウス分布を混合したモデルで、今回は3個の1次元ガウス分布を混合しています。
まずは真の分布からデータを生成します。

import pystan
import numpy as np
import matplotlib.pyplot as plt

mean1 = 10
mean2 = -10
mean3 = 0
num1 = 200
num2 = 300
num3 = 500

X = np.concatenate([
	np.random.normal(loc=mean1, scale=1, size=num1), 
	np.random.normal(loc=mean2, scale=1, size=num2),
	np.random.normal(loc=mean3, scale=1, size=num3)
	])

np.random.shuffle(X)

真の平均はそれぞれ{(10, -10, 0)}、真の混合比は {2:3:5} です。標準偏差は簡単のため全て1としています。
ヒストグラムにするとこんな感じのデータです。

f:id:breakbee:20140810032241p:plain

このデータを用いてMCMCサンプリングを行い、平均と混合比の事後分布を推定する、というのが今回の目的です。

Stan によるモデリング

PyStan は Stan言語のインターフェイスなので、Stan言語でモデリングを行い、Pythonから呼び出します。
今回のガウス混合モデルのモデリングはこんな感じです。これを gmm_mcmc.stan というファイル名で保存します。

data {
    int<lower=1> N;
    int<lower=1> k;
    real X[N];
}
parameters {
    simplex[k] theta;
    real mu[k];
}
model {
    real ps[k];
    for (i in 1:k){
        mu[i] ~ normal(0, 1.0e+2);
    }
    for(i in 1:N){
        for(j in 1:k){
            ps[j] <- log(theta[j]) + normal_log(X[i], mu[j], 1.0);
        }
        increment_log_prob(log_sum_exp(ps));
    }
}

Stan はC言語のような手続き型の言語です。Pythonとは異なりコードブロックを用いています。

Stan には、基本的に次の4つのブロックがあります。

  • data
  • parameters
  • transformed parameters
  • model

data ブロックは用いるデータを宣言しておくブロックです。
このブロックにPythonなどのインターフェイスからデータを渡すことになります。
今回は、データと、データのサイズと、混合数を渡しています。

parameters ブロックはモデルのパラメータを宣言するブロックです。
今回の例で言えば、平均 mu と混合比 theta です。
theta の型である simplex というのは、ディリクレ分布を事前分布に設定した時に使う特別な型です。(後述)

transformed parameters ブロックは、parameters ブロックで宣言したパラメータを用いて新しいパラメータを定義するブロックです。
今回は用いていませんが、例えば{a_1, a_2, a_3}というパラメータがあった時、{b = a_1 + a_2 * a_3} などと宣言します。

model ブロックは、モデルを記述するブロックです。
モデルの記述方法には、

  • Log Probability Increment
  • Sampling Statement

の2種類があります。

平均 mu のサンプリングには Sampling Statement を用いています。

    for (i in 1:k){
        mu[i] ~ normal(0, 1.0e+2);
    }

このように、~ の左側にパラメータ、右側に従う分布を描く方法です。
(モデルの)平均の事前分布には、平均0、標準偏差100の正規分布を用いています。

混合比 theta のサンプリングには Log Probability Increment を用いています。

    real ps[k];
    for(i in 1:N){
        for(j in 1:k){
            ps[j] <- log(theta[j]) + normal_log(X[i], mu[j], 1.0);
        }
        increment_log_prob(log_sum_exp(ps));
    }

対数をとった確率(Log Probability)の式を記述することによって計算を行う方法です。
基本的にどちらの記述を用いても良いのですが、前述した simplex 型、つまり混合比の事前分布であるディリクレ分布に従うパラメータは、制約上 Log Probability Increment でしか記述できないようです。

参考
マニュアル(pdf)
https://github.com/stan-dev/stan/releases/download/v2.4.0/stan-reference-2.4.0.pdf

  • 9.2(混合分布について)
  • 21.3(モデルの記述方法について)
  • 44(simplex 型について)

PyStan からの呼び出し

では、この Stan によるモデリングPython から呼び出してみましょう。

N = X.shape[0]
k = 3

stan_data = {'N': N, 'k': k, 'X': X}

fit = pystan.stan(file='gmm_mcmc.stan', data=stan_data, iter=10000, chains=1)

print('Sampling finished.')

# 可視化
fit.plot()
plt.show()

Python側からデータを渡す時、Stan の data ブロックで宣言した名前をキーにした辞書型にして渡します。
今回の例でいうところの stan_data です。データ数と混合数、データを辞書にして渡しています。

今回のサンプリングでは、サンプリング回数(iter)は10000回、サンプリング系列数(chains)は1としています。(通常は複数のサンプリングを走らせます。)
そのうち最初の5000回(ちょうど半分)は burn-in という処理の対象となります。簡単にいうと、サンプリングの最初の方は精度が悪いので捨てます。この捨てる数はもちろん調整可能です。(デフォルトで半分。)

サンプリングの結果はこんな感じです。

f:id:breakbee:20140810035826p:plain

混合比も平均も、おおよそ真値を中心にうまく分布しているのが見て取れますね。

サンプリング系列数(chains)を4にするとこんな感じです。
f:id:breakbee:20140810042940p:plain

今回使用したコードはこちらです。(Gist)
PyStan Code
PyStan Code for GMM
Stan Code
Gaussian Mixture Model for Stan

感想

PyMC と比べて、PyStan は高速だけど記述が面倒、といった印象を受けました。
PyMC3 は少しバグが多いので、マニュアルがしっかりしてて安定してる Stan は魅力的でした。
しかし、PyMCに比べて記述量が増えるので、Stan言語の記述に慣れるかどうかが1つのポイントかもしれません。
書いててC言語を思い出しました。

PyStan の日本語資料はあまりなかったのですが、 RStan は結構ありました。
両方 Stan のインターフェイスで、Stanコードは共通なので、割りと調べやすかったです。

参考資料

公式の入門手引きの解説もとてもわかり易いです。
PyStan Getting started
http://pystan.readthedocs.org/en/latest/getting_started.html

PyStan の貴重な日本語の活用事例です。
Pystanで自然言語処理 scikit.learnのdatasetで試す - xiangze's sparse blog

マニュアルも非常に細かく書かれています。
https://github.com/stan-dev/stan/releases/download/v2.4.0/stan-reference-2.4.0.pdf

日本語でマニュアルの解説をしている方の記事です。
Stanのマニュアルの8章~12章の私的メモ

Stan のチュートリアルを作成してくださった方のスライドです。
http://www.slideshare.net/teitonakagawa/stantutorialj