洋食の日記

「た・である」調ではなく「です・ます」調で書きはじめれば良かったなと後悔してる人のブログです

Rumaleと同様のインターフェースでtorch.rbが使えるRumale::Torchを作った

はじめに

Rubyには深層学習をあつかうgemがいくつかある。そのなかでtorch.rbは、riceを利用してLibTorchをbinding libraryしたものである。LibTorchは、PyTorchのC++ APIといえるもので、version 1.5からstableなAPIとして提供されており、今後の発展と開発の継続が期待できる。Rumaleでも何かできないかと考え、torch.rbをwrapして、Rumaleに実装された機械学習アルゴリズムと同様に、fitとpredictで分類や回帰ができるものを作ってみた。

rumale-torch | RubyGems.org | your community gem host

使い方

torch.rbは、LibTorchを必要とする。Macでhomebrewを使っていれば、以下のコマンドでインストールできる。 ここで、一緒にインストールしているautomakeは、torch.rbが依存しているriceで必要となる。 その他の環境でのインストールは、torch.rbのREADMEに詳しく書かれている

$ brew install automake libtorch

Rumale::Torchはgemコマンドでインストールできる。Runtime依存で、Rumaleとtorch.rbがインストールされる。

$ gem install rumale-torch

torch.rbとRumale::Torchを使った、単純な多層パーセプトロンによる分類の例を示す。 まずは、実験のためのデータセットをダウンロードする。 LIBSVM Dataにあるpendigitsデータセットを用いる。

$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/pendigits
$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/pendigits.t

訓練のためのコードは、以下の通りである。

require 'rumale'
require 'rumale/torch'

# torch.rbの乱数の種を固定する.
Torch.manual_seed(1)

# 使用するデバイスを定義する. 
device = Torch.device('cpu')

# 訓練データを読み込む.
# pendigitsデータセットは16次元の特徴ベクトルが10個のクラスに分けられている.
x, y = Rumale::Dataset.load_libsvm_file('pendigits')

# torch.rbの作法に従ってニューラルネットを定義する.
class MyNet < Torch::NN::Module
  def initialize
    super
    @dropout = Torch::NN::Dropout.new(p: 0.5)
    @fc1 = Torch::NN::Linear.new(16, 128)
    @fc2 = Torch::NN::Linear.new(128, 10)
  end

  def forward(x)
    x = @fc1.call(x)
    x = Torch::NN::F.relu(x)
    x = @dropout.call(x)
    x = @fc2.call(x)
    Torch::NN::F.softmax(x)
  end
end

# 定義したニューラルネットを作成する.
net = MyNet.new.to(device)

# ニューラルネットをRumale::Torchに渡して分類器を作成する.
classifier = Rumale::Torch::NeuralNetClassifier.new(
  model: net, device: device,
  batch_size: 10, max_epoch: 50, validation_split: 0.1,
  verbose: true
)

# 分類器を学習する.
classifier.fit(x, y)

# torch.rbとRumale::Torchのものそれぞれ, 学習した分類器を保存する.
Torch.save(net.state_dict, 'pendigits.pth')
File.binwrite('pendigits.dat', Marshal.dump(classifier))

これを実行すると、torch.rbで定義した多層パーセプトロンによる分類器が学習される。 verboseをtrueとしたので、各epochでロス関数の値などが表示される。

$ ruby train.rb
epoch:  1/50 - loss: 0.2073 - accuracy: 0.3885 - val_loss: 0.2074 - val_accuracy: 0.3853
epoch:  2/50 - loss: 0.1973 - accuracy: 0.4883 - val_loss: 0.1970 - val_accuracy: 0.4893
epoch:  3/50 - loss: 0.1962 - accuracy: 0.4997 - val_loss: 0.1959 - val_accuracy: 0.5013

...

epoch: 50/50 - loss: 0.1542 - accuracy: 0.9199 - val_loss: 0.1531 - val_accuracy: 0.9293

次に、テストデータセットのラベルを推定する。コードは以下の通りである。

require 'rumale'
require 'rumale/torch'

# 訓練のときと同様にtorch.rbの作法に従いニューラルネットを定義する.
class MyNet < Torch::NN::Module
  def initialize
    super
    @dropout = Torch::NN::Dropout.new(p: 0.5)
    @fc1 = Torch::NN::Linear.new(16, 128)
    @fc2 = Torch::NN::Linear.new(128, 10)
  end

  def forward(x)
    x = @fc1.call(x)
    x = Torch::NN::F.relu(x)
    # x = @dropout.call(x)
    x = @fc2.call(x)
    Torch::NN::F.softmax(x)
  end
end

# 定義したニューラルネットのインスタンスを作成し,
# 保存したものを読み込む.
net = MyNet.new
net.load_state_dict(Torch.load('pendigits.pth'))

# Rumale::Torch側も学習したものを読み込む.
# model = でニューラルネットをセットする.
classifier = Marshal.load(File.binread('pendigits.dat'))
classifier.model = net

# テストデータセットを読み込む.
x_test, y_test = Rumale::Dataset.load_libsvm_file('pendigits.t')

# テストデータのラベルを推定する.
p_test = classifier.predict(x_test)

# 評価のために正確度を計算される.
accuracy = Rumale::EvaluationMeasure::Accuracy.new.score(y_test, p_test)
puts(format("Accuracy: %2.1f%%", accuracy * 100))

これを実行すると、テストデータでの分類の正確度が出力される。

$ ruby test.rb
Accuracy: 91.2%

回帰のためのNeuralNetworkRegressorもあり、使い方は同様である。 torch.rbで定義したニューラルネットを渡すと、Rumaleと同様に、fitとpredictで学習と推定ができる。 torch.rbでは、行列などの表現にTensorクラスを利用するが、Rumale::Torchを利用している限りは、Numo::NArrayだけでよい。

おわりに

PyTorchは、柔軟性が高く、深層学習の実装において、細かいところまで作り込むことができる。そのぶん、コードの記述量は多く、多層パーセプトロンのようなシンプルなものでも、コードが冗長なものになる(torch.rbでも同様である)。これに対して、PyTorchLightningやskorchといったものが開発されている。深層学習ライブラリは、Theanoの頃から、Lasagneやnolearn、Tensorflow統合前のKeras、Tensorflowではtflearnといった、記述を容易にするライブラリが開発されている。Rumale::Torchもこの流れのなかにあるが、できることは全然ないといっていいほど少ない状態なので、今後も地道に開発を続けていく。

github.com