はじめに
Rumaleの決定木で、どうしても普通にRubyを使っては高速化できない箇所があり、そこをExtensionで実装した。これをバージョン0.9.0としてリリースした。
rumale | RubyGems.org | your community gem host
決定木では、特徴軸ごとに、不純度にもとづいて最適な分割をさがす。 訓練データがN個あり、特徴量の値が重複しないとすると、N-1個の値が分割を決める閾値の候補となり、それだけforループを回す必要がある(特徴ベクトルがD次元とすると探索ループの外側にD個分のforループがある)。 Rubyに限らず、スクリプト言語の多くは、多重のforループはどうしても遅くなってしまうので、ここをExtensionで実装した。
簡単な実装の話
Ruby Extensionで、ExtDecisionTreeClassifierモジュールを作り、そこに最適な分割を探索するfind_split_paramsメソッドを生やした(回帰についても同様)。
void Init_rumale(void) { VALUE mRumale = rb_define_module("Rumale"); VALUE mTree = rb_define_module_under(mRumale, "Tree"); VALUE mExtDTreeCls = rb_define_module_under(mTree, "ExtDecisionTreeClassifier"); rb_define_method(mExtDTreeCls, "find_split_params", find_split_params_cls, 6); // ... (省略) }
その上で、Ruby側のDecisionTreeClassifierクラスでincludeし、分割を探索する部分でfind_split_paramsメソッドを用いた。
module Rumale module Tree class DecisionTreeClassifier < BaseDecisionTree include Base::Classifier include ExtDecisionTreeClassifier # ... (省略)
これで一部だけRuby Extension化できる。
実験
0.8.4と0.9.0で簡単な決定木の実行速度の比較を行った。今回、データセットの読込には、red-datasetsを用いた。データセットのダウンロードも自動化されており、簡単にデータセットを扱うことができる。
$ gem install rumale red-datasets red-datasets-numo-narray
コードは以下のようになる。データセットには、LIBSVM DATAで提供されている、7,291個256次元のUSPSを用いた。
require 'benchmark' require 'rumale' require 'datasets' require 'datasets-numo-narray' # データセットを読み込む usps = Datasets::LIBSVM.new('usps').to_narray labels = Numo::Int32.cast(usps[nil, 0]) samples = Numo::DFloat.cast(usps[nil, 1..-1]) # データセットの10%をテストセットとして、ランダムに訓練とテストにわける ss = Rumale::ModelSelection::ShuffleSplit.new(n_splits: 1, test_size: 0.1, random_seed: 1) train_ids, test_ids = ss.split(samples, labels).first # 決定木の訓練とテストを10回おこない実行速度を確認する Benchmark.bm 10 do |r| r.report 'dtree' do est = Rumale::Tree::DecisionTreeClassifier.new(random_seed: 1) est.fit(samples[train_ids, true], labels[train_ids]) est.predict(samples[test_ids, true]) end end
これをRumaleのバージョン0.8.4と0.9.0で実行すると、次のようになる。だいたい100倍ほど速くなっているのがわかる(※注: データ数や特徴ベクトルの次元数が小さい場合これほどの差はつかない)。
# ver. 0.8.4 user system total real dtree 4923.530000 198.290000 5121.820000 (7917.143121)
# ver. 0.9.0 user system total real dtree 35.210000 0.310000 35.520000 ( 37.066986)
決定木が速くなると、Random ForestやAdaboost(Rumaleでは弱学習器に決定木を使用している)も速くなるので、この改善は大きい。
おわりに
当初Extensionを、C99なC言語で書いていたが、Travis CIでコケてしまった。コンパイルオプションで「-std=gnu99」とか付けることも考えたが、環境によってどうなるかわからなかったので、C90で書き直した。ExtensionはC90で書くのが安全と思われる。またMakefileなどのビルドの部分は、rake-compilerが良い感じにしてくれるので、簡単だった。たぶん、MacとかLinuxでは無事にインストールできる...と思う。
Rumaleの0.9系は、どうしても高速化できない部分(主に行列・ベクトル演算で書くことが難しくNumo::NArrayの恩恵が得られない部分)をExtensionで書き直すことを考えている。また、並列化できる箇所もあるので、Parallel gemなどを導入することを考えている。