洋食の日記

洋食のことではなく、技術メモを書きます。たまにどうでも良いことも書きます。

機械学習ライブラリSVMKitにカーネルSVMを追加した

はじめに

Pure Ruby機械学習ライブラリSVMKitにカーネルSVMを追加しました。カーネルSVMは、Pure Rubyでは速度的にツラいものがあるかな?と思っていたが、機械学習ライブラリとしては実装されているべきものなので追加した。※それ以前にLogistic Regressionを追加したけどブログに書くのを失念した...

svmkit | RubyGems.org | your community gem host

インストール

行列・ベクトルをあつかうのでNMatrixに依存する。

$ gem install nmatrix svmkit

使い方

Pythonのscikit-learnライクを意識してきたけど、今回はちょっと違うものにしてみた。 入力データとして、特徴ベクトルを与えるのが機械学習ライブラリでは一般的だが、SVMKitではカーネル行列を与える形にした。libsvmでいうところのprecomputed kernelという形式となる。 これは、世にある全てのカーネル関数を実装するのは難しいことと、ライブラリ利用者が適切なカーネルを選択できるように、という思いから。RBFカーネルやシグモイドカーネルは実装してある。

libsvm dataのサイトにあるpendigitsデータセットを読み込んで、分類するコードは以下の通り。 まずは、分類機の訓練から。

require 'svmkit'
require 'libsvmloader'

# libsvm形式の訓練データセットを読み込む。
samples, labels = LibSVMLoader::load_libsvm_file('pendigits', stype: :dense)

# RBFカーネル行列を計算する。※パラメータγは0.005とした。
kernel_matrix = SVMKit::PairwiseMetric::rbf_kernel(samples, nil, 0.005)

# カーネルSVMを用意する。
base_classifier =
  SVMKit::KernelMachine::KernelSVC.new(reg_param: 1.0, max_iter: 1000, random_seed: 1)

# one-vs-restで多値分類器とする。
classifier = SVMKit::Multiclass::OneVsRestClassifier.new(estimator: base_classifier)

# 分類器を訓練する。
classifier.fit(kernel_matrix, labels)

# 分類器を保存する。
File.open('trained_classifier.dat', 'wb') { |f| f.write(Marshal.dump(classifier)) }

そして、訓練済みの分類器を読み込んで、テストデータを分類するコード。

require 'svmkit'
require 'libsvmloader'

# libsvm形式のテストデータセットを読み込む。
samples, labels = LibSVMLoader::load_libsvm_file('pendigits.t', stype: :dense)

# カーネル行列を計算するために、訓練データセットも読み込む。※ラベル情報は不要
tr_samples, = LibSVMLoader::load_libsvm_file('pendigits', stype: :dense)

# 訓練済みの分類器を読み込む。
classifier = Marshal.load(File.binread('trained_classifier.dat'))

# テストデータ-訓練データ間でRBFカーネル行列を計算する。
kernel_matrix = SVMKit::PairwiseMetric::rbf_kernel(samples, tr_samples, 0.005)

# テストデータのラベルを推定し、分類結果のAccuracyを出力する。
puts(sprintf("Accuracy: %.1f%%", 100.0 * classifier.score(kernel_matrix, labels)))

結果としてAccuracyは97.6%となった。

トイデータでの実験

訓練データがこれで、

f:id:yoshoku:20171021152814p:plain

テストデータがこれ。

f:id:yoshoku:20171021152829p:plain

テストデータのラベルを、SVMKitのカーネルSVMで学習して推定すると、

f:id:yoshoku:20171021152844p:plain

こうじゃ!!うまく非線形データを分類できている。

おわりに

つまらないものですが、よろしくお願い致します。