洋食の日記

「だ・である」調ではなく「です・ます」調で書きはじめれば良かったなと後悔してる人のブログです

Rumale::SVMにLocally Linear SVMによる分類器を追加した

はじめに

Locally Linear Support Vector Machine(LL-SVM)は、多様体学習の考え方を利用して、線形SVM非線形分類器を実現する手法である。Rumale::SVMは、Ruby機械学習ライブラリであるRumaleと同様のインターフェースで、LIBLINEARやLIBSVMで実装されているSVMアルゴリズムによる分類器や回帰を利用できるものである。LIBLINEARやLIBSVMにアップデートがないと、Rumale::SVMもアップデートする機会がないのだが、SVM全般のgemと拡大解釈して、LL-SVMを実装してみることにした。

Ladicky, L., and Torr, P H.S., "Locally Linear Support Vector Machines," Proc. ICML'11, pp. 985--992, 2011.

インストールと使い方

LL-SVMの実行には、Numo::Linalg(もしくはNumo::TinyLinalg)とlbfgsbを必要とする。Rumale::SVMに含まれるその他のアルゴリズムは必要としないので、runtime-dependencyに含めていない。これらを一緒にインストールする。

gem install rumale-svm numo-tiny_linalg lbfgsb

例として、LIBSVM Dataにある手書き文字画像のデータセットPendigitsの分類を行う。

wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/pendigits
wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/pendigits.t

分類性能の評価のためrumale-evaluation_measureをインストールする。

gem install rumale-evaluation_measure

Pendigitsのデータを分類して、その分類の正確度を出力するスクリプトは以下の様になる。

require 'lbfgsb'
require 'numo/tiny_linalg'
Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)

require 'rumale/dataset'
require 'rumale/evaluation_measure'
require 'rumale/svm'

# Pendigitsデータを読み込む.
x_train, y_train = Rumale::Dataset.load_libsvm_file('pendigits')
x_test, y_test = Rumale::Dataset.load_libsvm_file('pendigits.t')

# LL-SVMによる分類器を生成する.
# パラメータである代表点数は128, 近傍数は8とした.
classifier = Rumale::SVM::LocallyLinearSVC.new(n_anchors: 128, n_neighbors: 8, random_seed: 42)

# 学習データで学習する.
classifier.fit(x_train, y_train)

# テストデータのラベルを推定する.
y_pred = classifier.predict(x_test)

# 推定結果の正確度を出力する.
acc = Rumale::EvaluationMeasure::Accuracy.new
puts format('Accuracy: %.1f%%', (100.0 * acc.score(y_test, y_pred)))

これを実行すると正確度は95.7%となった。

$ ruby example.rb
Accuracy: 95.7%

比較のために線形SVMやRBFカーネルによるカーネルSVMを試してみる。分類器の生成の箇所を以下で置き換えればよい。

# 線形SVMによる分類器を生成する.
classifier = Rumale::SVM::LinearSVC.new(random_seed: 42)
# カーネルSVMによる分類機を生成する.
# classifier = Rumale::SVM::SVC.new(kernel: 'rbf', gamma: 1e-3, random_seed: 42)

結果として、線形SVMの正確度は89.7%、カーネルSVMの正確度は95.4%となった。

$ ruby example.rb
Accuracy: 89.7%

パラメータを調整したりすると、また変わってくると思うが、LL-SVMでいい感じに分類できることがわかる。

アルゴリズムの説明

LL-SVMについてザックリ説明する。LL-SVMの発想としては、多様体学習のLocally Linear Embeddingと似ていて、局所的な線形SVMをなめらかにつなぎあわせることで、結果として非線形な決定境界を得ようというものである。N個のサンプル\mathbf{x}_{1},\ldots,\mathbf{x}_{N}があり、それぞれに二値のラベルy_1,\ldots,y_N\in\lbrace -1,1\rbraceが付与されているとする。LL-SVMのコスト関数は次のとおりである。

 \displaystyle
\mathrm{argmin}_{W,\mathbf{b}}\frac{\lambda}{2}||W||^2+\frac{1}{N} \sum_{k=1}^{N}\max(0, 1-y_k H_{W,\mathbf{b}}(\mathbf{x}_k))

ヒンジ損失によるSVMと似ているが、重みベクトルではなく重み行列Wとなっていて、重みベクトルが複数あることを示している。さらに肝心なのはHで、これは以下のように定義される。

 \displaystyle
H_{W,\mathbf{b}}(\mathbf{x}_k)=\gamma(\mathbf{x}_k)^{\top}W\mathbf{x}_k+\gamma(\mathbf{x}_k)^{\top}\mathbf{b}

\gamma(\mathbf{x})は、local codingと呼ばれるもので、\mathbf{x}を、その近傍に位置する代表点\mathbf{v}の線形結合で近似する際の係数である。

 \displaystyle
\mathbf{x} \approx \sum_{\mathbf{v}\in C}\gamma_{\mathbf{v}}(\mathbf{x})\mathbf{v}

重みとバイアスを、係数\gammaにより線形結合したものを、\mathbf{w}_\gamma^{\top}=\gamma(\mathbf{x}_k)^{\top}Wb_\gamma=\gamma(\mathbf{x}_k)^{\top}\mathbf{b}とすると、線形SVMによる識別関数と同様になる。

 \displaystyle
H_{W,\mathbf{b}}(\mathbf{x}_k)=\mathbf{w}_\gamma^{\top}\mathbf{x}+b_\gamma

こうして、代表点ごとに学習した線形SVMが、local codingによりつなぎあわせられる。代表点には、k-means法によるセントロイドが用いられる。

LL-SVMの重要なパラメータとしては、代表点数とlocal codingのための近傍数となる。よくある2-moonsデータで、代表点数と近傍数による決定境界の違いを見てみる。

代表点16個で近傍数8 - 半円の端が誤分類されてしまう

代表点16個で近傍数4 - 近傍数を小さくすることである程度データ分布に沿った決定境界を得られている

代表点128個で近傍数8 - データ分布を捉えることができていて決定境界もなめらかになっている

論文では、確率的勾配降下法により重み行列を求める方法が提案されている。ただ、確率的勾配降下法は、それなりのデータ量とイテレーション回数、また学習率の調整が必要で、いい感じの結果を得るのが難しい。また、Rubyで実装するとなると、timesやeachによるループが必要になるので、実行速度が遅くなる。そこで、二乗ヒンジ損失にして、L-BFGS法で最適化する方法で実装している。

 \displaystyle
J(W,\mathbf{b})=\frac{\lambda}{2}||W||^2+\frac{1}{|S|} \sum_{k\in S}\max(0, 1-y_k H_{W,\mathbf{b}}(\mathbf{x}_k))^{2}

 \displaystyle
\frac{\partial J}{\partial W}=\lambda W-\frac{2}{|S|} \sum_{k\in S}\max(0, 1-y_k H_{W,\mathbf{b}}(\mathbf{x}_k))y_k \gamma(\mathbf{x}_k)\mathbf{x}_k^{\top}

おわりに

深層学習が牛耳っている様な時代に、10年以上も前のSVMを実装したのは、Google Scholarで探したりすると、最近でもSVMに関する論文がチラホラ発表されていて、懐かしい気持ちになったからである。ニューラルネットワークが冬の時代から復活したように、今は熱心に研究されていない機械学習の手法も、効果的な学習方法が発見されて大躍進をとげる...かもしれない。

github.com