はじめに
SVMKitの開発は、何かしら毎月バージョンアップしようという思いで進めており、無事に2月も0.2.4をリリースすることができた。
svmkit | RubyGems.org | your community gem host
SVMKitにFactorization Machine(FM)による分類器を追加した。FMは、一般的な線形分類器に、特徴ベクトル間の相互関係を(低次元な)潜在ベクトルの内積に因子分解することで捉える項を足したような形となる(と一文で説明するのは難しいので詳しくは元論文を参照してください。そんなに難しい論文ではないです)。 FMは、スパースな特徴ベクトルに効果的とされ、一般に推薦などに使われるが、普通に分類器として使うこともできる。 実装したのは、確率的勾配降下法(Stochastic Gradient Descent, SGD)による、Hinge lossなFMである。 SGDによる各パラメータの最適化部分は、論文に掲載されているものから改良して、Mini-Batchなものにした。
その他に、SVMKit 0.2.4では、評価尺度を計算するEvaluatorモジュールを追加した。これにより、評価尺度の計算を分離することができ、これまでAccuracyの計算のみだったが、Precision、Recall、F値を計算できるようにした(今後もLog-Lossなどを追加していく予定)。
使い方
まずSVMKitをインストールする。線形代数の計算で使用しているNumo::NArrayもインストールされる。
$ gem install svmkit
次に、データを用意する。今回は、LIBSVM DATAから、手書き数字のデータセットであるpendigitsをとってきた。
$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/pendigits
Factorization Machineの分類精度を、5交差検定でF値でみる場合、以下のようになる。
require 'svmkit' # LIBSVM形式のデータを読み込む. samples, labels = SVMKit::Dataset.load_libsvm_file('pendigits') # FMを定義する. # 各パラメータの意味はドキュメンを参照ください(http://www.rubydoc.info/gems/svmkit/0.2.4) factm = SVMKit::PolynomialModel::FactorizationMachineClassifier.new( n_factors: 4, init_std: 0.001, reg_param_bias: 1.0, reg_param_weight: 1.0, reg_param_factor: 10000.0, max_iter: 1000, batch_size: 50, random_seed: 1) # One-vs-restで多値分類器にする. ovr_factm = SVMKit::Multiclass::OneVsRestClassifier.new(estimator: factm) # 評価尺度はマクロ平均なF値で. ev = SVMKit::EvaluationMeasure::FScore.new(average: 'macro') # 5-交差検定で評価する. kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5, shuffle: true, random_seed: 1) cv = SVMKit::ModelSelection::CrossValidation.new(estimator: ovr_factm, splitter: kf, evaluator: ev) report = cv.perform(samples, labels) # 結果を出力する. mean_f_score = report[:test_score].inject(:+) / kf.n_splits puts(sprintf("Mean F1-Score: %.1f%%", 100.0 * mean_f_score))
これを実行すると以下の様になる。
$ ruby svmkit_fm_cv.rb Mean F1-Score: 0.886
ちなみに、線形SVMでは、F値が0.601だった。 個人的な知見だが、FMは因子分解なパラメータが、オーバーフィットしやすいように思うので、そこの正則化パラメータ(↑の例ではreg_param_factor)を線形項の正則化パラメータ(↑の例ではreg_param_biasやreg_param_weight)よりも、大きくしたほうが多くの場合で上手くいく。
おわりに
つまらないものですが、よろしくお願い致します(SVMKitって名前なのに、SVM関係なくなってきてる...)