はじめに
特徴ベクトルやラベルをNumo::NArray形式で扱えるLIBSVMのbinding gemが欲しかったので、Ruby拡張ライブラリの勉強もかねて作成した。開発に際しては、Numo::FFTWをとても参考にさせて頂いた。gem名は、勝手ながらnumo-libsvmとした。
numo-libsvm | RubyGems.org | your community gem host
使い方
準備
多くのLIBSVM関連ライブラリがLIBSVMのコード自体を同梱しているが、Numo::Libsvmでは同梱しないことにした。brewやaptなどのパッケージマネージャでインストールしたLIBSVMと、バージョンがズレるのはどうかなと思ったためである。というわけで、LIBSVMをインストールする必要がある(必要になるのはlibsvm.soとsvm.hである)。
例えば、macOSであれば、
$ brew install libsvm
Ubuntuであれば、
$ sudo apt-get install libsvm-dev
となる。
インストール
numo-libsvmは、gemコマンドでインストールできる。
$ gem install numo-libsvm
C-SVCによる分類
LIBSVM DataにあるPendigitsデータセットを使って、C-SVCによる分類を行う。データの取得にはred-datasetsを用いる。
$ gem install red-datasets-numo-narray
では、まず、C-SVCによる分類器を訓練する。
require 'numo/narray' require 'numo/libsvm' require 'datasets-numo-narray' # Pendigitsデータの訓練データセットをダウンロードする. puts 'Download dataset.' pendigits = Datasets::LIBSVM.new('pendigits').to_narray x = pendigits[true, 1..-1] # 特徴ベクトル y = pendigits[true, 0] # ラベル # RBFカーネルによるC-SVCを実現するパラメータを定義する. param = { svm_type: Numo::Libsvm::SvmType::C_SVC, kernel_type: Numo::Libsvm::KernelType::RBF, gamma: 0.0001, # RBFカーネルのパラメータ C: 10, # C-SVCのパラメータ shrinking: true } # C-SVCを訓練する. puts 'Train support vector machine.' model = Numo::Libsvm.train(x, y, param) # パラメータと訓練したモデルをMarshalでファイルに保存する。 puts 'Save parameters and model with Marshal.' File.open('pendigits.dat', 'wb') { |f| f.write(Marshal.dump([param, model])) }
これを実行すると、以下のようになる。
$ ruby train.rb Download dataset. Train support vector machine. Save paramters and model with Marshal.
次に、訓練したモデルで、テストデータの分類を行う。
require 'numo/narray' require 'numo/libsvm' require 'datasets-numo-narray' # Pendigitsデータのテストデータセットをダウンロードする. puts 'Download dataset.' pendigits_test = Datasets::LIBSVM.new('pendigits', note: 'testing').to_narray x = pendigits_test[true, 1..-1] y = pendigits_test[true, 0] # パラメータと訓練したモデルをMarshalで読み込む. puts 'Load parameter and model.' param, model = Marshal.load(File.binread('pendigits.dat')) # テストデータのラベルを推定する. puts 'Predict labels.' predicted = Numo::Libsvm.predict(x, param, model) # 推定結果を正確度で評価する. mean_accuracy = y.eq(predicted).count.fdiv(y.size) puts "Accuracy: %.1f %%" % (100 * mean_accuracy)
これを実行すると、以下のようになる。正確度(Accuracy)が98.3%となり、上手く分類できていることがわかる。
$ ruby test.rb Download dataset. Load parameter and model. Predict labels. Accuracy: 98.3 %
おわりに
ひとまず、最低限の動きができるようになったので公開した。これから、ドキュメントの整備も含めてアップデートしていきたい。また、LIBSVMのコードを同梱していないので、RubyInstallerによるWindows環境では動作しないかもしれない(libsvmのライブラリとヘッダーが指定できればWindowsでも動くと思われる)。
Rumaleなインターフェースも用意しようかと考えたが、それは今後、例えばrumale-libsvmのような形で、Rumale側でできればと思う。