洋食の日記

記事をです・ます調で書き始めれば良かったと後悔している人のブログです

Rumaleに多層パーセプトロンなニューラルネットワークを追加した

はじめに

Rumaleに多層パーセプトロンによる分類・回帰を追加した。活性化関数にはReLUを、正則化にはDropout、最適化にはAdamというスタンダードな構成にした。あわせて、入力のバリデーションをNumo::NArrayだけでなくRuby Arrayも受け付けるように修正して、version 0.14.0としてリリースした。

rumale | RubyGems.org | your community gem host

使い方

Rumaleはgemコマンドでインストールできる。データの取得にred-datasetsを使いたいので、一緒にインストールする。

$ gem install rumale red-datasets-numo-narray

USPSという手描き数字画像によるデータセットで分類の例を示す。

require 'datasets-numo-narray'
require 'rumale'

# データを読み込む.
dataset = Datasets::LIBSVM.new('usps').to_narray
labels = Numo::Int32.cast(dataset[true, 0])
samples = Numo::DFloat.cast(dataset[true, 1..-1])

# データセットをランダムに訓練とテストに分割する.
ss = Rumale::ModelSelection::ShuffleSplit.new(
  n_splits: 1, test_size: 0.2, random_seed: 1
)
train_ids, test_ids = ss.split(samples, labels).first

train_s = samples[train_ids, true]
train_l = labels[train_ids]
test_s = samples[test_ids, true]
test_l = labels[test_ids]

# 多層パーセプトロンによる分類器を用意する.
# ※
# 隠れ層のユニット数はhidden_unitsで与える.
# 以下の例では, 2層の隠れ層を持ち, ユニット数がそれぞれ256と128になる。
# 繰り返し回数max_iterとミニバッチの大きさbatch_sizeは, 
# 隠れ層と同様に慎重に設定したほうが良い.
# verboseをtrueにすると学習過程のロス関数の値が表示される.
mlp = Rumale::NeuralNetwork::MLPClassifier.new(
  hidden_units: [256, 128],
  max_iter: 1000, batch_size: 50,
  verbose: true, 
  random_seed: 1
)

# 多層パーセプトロンによる分類器を学習する.
mlp.fit(train_s, train_l)

# テストセットで分類の正確度を計算する.
puts("Accuracy: %.4f" % mlp.score(test_s, test_l))

これを実行すると、次のようになる。学習が進むにつれてロスが小さくなり、テストセットでは96%程度の正確さで分類できている。

[MLPClassifier] Loss after 10 iterations: 2.2093608343275157
[MLPClassifier] Loss after 20 iterations: 1.8448493780493846
[MLPClassifier] Loss after 30 iterations: 1.5683160562017802
...
[MLPClassifier] Loss after 980 iterations: 0.2764881083070085
[MLPClassifier] Loss after 990 iterations: 0.1944084487714543
[MLPClassifier] Loss after 1000 iterations: 0.11276357417884826
Accuracy: 0.9616

あわせて

入力のバリデーションを緩めた。今までは、サンプルなどはNumo::NArrayでないと弾いていたが、Ruby Arrayでも受け入れるようにした(内部の計算ではNumo::NArrayに変換したものを使う)。また、実数なハイパーパラメータはFloatでないと弾いていたがこれもやめた。Scikit-learnもわりと緩い感じで、変なものが入力されたら、問答無用でコケていた。「APIドキュメントは書いてるから想定外の入力きてコケても許してね」ぐらいの気持ちで、利便性のためバリデーションを緩めた。

おわりに

Rumaleも様々な手法があり、そこそこ大きなライブラリとなってきた。一度開発の手は止めて、yardocによるAPIドキュメントだけでなく、ユーザーガイドを書いていこうと思っている。

github.com