洋食の日記

「た・である」調ではなく「です・ます」調で書きはじめれば良かったなと後悔してる人のブログです

Rumale::SVMに線形One-Class SVMを追加した

はじめに

LIBLINEARがバージョンアップして、線形One-Class SVM(Linear one-class support vector machine, LOCSVM)が追加された。これに合わせて、Numo::LiblinearとRumale::SVMをアップデートして、LOCSVMに対応した。LOCSVMは、データの分布を推定するような、教師なし学習である。原点からデータまでの距離をマージンと考え、SVMを学習することで、原点に向かってデータの端に決定境界ができるような感じになる。与えられたデータが、学習したデータ分布にちかしいかを判定できるので、応用として外れ値検出がある。

numo-liblinear | RubyGems.org | your community gem host

rumale-svm | RubyGems.org | your community gem host

Numo::Liblinearは、RubyとLIBLINEARのデータの受け渡しに、Numo::NArrayを使う薄いラッパーのようなものなので、以降、APIをRumaleに合わせたRumale::SVMをつかってLOCSVMを試す。

使い方

Gemコマンドでインストールできる。別の外部ライブラリをインストールする必要はない。

$ gem install rumale-svm

LOCSVMでの外れ値検出を試すために人工データを作る。原点から離れて正常データがあり、そこから、離れたところに外れ値データがある(外れ値データであるかは本来不明)。

require 'rumale'
require 'rumale/svm'

# プロットのためにnumo-gnuplotを用いる.
# $ brew install gnuplot
# $ gem install numo-gnuplot
require 'numo/gnuplot' 

# 人工データを作る.
x = Rumale::Utils.rand_normal([90, 2], Random.new(1), 3.0, 0.5)
x_out = Rumale::Utils.rand_normal([10, 2], Random.new(1), 1.0, 0.1)
samples = Numo::NArray.vstack([x, x_out])

# 人工データをpngで出力する.
Numo.gnuplot do
  set(terminal: 'png')
  set(output: 'ocsvm_.png')
  plot('[0:5] [0:5]', samples[true, 0], samples[true, 1], pt: 6, ps: 1)
end

f:id:yoshoku:20200815222421p:plain
人工データ

作成した人工データをLOCSVMに与えてデータ分布を学習する。LOCSVMのハイパーパラメータにnuがある。これは正則化のためのパラメータだが、外れ値がどの程度含まれているかを表す。人工データは100点で、そのうち10点が外れ値なので、0.1を与えた。これらは本来わからないので、実データでは、このハイパーパラメータのnuの調整が重要となる。

# LOCSVMを定義する.
ocsvm = Rumale::SVM::LinearOneClassSVM.new(nu: 0.1)

# LOCSVMで人工データの分布を学習する.
ocsvm.fit(samples)

# 人工データのラベルを推定する.
# 1であれば正常データ, -1であれば外れ値データといったところ.
labels = ocsvm.predict(samples)

# 結果をpngで出力する.
a = samples[true, 0]
b = samples[true, 1]
plots = labels.to_a.uniq.sort.map do |l|
  [a[labels.eq(l)], b[labels.eq(l)], t: l.to_s, ps: 2]
end

Numo.gnuplot do
  set(terminal: 'png')
  set(output: 'ocsvm.png')
  plot('[0:5] [0:5]', *plots)
end

これを実行すると、以下のようになる。正常値・外れ値としたものがきれいにわかれている。

f:id:yoshoku:20200815225724p:plain
推定結果

当たり前だが、実際には訓練に使ったものをテストに使うことはない。正常値・外れ値に見立てたデータを与えみても上手くいった。 特徴抽出なり変換なりで、はっきり外れ値が分かれるような場合には、LOCSVMは有効だと考える。

pp ocsvm.predict([[2.3, 2.8], [0.8, 0.5]])
# => 
# Numo::Int32#shape=[2]
# [1, -1]

おわりに

正直、LIBLINEARがバージョンアップされて、新しいアルゴリズムが追加されるとは思わなかった。One-class SVMは、外れ値検出が有名だが、文書分類に応用した研究もあるので、使い方によってはおもしろいことができると思う。あと、deep化したものもあったりして、研究としても続いているようだ。

github.com