洋食の日記

洋食のことではなく、技術メモを書きます。たまにどうでも良いことも書きます。

rb-libsvmとlibsvmloaderを使ったRubyでのカーネル非線形SVMによる分類

はじめに

RubyLIBSVMバインドであるrb-libsvmと、libsvm形式のデータセットを読み書きするlibsvmloaderで、カーネル非線形SVMで分類する例です。カーネル非線形と明記するのは、libsvmの姉妹品であるliblinearが線形SVMなため。

準備

まずLIBSVMそのものをインストールする必要がある。macでhomebrewなら次のとおり。

$ brew install libsvm

次にGemをインストールする。libsvmloaderがnmatrixに依存するので、一応nmatrixつけてみました。SciRubyにようこそ的なメッセージがでるのが良いよね〜。

source 'https://rubygems.org'

gem 'nmatrix'
gem 'rb-libsvm'
gem 'libsvmloader'
$ bundle install

最後にサンプルで使うデータセットlibsvmのサイトからダウンロードする。MNISTが有名だけど、大きいので、同じ手書き文字認識データセットのpendigitsを使う。

$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/pendigits
$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/pendigits.t

コード

「1. 訓練データセットで、SVMを学習する。2. テストデータセットを読み込んで、ラベルを推定する。 3. Accuracyを計算して出力する。」この一連の流れを一つにまとめた。

require 'libsvmloader'
require 'libsvm'

# 訓練データを読み込む
samples, labels = LibSVMLoader.load_libsvm_file('pendigits')

# 訓練データをrb-libsvmの特徴ベクトル形式に変換する
examples = samples.each_row.map { |v| Libsvm::Node.features(v.to_a) }

# カーネル非線形SVMのパラメータを設定する
params = Libsvm::SvmParameter.new.tap do |p|
  p.cache_size = 1000                      # キャッシュメモリサイズ [MB]
  p.svm_type = Libsvm::SvmType::C_SVC      # SVM分類器
  p.kernel_type = Libsvm::KernelType::RBF  # RBFカーネル
  p.gamma = 0.0001                         # RBFカーネルのパラメータ
  p.c = 1.0                                # SVMの正則化パラメータ
  p.eps = 0.001                            # SVMの最適化を終了する閾値
end

# カーネル非線形SVMを訓練する
problem = Libsvm::Problem.new
problem.set_examples(labels.to_flat_a, examples)
model = Libsvm::Model.train(problem, params)

# テストデータを読み込む
samples, labels = LibSVMLoader.load_libsvm_file('pendigits.t')

# テストデータのラベルを推定する
preds = samples.each_row.map { |v| model.predict(Libsvm::Node.features(v.to_a)) }

# Accuracyを出力する
n_hits = preds.map.with_index { |l, n| 1 if l == labels[n] }.compact.sum
n_samples = labels.size
accuracy = 100.0 * (n_hits / n_samples.to_f)
puts(format('Accuracy = %.4f%% (%d/%d)', accuracy, n_hits, n_samples))

これを実行すると、pendigitsデータセットの学習と分類が行われて、バーンとAccuracyが表示される。

$ ruby hoge.rb
Accuracy = 98.2847% (3438/3498)

おわりに

RubyにはLIBSVMをバインドしたGemが他にもある。それぞれに使い方が異なるので、チームやプロジェクトにあったものを選ぶのが良いと思う。Pythonのscikit-learnmみたいな、デファクトスタンダード機械学習ライブラリはRubyにはない(2017/09/16現在)。みんな自由にやってる感じ。でも、それはそれで良い雰囲気だと思う。そんなわけで、個人的にscikit-learnインスパイアなRubyのベーシックな機械学習ライブラリを作ることを考えている(特に野心はなく単純な興味から)。深層学習は「流行ってるから、もう誰かライブラリ作ってるでしょ〜」といった気持ち。

ホントにスゴイ人気!!