はじめに
SVMKitで、一通りベーシックな機械学習アルゴリズムの実装を終えたので、しばらく便利機能の追加を予定している。バージョン0.7.2ではPipelineを実装した。Pipelineを使うことで、正規化して主成分分析してSVMで分類といった連結処理を定義できる。
svmkit | RubyGems.org | your community gem host
使い方
gemコマンドでSVMKitをインストールする。
$ gem install svmkit
例で使うデータセットはLIBSVM Dataからpendigitsデータセットを取得する。
$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/pendigits
RBFカーネル近似を行い、線形SVM分類器で分類することをPipelineを使って実装すると次のようになる。
require 'svmkit' # pendigitsデータを読み込む。 samples, labels = SVMKit::Dataset.load_libsvm_file('pendigits') samples = Numo::DFloat.cast(samples) # カーネル近似と線形SVMによる分類器を構成する。 # Pipelineの各ステップはHashで定義する。 rbf = SVMKit::KernelApproximation::RBF.new(gamma: 0.0001, n_components: 800, random_seed: 1) svc = SVMKit::LinearModel::SVC.new(reg_param: 0.0001, max_iter: 1000, random_seed: 1) pipeline = SVMKit::Pipeline::Pipeline.new(steps: { hoge: rbf, fuga: svc }) # 5-交差検定を実施する。 kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5, shuffle: true, random_seed: 1) cv = SVMKit::ModelSelection::CrossValidation.new(estimator: pipeline, splitter: kf) report = cv.perform(samples, labels) # 結果を出力する。 mean_accuracy = report[:test_score].inject(:+) / kf.n_splits puts("5-CV mean accuracy: %.1f %%" % (mean_accuracy * 100.0))
これを実行すると次のようになる。
$ ruby pipeline.rb 5-CV mean accuracy: 99.2 %
Pipelineは、他の機械学習アルゴリズム同様にMarshal.dumpとloadで、モデルの書き出しと読み込みが行える。
おわりに
SVMKitも多機能になり、名前と内容が一致していないため、改名を考えている。0.7系でユーティリティ的なのを実装したら変えるつもりです。