洋食の日記

「た・である」調ではなく「です・ます」調で書きはじめれば良かったなと後悔してる人のブログです

Rumaleのロジスティック回帰の最適化手法にL-BFGSを追加した

はじめに

Rumaleのロジスティック回帰のsolverにL-BFGSを追加し、これをデフォルトとして、ver. 0.22.0 としてリリースした。Rumaleのロジスティック回帰では、最適化に確率的勾配降下法(Stochastic Gradient Descent, SGD)を用いていた。データによって、反復回数やミニバッチの大きさなど、ハイパーパラメータを調整する必要があり、手軽に使える感じではなかった。また、Scikit-learnのロジスティック回帰では、L-BFGSを用いるのがデフォルトになっている(以前はLIBLINEARによるものがデフォルトだった)。そこで、Rumaleのロジスティック回帰にもL-BFGSのものを追加した。また、L-BFGSを用いる場合には、多値分類を多項ロジスティック回帰で実現する。

rumale | RubyGems.org | your community gem host

使い方

Rumaleをインストールすれと、L-BFGSのために依存でlbfgb.rbも一緒にインストールされる。

$ gem install rumale

多値分類の例として、LIBSVM DATAのpendigitsデータセットを分類する。

require 'rumale'

# LIBSVM形式のpendigitsデータを読み込む.
x, y = Rumale::Dataset.load_libsvm_file('pendigits')

# ランダム分割で訓練とテストに分割する.
ss = Rumale::ModelSelection::ShuffleSplit.new(n_splits: 1, test_size: 0.2, random_seed: 1)
train_ids, test_ids = ss.split(x, y).first
x_train = x[train_ids, true]
y_train = y[train_ids]
x_test = x[test_ids, true]
y_test = y[test_ids]

# ロジスティック回帰による分類器を学習する.
cls = Rumale::LinearModel::LogisticRegression.new
cls.fit(x_train, y_train)

# テストデータで正確度を計算する.
puts "Accuracy: %.3f" % cls.score(x_test, y_test)
$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/pendigits
$ time ruby logit.rb
Accuracy: 0.956
ruby logit.rb  1.15s user 0.16s system 99% cpu 1.317 total

従来のSGDを用いた学習を行う場合はsolverオプションに 'sgd' を渡す。

# SGDはNumo::NArrayでの行列積や内積を行うので, 
# 高速化のために, Numo::Linalgをロードして, OpenBLASを叩くようにする.
require 'numo/openblas'
# SGDでは多値分類はone-vs-rest法を用いる. 
# 各2値分類器の学習を並行して行えるようにParallel gemを使う.
require 'parallel'

...

cls = Rumale::LinearModel::LogisticRegression.new(
  solver: 'sgd',
  learning_rate: 0.001, momentum: 0.9,
  max_iter: 1000, batch_size: 10,
  n_jobs: -1,
  random_seed: 1
)

...

SGDによる最適化の実行時間は、max_iterとbatch_sizeにより変化するが、Numo::LinalgやParallelを使ってもL-BFGSのよりも遅くなっている。

$ time ruby logit.rb
Accuracy: 0.935
ruby logit.rb  394.07s user 56.87s system 574% cpu 1:18.49 total

おわりに

月に一度はRumaleの新しいバージョンをリリースしたいと思っていたが、11月はギリギリになってしまった。最適化にL-BFGSを使うアイディアは以前からあったが、そのベースとなるL-BFGS-Bをgem化するのに、思ったよりも時間がかかってしまった。Numo::NArrayとNumo::Linalg、そしてlbfgs.rbを使えば、Rubyで多くの機械学習アルゴリズムを実装できると考える。Rumaleも色々と発展を予定している。

github.com