洋食の日記

洋食のことではなく、技術メモを書きます。たまにどうでも良いことも書きます。

SVMKitにK分割交差検証を追加した

はじめに

SVMKitで「LIBSVM相当のことができるように」と思い、K分割交差検証(K-fold cross validation)を追加した。一度、cross validationするためのデータを分割するクラスを追加した段階で「これでminimum viable productかな」と思って、0.2.2としてリリースした。その後、予想していたよりもサクッとcross validationが実装できたので、0.2.3としてリリースした。

svmkit | RubyGems.org | your community gem host

使い方

データ分割クラスStratifiedKFold(もしくはKFold)と適当な分類器クラスを、CrossValidationクラスに渡して、performメソッドを実行すると交差検証が始まる形とした。performメソッドは、scikit-learnと同様に、実行時間やテストデータセットのスコアが配列で入ったHashを返す。

require 'svmkit'

# LIBSVM形式のデータを読み込む(LIBSVM Dataのpendigitsデータセット).
samples, labels = SVMKit::Dataset.load_libsvm_file('pendigits')

# カーネルSVMをOne-vs-Restで多値分類器にする.
kernel_svc =
  SVMKit::KernelMachine::KernelSVC.new(reg_param: 1.0, max_iter: 1000, random_seed: 1)
ovr_kernel_svc = SVMKit::Multiclass::OneVsRestClassifier.new(estimator: kernel_svc)

# StratifiedなK-fold分割を行うクラスを生成する(シャッフルして、各クラスで5分割する).
kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5, shuffle: true, random_seed: 1)

# カーネルSVMの性能を交差検証で確認する.
cv = SVMKit::ModelSelection::CrossValidation.new(estimator: ovr_kernel_svc, splitter: kf)
kernel_mat = SVMKit::PairwiseMetric::rbf_kernel(samples, nil, 0.005)
report = cv.perform(kernel_mat, labels)

# 平均正確度を出力する.
mean_accuracy = report[:test_score].inject(:+) / kf.n_splits
puts(sprintf("Mean Accuracy: %.1f%%", 100.0 * mean_accuracy))

同じ様なことを、CrossValidationクラスではなく、データ分割のStratifiedKFoldクラスだけを使って行うと次のようになる。

require 'svmkit'

# LIBSVM形式のデータを読み込む(LIBSVM Dataのpendigitsデータセット).
samples, labels = SVMKit::Dataset.load_libsvm_file('pendigits')

# StratifiedなK-fold分割を行うクラスを生成する(シャッフルして、各クラスで5分割する).
kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5, shuffle: true, random_seed: 1)

# K-fold cross validation法で分類精度を評価する.
scores = kf.split(samples, labels).map do |train_ids, test_ids|
  # 訓練データセットとテストデータセットに分ける.
  train_samples = samples[train_ids, true]
  train_labels = labels[train_ids]
  test_samples = samples[test_ids, true]
  test_labels = labels[test_ids]
  # 訓練データでカーネルSVMを学習する.
  kernel_matrix = SVMKit::PairwiseMetric::rbf_kernel(train_samples, nil, 0.005)
  base_classifier =
    SVMKit::KernelMachine::KernelSVC.new(reg_param: 1.0, max_iter: 1000, random_seed: 1)
  classifier = SVMKit::Multiclass::OneVsRestClassifier.new(estimator: base_classifier)
  classifier.fit(kernel_matrix, train_labels)
  # テストデータで学習したカーネルSVMの分類精度を評価する.
  kernel_matrix = SVMKit::PairwiseMetric::rbf_kernel(test_samples, train_samples, 0.005)
  classifier.score(kernel_matrix, test_labels)
end

# 平均正確度を出力する.
mean_accuracy = scores.inject(:+) / kf.n_splits
puts sprintf("Accuracy: %.1f%%", 100.0 * mean_accuracy)

これらを実行すると、以下のような5分割交差検証の分類精度が出力される。

$ ruby svmkit_validation.rb
Accuracy: 98.3%

おわりに

つまらないものですが、よろしくお願い致します。