洋食の日記

「た・である」調ではなく「です・ます」調で書きはじめれば良かったなと後悔してる人のブログです

Rumaleの決定木をRuby拡張を使って速くした

はじめに

Rumaleの決定木で、どうしても普通にRubyを使っては高速化できない箇所があり、そこをExtensionで実装した。これをバージョン0.9.0としてリリースした。

rumale | RubyGems.org | your community gem host

決定木では、特徴軸ごとに、不純度にもとづいて最適な分割をさがす。 訓練データがN個あり、特徴量の値が重複しないとすると、N-1個の値が分割を決める閾値の候補となり、それだけforループを回す必要がある(特徴ベクトルがD次元とすると探索ループの外側にD個分のforループがある)。 Rubyに限らず、スクリプト言語の多くは、多重のforループはどうしても遅くなってしまうので、ここをExtensionで実装した。

簡単な実装の話

Ruby Extensionで、ExtDecisionTreeClassifierモジュールを作り、そこに最適な分割を探索するfind_split_paramsメソッドを生やした(回帰についても同様)。

void Init_rumale(void)
{
  VALUE mRumale = rb_define_module("Rumale");
  VALUE mTree = rb_define_module_under(mRumale, "Tree");
  VALUE mExtDTreeCls = rb_define_module_under(mTree, "ExtDecisionTreeClassifier");
  rb_define_method(mExtDTreeCls, "find_split_params", find_split_params_cls, 6);
  
  // ... (省略)
}

その上で、Ruby側のDecisionTreeClassifierクラスでincludeし、分割を探索する部分でfind_split_paramsメソッドを用いた。

module Rumale
  module Tree
    class DecisionTreeClassifier < BaseDecisionTree
      include Base::Classifier
      include ExtDecisionTreeClassifier

      # ... (省略)

これで一部だけRuby Extension化できる。

実験

0.8.4と0.9.0で簡単な決定木の実行速度の比較を行った。今回、データセットの読込には、red-datasetsを用いた。データセットのダウンロードも自動化されており、簡単にデータセットを扱うことができる。

$ gem install rumale red-datasets red-datasets-numo-narray

コードは以下のようになる。データセットには、LIBSVM DATAで提供されている、7,291個256次元のUSPSを用いた。

require 'benchmark'
require 'rumale'
require 'datasets'
require 'datasets-numo-narray'

# データセットを読み込む
usps = Datasets::LIBSVM.new('usps').to_narray
labels = Numo::Int32.cast(usps[nil, 0])
samples = Numo::DFloat.cast(usps[nil, 1..-1])

# データセットの10%をテストセットとして、ランダムに訓練とテストにわける
ss = Rumale::ModelSelection::ShuffleSplit.new(n_splits: 1, test_size: 0.1, random_seed: 1)
train_ids, test_ids = ss.split(samples, labels).first

# 決定木の訓練とテストを10回おこない実行速度を確認する
Benchmark.bm 10 do |r|
  r.report 'dtree' do
    est = Rumale::Tree::DecisionTreeClassifier.new(random_seed: 1)
    est.fit(samples[train_ids, true], labels[train_ids])
    est.predict(samples[test_ids, true])
  end
end

これをRumaleのバージョン0.8.4と0.9.0で実行すると、次のようになる。だいたい100倍ほど速くなっているのがわかる(※注: データ数や特徴ベクトルの次元数が小さい場合これほどの差はつかない)。

# ver. 0.8.4
                 user     system      total        real
dtree      4923.530000 198.290000 5121.820000 (7917.143121)
# ver. 0.9.0
                 user     system      total        real
dtree       35.210000   0.310000  35.520000 ( 37.066986)

決定木が速くなると、Random ForestやAdaboost(Rumaleでは弱学習器に決定木を使用している)も速くなるので、この改善は大きい。

おわりに

当初Extensionを、C99なC言語で書いていたが、Travis CIでコケてしまった。コンパイルオプションで「-std=gnu99」とか付けることも考えたが、環境によってどうなるかわからなかったので、C90で書き直した。ExtensionはC90で書くのが安全と思われる。またMakefileなどのビルドの部分は、rake-compilerが良い感じにしてくれるので、簡単だった。たぶん、MacとかLinuxでは無事にインストールできる...と思う。

Rumaleの0.9系は、どうしても高速化できない部分(主に行列・ベクトル演算で書くことが難しくNumo::NArrayの恩恵が得られない部分)をExtensionで書き直すことを考えている。また、並列化できる箇所もあるので、Parallel gemなどを導入することを考えている。