洋食の日記

「た・である」調ではなく「です・ます」調で書きはじめれば良かったなと後悔してる人のブログです

Rumaleの乱数生成を無難なものにした

はじめに

機械学習アルゴリズムでは、乱数でベクトルを初期化したり、ランダムサンプリングしたりなど、乱数生成をアルゴリズム中に含むものが多い。Rumaleの多くもそんな感じで、クラスのインスタンス変数にRandomクラスによる乱数生成器を持っている。これの扱いを良い感じにした。

rumale | RubyGems.org | your community gem host

良い感じにしたとは?

Rumaleのクラス内で、乱数生成器は、簡単に書くと以下のようになっていた。

class Hoge
  attr_reader :weight

  def initialize(random_seed: 0)
    @rng = Random.new(random_seed)
  end

  def fit
    # 重みを初期化する
    @weight = @rng.rand
  end
end

インスタンス生成として乱数生成器を使いまわすかたちになっている。例えば、乱数で重みベクトルを初期化するようなアルゴリズムの場合に、学習データが同じであっても、fitメソッドを呼び出すたびに、学習結果が異なる状態にあった。

> h = Hoge.new
> h.fit
> h.weight
=> 0.5488135039273248
> h.fit
> h.weight
=> 0.7151893663724195

同じデータ・パラメータでfitメソッドを何度も呼び出すコトがあるか?は置いておいて、同じデータを入力したら同じ結果が出て欲しいのが人情なので、fitメソッドを以下のように修正した。

def fit
  sub_rng = @rng.dup
  # 重みを初期化する
  @weight = sub_rng.rand
end

initializeで作成した乱数生成器をコピーするようにした。これで、fitメソッドを呼び出すたびに、学習結果(乱数での初期化やランダムサンプリング)が変化することはなくなった。

> h = Hoge.new
> h.fit
> h.weight
=> 0.5488135039273248
> h.fit
> h.weight
=> 0.5488135039273248

他にこの現象の回避方法はいくらでもあるだろうが、すでにMarshal.dumpされた学習済みモデルがあることを考えると、最短距離の対処法だと思う。

おわりに

こういう下回り的なところも気をつけていかないと。

github.com