洋食の日記

「た・である」調ではなく「です・ます」調で書きはじめれば良かったなと後悔してる人のブログです

SVMKitにPipeline機能を追加した

はじめに

SVMKitで、一通りベーシックな機械学習アルゴリズムの実装を終えたので、しばらく便利機能の追加を予定している。バージョン0.7.2ではPipelineを実装した。Pipelineを使うことで、正規化して主成分分析してSVMで分類といった連結処理を定義できる。

svmkit | RubyGems.org | your community gem host

使い方

gemコマンドでSVMKitをインストールする。

$ gem install svmkit

例で使うデータセットLIBSVM Dataからpendigitsデータセットを取得する。

$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/pendigits

RBFカーネル近似を行い、線形SVM分類器で分類することをPipelineを使って実装すると次のようになる。

require 'svmkit'

# pendigitsデータを読み込む。
samples, labels = SVMKit::Dataset.load_libsvm_file('pendigits')
samples = Numo::DFloat.cast(samples)

# カーネル近似と線形SVMによる分類器を構成する。
# Pipelineの各ステップはHashで定義する。
rbf = SVMKit::KernelApproximation::RBF.new(gamma: 0.0001, n_components: 800, random_seed: 1)
svc = SVMKit::LinearModel::SVC.new(reg_param: 0.0001, max_iter: 1000, random_seed: 1)
pipeline = SVMKit::Pipeline::Pipeline.new(steps: { hoge: rbf, fuga: svc })

# 5-交差検定を実施する。
kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5, shuffle: true, random_seed: 1)
cv = SVMKit::ModelSelection::CrossValidation.new(estimator: pipeline, splitter: kf)
report = cv.perform(samples, labels)

# 結果を出力する。
mean_accuracy = report[:test_score].inject(:+) / kf.n_splits
puts("5-CV mean accuracy: %.1f %%" % (mean_accuracy * 100.0))

これを実行すると次のようになる。

$ ruby pipeline.rb
5-CV mean accuracy: 99.2 %

Pipelineは、他の機械学習アルゴリズム同様にMarshal.dumpとloadで、モデルの書き出しと読み込みが行える。

おわりに

SVMKitも多機能になり、名前と内容が一致していないため、改名を考えている。0.7系でユーティリティ的なのを実装したら変えるつもりです。