洋食の日記

「た・である」調ではなく「です・ます」調で書きはじめれば良かったなと後悔してる人のブログです

SVMKitの二値分類器にOne-vs-the-restを内包した

はじめに

SVMKitのSVMKit::LinearModel::SVCなどの二値分類器で、多値のラベルを与えれた場合に、自動的にOne-vs-the-rest法で多値分類器化するように修正を加えた。

svmkit | RubyGems.org | your community gem host

これにより、SVMKit::Multiclass::OneVsRestClassifierを使うことなく、多値分類が可能となる。あわせて、SVCとLogisticRegressionの実装を修正した。Numo::Narrayに不慣れなときに実装したものだったので、今回の修正でわずかだかパフォーマンスも向上した。

使い方

LinearModel::SVCの他、LinearModel::LogisticRegression、KernelMachine::KernelSVC、PolynomialModel::FactorizationMachineClassifierで、明示的な多値分類器化が必要なくなる。

require 'svmkit'

# LIBSVM形式のデータを読み込む.
samples, labels = SVMKit::Dataset.load_libsvm_file('pendigits')

# 線形SVM分類器を定義する.
svc = SVMKit::LinearModel::SVC.new

# 以前は、明示的に多値分類器化する必要があったが、これがなくなる.
# ovr_svc = SVMKit::Multiclass::OneVsRestClassifier.new(estimator: svc)

# 5-交差検定で評価する.
kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5, shuffle: true, random_seed: 1)
cv = SVMKit::ModelSelection::CrossValidation.new(estimator: svc, splitter: kf)
report = cv.perform(samples, labels)

# 結果を出力する.
mean_accuracy = report[:test_score].inject(:+) / kf.n_splits
puts(sprintf("Accuracy: %.1f%%", 100.0 * mean_accuracy))

おわりに

0.2系では細かい修正を行い、0.3系から回帰かクラスタリングの実装をしていく予定にしている。

つまらないものですが、よろしくお願い致します。