洋食の日記

「た・である」調ではなく「です・ます」調で書きはじめれば良かったなと後悔してる人のブログです

RumaleにExtra-Treesによる分類と回帰を追加した

はじめに

新しい時代になったので、Rumaleに新しいアルゴリズムを追加してバージョンを0.9.1に上げてみた。あわせて、バージョン0.9.0で導入した決定木のC拡張もリファクタリングして、少しだけ速くなっている。

rumale | RubyGems.org | your community gem host

Extra-Trees

Extra-Trees(Extremely randomized trees)はおもしろいアルゴリズムで、決定木では特徴軸を分割する際に、Gini係数やEntropyなどを基準に、利得が最大になる特徴とその分割の閾値を選択するが、Extra-Treesはそれらをランダムに選択する。このランダムな木を、Random Forestと同じように複数用意してBaggingするのだが、それぞれの木を学習する際に、Bootstrapサンプリングはせずに訓練データ全てを用いる。シンプルなので高速に動くし、分類精度もRandom Forestに匹敵する値となる。

link.springer.com

使い方

Rumaleはgemコマンドでインストールできる。Numo::NArrayに依存している。

$ gem install rumale

Rumaleは基本的にはscikit-learnのインターフェースに合わせている。fitしてpredictする感じ。 scoreメソッドを使うとAccyracyを計算する。

require 'rumale'

# データの取得にはred-datasetsを用いた. gemコマンドでインストールできる.
#   gem install red-datasets-numo-narray
# USPSという手書き文字データセットを読み込む.
require 'datasets'
require 'datasets-numo-narray'

usps = Datasets::LIBSVM.new('usps').to_narray
labels = Numo::Int32.cast(usps[true, 0])
samples = Numo::DFloat.cast(usps[true, 1..-1])

# ランダム分割で訓練とテストに分ける.
ss = Rumale::ModelSelection::ShuffleSplit.new(n_splits: 1, test_size: 0.1, random_seed: 1)
train_ids, test_ids = ss.split(samples, labels).first

# 訓練データセットでExtra-Treesによる分類器を学習する.
est = Rumale::Ensemble::ExtraTreesClassifier.new(random_seed: 1)
est.fit(samples[train_ids, true], labels[train_ids])

# テストデータセットで分類性能を評価する.
puts("Accuracy: %.4f" % est.score(samples[test_ids, true], labels[test_ids]))

分類精度の比較

上記のコードを実行すると、次のような結果になる。手書き文字認識の正確度が95%とランダムながらそれなりの精度となる。

$ ruby tree.rb
Accuracy: 0.9547

コードを一行変えてRandom Forestを試してみる

# 訓練データセットでExtra-Treesによる分類器を学習する.
# est = Rumale::Ensemble::ExtraTreesClassifier.new(random_seed: 1)
est = Rumale::Ensemble::RandomForestClassifier.new(random_seed: 1)

これを実行すると、次のような結果になる。わずかにExtra-Treesのほうが正確度が高い。 Extra-Treesはランダムに特徴と閾値の選択をするため、本来であれば、 乱数のシードを変えつつ複数回実行した平均値を比較すべきだが、Random Forestと同程度の結果が得られることがわかる。

$ ruby tree.rb
Accuracy: 0.9438

Extra-TreesおよびRandom Forestにはハイパーパラメータがあるので、それらを調整するとまた結果が変わってくるだろう。

実行速度の比較

データセットの分割の後ろに以下のコードを追加して、実行速度を計測した。 Extra-Treesは、ランダムに選択される閾値によっては、運悪く木が深くなりすぎる可能性があるので、 木の深さを表すmax_depthパラメータを10とした(Accuracyは0.93程度となる)。

# 省略

# ランダム分割で訓練とテストに分ける.
ss = Rumale::ModelSelection::ShuffleSplit.new(n_splits: 1, test_size: 0.1, random_seed: 1)
train_ids, test_ids = ss.split(samples, labels).first

# Benchmarkを使って訓練・テストの実行速度を計測する.
require 'benchmark'

Benchmark.bm 10 do |r|
  r.report 'extra-trees' do
    est = Rumale::Ensemble::ExtraTreesClassifier.new(max_depth: 10, random_seed: 1)
    est.fit(samples[train_ids, true], labels[train_ids])
    est.predict(samples[test_ids, true])
  end

  r.report 'random forest' do
    est = Rumale::Ensemble::RandomForestClassifier.new(max_depth: 10, random_seed: 1)
    est.fit(samples[train_ids, true], labels[train_ids])
    est.predict(samples[test_ids, true])
  end
end

これを実行すると以下のようになる。Random Forestの方が速いことがわかる。しかし、Random Forestで使用している決定木の特徴軸の分割はRuby拡張(C言語)で、Extra-TreesはPure Rubyで実装していることを考えると、Extra-Treesは速いと感じる。

$ ruby tree.rb
                 user     system      total        real
extra-trees  8.780000   0.110000   8.890000 (  8.955922)
random forest  6.660000   0.090000   6.750000 (  6.781255)

おわりに

Extra-Treesは、Kaggleでもstackingの最終層の分類器などで使われる。Extra-Treesでは、決定木の分割において、ランダムに特徴と閾値を選択するが、Random Forestと同等の分類精度を得られるのが興味ぶかい。Extra-Treesだけでなく、Random recursive SVMやEcho state networkなど、ランダム要素を含む機械学習アルゴリズムは多くあり、いずれも調べていると深みにはまるおもしろさがある(うまくいく乱数のシードを探しはじめたり...)。

github.com