洋食の日記

記事をです・ます調で書き始めれば良かったと後悔している人のブログです

Red DatasetsとSVMKitを使ってIrisデータセットでの線形SVMの分類精度を確認する

はじめに

Red Datasetsは、IrisやMNISTといった公開されているデータセットを、Rubyで簡単に扱えるようにするプロジェクトである(Pythonでいえば、scikit-learnのsklearn.datasetsや、Kerasのkeras.datasetsに近い)。本記事では、Red DatasetsでIrisデータセットを読み込み、SVMKitで線形SVMによる分類精度の交差検定を行う。SVMKitは、データをNumo::NArrayで扱うので、そこの変換が必要になる。

インストール

Red DatasetsとSVMkitともに、gemで簡単にインストールできる。

$ gem install red-datasets svmkit

Red Datasetsの簡単な使い方

使い方は、Red DatasetsのUsageがわかりやすい。 Red Datasetsは、データセットを、関係データベースの様な複数レコードからなるテーブルで表現する。

require 'datasets'

# Irisデータセットをnewする。
iris = Datasets::Iris.new

# Irisデータセットは、花のアヤメの種類を表すラベルと、ガク片と花びらの長さ・幅による特徴量からなる。
# eachでそれらを1つずつ見ていく。
iris.each do |r| 
  puts "#{r.label}, #{r.sepal_length}, #{r.sepal_width}, #{r.petal_length}, #{r.petal_width}"
end

これを実行すると以下のようになる。便利!!

$ ruby red_test.rb
Iris-setosa, 5.1, 3.5, 1.4, 0.2
Iris-setosa, 4.9, 3.0, 1.4, 0.2
Iris-setosa, 4.7, 3.2, 1.3, 0.2

... (省略)

コード

Red Datasetsで読み込んだIrisデータセットで、線形SVMの分類精度の交差検定を行う。テーブルによるデータの取り出しを試してみた。

require 'datasets'
require 'svmkit'
require 'numo/narray'

# Irisデータセットを読み込む。
iris = Datasets::Iris.new

# テーブルを取得する。
iris_table = iris.to_table

# ラベルと特徴量に分けてとりだす。
iris_labels = iris_table[:label]
iris_attrs = iris_table.fetch_values(
  :sepal_length, :sepal_width, :petal_length, :petal_width).transpose

# Irisデータセットの文字列によるラベルを整数値のラベル (Numo::Int32) に変換する。
encoder = SVMKit::Preprocessing::LabelEncoder.new
labels = encoder.fit_transform(iris_labels)

# Irisデータセットの特徴量をNumo::DFloatに変換する。
samples = Numo::DFloat[*iris_attrs]

# 線形SVMの5-fold分割による交差検定を定義する。
svc = SVMKit::LinearModel::SVC.new(
  reg_param: 0.0001, fit_bias: true, max_iter: 3000, random_seed: 1)
kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5, random_seed: 1)
cv = SVMKit::ModelSelection::CrossValidation.new(estimator: svc, splitter: kf)

# 交差検定を実行する。
report = cv.perform(samples, labels)

# 平均Accuracyを出力する。
mean_accuracy = report[:test_score].inject(:+) / kf.n_splits
puts("Mean Accuracy: %.1f%%" % (100.0 * mean_accuracy))

これを実行すると以下のとおり。Irisデータセットにおける線形SVMの分類精度を、5交差検定で確認すると94.7%であるとわかる。

Mean Accuracy: 94.7%

おわりに

公開されているデータセットは、ファイル形式が独自のバイナリだったりして、まず扱えるようにするまでが大変だったりするものも多い。Red Datasetsのように、データセットを統一された使い方で扱えるのはとてもありがたい。何かデータセットを追加して貢献できたらな〜。