洋食の日記

「た・である」調ではなく「です・ます」調で書きはじめれば良かったなと後悔してる人のブログです

データセットをNumo::NArrayであつかうLIBSVMのGemを作成した

はじめに

特徴ベクトルやラベルをNumo::NArray形式で扱えるLIBSVMのbinding gemが欲しかったので、Ruby拡張ライブラリの勉強もかねて作成した。開発に際しては、Numo::FFTWをとても参考にさせて頂いた。gem名は、勝手ながらnumo-libsvmとした。

numo-libsvm | RubyGems.org | your community gem host

使い方

準備

多くのLIBSVM関連ライブラリがLIBSVMのコード自体を同梱しているが、Numo::Libsvmでは同梱しないことにした。brewやaptなどのパッケージマネージャでインストールしたLIBSVMと、バージョンがズレるのはどうかなと思ったためである。というわけで、LIBSVMをインストールする必要がある(必要になるのはlibsvm.soとsvm.hである)。

例えば、macOSであれば、

$ brew install libsvm

Ubuntuであれば、

$ sudo apt-get install libsvm-dev

となる。

インストール

numo-libsvmは、gemコマンドでインストールできる。

$ gem install numo-libsvm

C-SVCによる分類

LIBSVM DataにあるPendigitsデータセットを使って、C-SVCによる分類を行う。データの取得にはred-datasetsを用いる。

$ gem install red-datasets-numo-narray 

では、まず、C-SVCによる分類器を訓練する。

require 'numo/narray'
require 'numo/libsvm'
require 'datasets-numo-narray'

# Pendigitsデータの訓練データセットをダウンロードする.
puts 'Download dataset.'
pendigits = Datasets::LIBSVM.new('pendigits').to_narray
x = pendigits[true, 1..-1] # 特徴ベクトル
y = pendigits[true, 0]     # ラベル

# RBFカーネルによるC-SVCを実現するパラメータを定義する.
param = {
  svm_type: Numo::Libsvm::SvmType::C_SVC,
  kernel_type: Numo::Libsvm::KernelType::RBF,
  gamma: 0.0001, # RBFカーネルのパラメータ
  C: 10, # C-SVCのパラメータ
  shrinking: true
}

# C-SVCを訓練する.
puts 'Train support vector machine.'
model = Numo::Libsvm.train(x, y, param)

# パラメータと訓練したモデルをMarshalでファイルに保存する。
puts 'Save parameters and model with Marshal.'
File.open('pendigits.dat', 'wb') { |f| f.write(Marshal.dump([param, model])) }

これを実行すると、以下のようになる。

$ ruby train.rb
Download dataset.
Train support vector machine.
Save paramters and model with Marshal.

次に、訓練したモデルで、テストデータの分類を行う。

require 'numo/narray'
require 'numo/libsvm'
require 'datasets-numo-narray'

# Pendigitsデータのテストデータセットをダウンロードする.
puts 'Download dataset.'
pendigits_test = Datasets::LIBSVM.new('pendigits', note: 'testing').to_narray
x = pendigits_test[true, 1..-1]
y = pendigits_test[true, 0]

# パラメータと訓練したモデルをMarshalで読み込む.
puts 'Load parameter and model.'
param, model = Marshal.load(File.binread('pendigits.dat'))

# テストデータのラベルを推定する.
puts 'Predict labels.'
predicted = Numo::Libsvm.predict(x, param, model)

# 推定結果を正確度で評価する.
mean_accuracy = y.eq(predicted).count.fdiv(y.size)
puts "Accuracy: %.1f %%" % (100 * mean_accuracy)

これを実行すると、以下のようになる。正確度(Accuracy)が98.3%となり、上手く分類できていることがわかる。

$ ruby test.rb
Download dataset.
Load parameter and model.
Predict labels.
Accuracy: 98.3 %

おわりに

ひとまず、最低限の動きができるようになったので公開した。これから、ドキュメントの整備も含めてアップデートしていきたい。また、LIBSVMのコードを同梱していないので、RubyInstallerによるWindows環境では動作しないかもしれない(libsvmのライブラリとヘッダーが指定できればWindowsでも動くと思われる)。

Rumaleなインターフェースも用意しようかと考えたが、それは今後、例えばrumale-libsvmのような形で、Rumale側でできればと思う。

github.com