はじめに
SVMKitで「LIBSVM相当のことができるように」と思い、K分割交差検証(K-fold cross validation)を追加した。一度、cross validationするためのデータを分割するクラスを追加した段階で「これでminimum viable productかな」と思って、0.2.2としてリリースした。その後、予想していたよりもサクッとcross validationが実装できたので、0.2.3としてリリースした。
svmkit | RubyGems.org | your community gem host
使い方
データ分割クラスStratifiedKFold(もしくはKFold)と適当な分類器クラスを、CrossValidationクラスに渡して、performメソッドを実行すると交差検証が始まる形とした。performメソッドは、scikit-learnと同様に、実行時間やテストデータセットのスコアが配列で入ったHashを返す。
require 'svmkit' # LIBSVM形式のデータを読み込む(LIBSVM Dataのpendigitsデータセット). samples, labels = SVMKit::Dataset.load_libsvm_file('pendigits') # カーネルSVMをOne-vs-Restで多値分類器にする. kernel_svc = SVMKit::KernelMachine::KernelSVC.new(reg_param: 1.0, max_iter: 1000, random_seed: 1) ovr_kernel_svc = SVMKit::Multiclass::OneVsRestClassifier.new(estimator: kernel_svc) # StratifiedなK-fold分割を行うクラスを生成する(シャッフルして、各クラスで5分割する). kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5, shuffle: true, random_seed: 1) # カーネルSVMの性能を交差検証で確認する. cv = SVMKit::ModelSelection::CrossValidation.new(estimator: ovr_kernel_svc, splitter: kf) kernel_mat = SVMKit::PairwiseMetric::rbf_kernel(samples, nil, 0.005) report = cv.perform(kernel_mat, labels) # 平均正確度を出力する. mean_accuracy = report[:test_score].inject(:+) / kf.n_splits puts(sprintf("Mean Accuracy: %.1f%%", 100.0 * mean_accuracy))
同じ様なことを、CrossValidationクラスではなく、データ分割のStratifiedKFoldクラスだけを使って行うと次のようになる。
require 'svmkit' # LIBSVM形式のデータを読み込む(LIBSVM Dataのpendigitsデータセット). samples, labels = SVMKit::Dataset.load_libsvm_file('pendigits') # StratifiedなK-fold分割を行うクラスを生成する(シャッフルして、各クラスで5分割する). kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5, shuffle: true, random_seed: 1) # K-fold cross validation法で分類精度を評価する. scores = kf.split(samples, labels).map do |train_ids, test_ids| # 訓練データセットとテストデータセットに分ける. train_samples = samples[train_ids, true] train_labels = labels[train_ids] test_samples = samples[test_ids, true] test_labels = labels[test_ids] # 訓練データでカーネルSVMを学習する. kernel_matrix = SVMKit::PairwiseMetric::rbf_kernel(train_samples, nil, 0.005) base_classifier = SVMKit::KernelMachine::KernelSVC.new(reg_param: 1.0, max_iter: 1000, random_seed: 1) classifier = SVMKit::Multiclass::OneVsRestClassifier.new(estimator: base_classifier) classifier.fit(kernel_matrix, train_labels) # テストデータで学習したカーネルSVMの分類精度を評価する. kernel_matrix = SVMKit::PairwiseMetric::rbf_kernel(test_samples, train_samples, 0.005) classifier.score(kernel_matrix, test_labels) end # 平均正確度を出力する. mean_accuracy = scores.inject(:+) / kf.n_splits puts sprintf("Accuracy: %.1f%%", 100.0 * mean_accuracy)
これらを実行すると、以下のような5分割交差検証の分類精度が出力される。
$ ruby svmkit_validation.rb Accuracy: 98.3%
おわりに
つまらないものですが、よろしくお願い致します。