distance dependent Chinese Restaurant Process

お久しぶりです。nokunoさんにも紹介されてしまったので頑張って月1ぐらいは更新したいと思ってます…。今回は面白かった論文の紹介です。去年のICMLのBleiの論文で、相変わらずCRPとかです。ICML版はこちらで、longer versionもあり、こちらからダウンロード出来ます。ICML版でほとんどの部分は説明されてて、理論的に詳しいところが知りたい人はlonger versionも補足的に読むといいかもしれません。
以前のエントリーで、DPMを説明するときに、CRPを介して説明出来るということを書きました。これはつまり、データをクラスタリングする場合、データの事前分布にCRPを仮定し、CRPの事後分布(レストランの状態の分布)がどうなるかを考え、同じテーブルに座ったデータ=客を同じクラスタとすることでクラスタリング出来ることを意味しています。この場合、背後にあるDPという構造を考えずに、単にCRPを道具としてクラスタ数を制限しないクラスタリングを考えています。distance dependent CRP(以下dd-CRP)はこの考えを推し進め、もはやDPとの等価性は存在しないが、より柔軟なクラスタリングを可能にしようというモデルです。
CRPでは、一人の客=データは、どこかのテーブル=クラスタに対応付けられました。dd-CRPでは、一人の客は、もう一人の別の客に対応付けられます。つまり、CRPではテーブル1に客{1,2,3}が座っているという状況は、dd-CRPでは例えば、テーブル1 <- 客1 <- 客2 <- 客3、というリンク構造で表されます。ここで矢印で繋がっている客が1つのクラスタを形成します。なんのためにこんなことをするかというと、例えば系列データをクラスタリングする場合、距離が近いデータは同じクラスタに含まれやすいだろうという直感を入れることが出来ます。例えばニュース記事などが時系列に並んでいてクラスタリングする場合、同じ時期の記事は同じクラスタになりやすい、という予想を組み込むことが出来ます。これはdd-CRPの生成過程において、新しい客が行く先の客の分布に対して、近くの客ほど行きやすい、というような分布を仮定したり、行く先の客との距離に閾値を設けたりすることで達成されます(論文ではdecay function)。ここで面白いのが、この閾値を定めず、今より前の客のところであればどこでも同じ確率で行く、とすると、これは従来のCRPと同じものになります。ただdd-CRP自体は、見て分かるように、客=データの位置関係によって同時確率が定まるため、これは交換可能な列とはなりません。つまり背後にDPは仮定出来ないということで、ノンパラベイズの枠組みとも異なる、と書かれています。
この事後分布はGibbsで推定出来ます。データを抜き取り、客の再配置を行うわけですが、このときその客には、他の客からのリンクが貼られている可能性があります。つまり5 <- {3,6,8 <- {2}}という具合。この場合{2,3,5,6,8}は同じクラスタを形成し、5に3,6,8がくっつき、8が2にくっついています。ここで5の再配置を行う場合、それが従えるデータの集合{2,3,5,6,8}を全て一緒に再配置する、ということになります。これが従来のCRPだと、5の再配置を行う場合、5が座るテーブルを変えるだけで、他の{2,3,6,8}については変えません。そのため、dd-CRPを従来のCRPの置き換えとして使ってこの推定を行うと、より速く混合が進み、マルコフ連鎖の収束が早まるらしいです。他の実験では(これがメインですが)、一応新聞記事や時系列に並んだNIPSデータセットで、このモデルが従来のCRPよりも低いパープレキシティを示すと書いてありました。
個人的にはこの収束が早まるというのがこのモデルの面白いところかなと思いました。現状dd-CRPは単なるCRP、DPMのモデルの置き換えとして使えるようですが、現在主に使われるのは階層モデル(HDP)だと思うので、HDPについてもdd-CRPのような等価な表現が出てくれば、色々嬉しいことがありそうです。特に隠れ変数が強い相関を持つモデルの場合従来のCRP、Gibbsでは収束が遅いことがボトルネックなので、この辺が改善されれば面白そうです。

ちなみに、Blei本人がこれのプログラム(R)を公開しています。
http://www.cs.princeton.edu/~blei/downloads/ddcrp.tgz
ちょっとだけ触ってみましたが、従来のCRPとの収束の違いなどはシミュレート出来ないようです。(いくつかパッケージが入っていることが前提です)また、僕の場合safelogという関数が定義されていない、と怒られたのですが、よく分からないのでここにあるsafeLogという関数を取ってきて名前をsafelogに変えました(いいのか?)。

> library(plyr)
> library(lda)
> library(Matrix)

# coraはldaに入っている論文のabstractを集めたデータセット
> data(cora.document) # 各論文毎にBOW形式で入っている
> data(cora.vocab)    # スカラと単語との対応データ

> source('data-modeling.R')
> source('ddcrp-inference.R')

> dat <- corpus.to.matrix(cora.documents[1:100],cora.vocab) # 文章データをBOWの疎行列に変換
> res <- ddcrp.gibbs(dat=dat[1:100,], dist.fn=seq.dist, alpha=10, # run gibbs
                     decay.fn=exp.decay(5),
                     doc.lhood.fn(0.5), 10, summary.fn = ncomp.summary)

> res$summary # クラスタ数の遷移を表示
        iter ncomp
summary    0   100
           1    10
           2     8
           3     9
           4    11
           5     9
           6     7
           7     7
           8     5
           9     7
          10     6
          11     6
          12     6
          13     7
          14     6
          15     6
          16     6
          17     7
          18     8
          19     9
          20     6
          
> res$map.state # 最後の状態を表示
    idx cluster customer
1     1       1        1
2     2       1        1
3     3       3        3
4     4       1        1
5     5       1        2
6     6       1        1
7     7       1        4
8     8       1        2
9     9       1        7
10   10       1        9
11   11       1        7
12   12       1       10
13   13       1        6
14   14       1       11
15   15       1       10
...
47   47       1       42
48   48      16       45
49   49       1       47
50   50      50       50
51   51       1       47
52   52       1       51
53   53       1       51
54   54       1       53
55   55       1       52
56   56      16       46
57   57       1       54
58   58       1       57
59   59      16       56
60   60       1       49
61   61      16       59
62   62      16       61
63   63       1       54
64   64      16       61
65   65       1       63
66   66       1       60
67   67      16       61
68   68       1       66
69   69       1       63
70   70       1       58
...

まあ、これだけじゃよく分からないですが…。最後の出力は、最後のレストランの状態を示していて、左からデータ=客のインデックス、その客が属するクラスタ番号、対応する(くっついている)客のインデックス、を示しています。idxとcustomerを見ると、近くの客のところに行きやすいようになっていることが分かります。
ちなみにcoraのデータセットを使ってるのは、使えるのがこれしか見当たらなかったからで、これは時系列データではないのでその点で微妙です。推定のddcrp.gibbs()では、いくつか引数を指定します。decay.fnは、上で述べたdecay functionで、ここではexp.decayという指数的に遠くの客に行きにくくなるものを使っています。decay.fn=window.decay(dim(docs)[1])とすると、前に存在する全ての客に等しく遷移するという、従来のCRPと等価なdd-CRPのサンプルが得られます。
Rは久々に触りましたが、パッケージが色々なソフトに分散している状況は何とかならないのですかね。まあ現状matlabもRもnumpyもほとんど触っていない(使えない)のですが。自分で書くのはどれか一つでいいにしろ、どの言語もある程度使えるようになっておく必要性というのは感じます。