読者です 読者をやめる 読者になる 読者になる

洋食の日記

洋食のことではなく、技術メモを書きます。たまにどうでも良いことも書きます。

scikit-learnで近似最近傍探索したいときはLSHForestがある

scikit-learnでは、ver. 0.16から近似最近傍探索手法のLSHForestが実装されている。LSHForestは、ハッシングによる近似最近傍探索の代表的な手法であるLocality Sensitive Hashing(LSH)をベースにした、木構造の近似最近傍探索手法である。LSHは特徴ベクトルをRandom Projectionと閾値処理で0と1の短いバイナリベクトルに変換し、これをキーとしてハッシュテーブルを作る。LSHForestでは、LSHによって得られれたバイナリベクトルから木構造(あるビットが0か1で二分木が作れる、これをLSHTreeと呼ぶ)を複数個つくる。検索クエリが与えられると、それぞれのLSHTreeから候補を割り出して、それら候補を距離で並べることで検索結果を得る。ちなみに、コサイン距離(1-コサイン類似度)による検索しか行えない。この点では、FLANNの方が柔軟である。 それでは、検索インデックスを作成するスクリプトは、次のようになる。これをgen_idx.pyとする。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from sklearn.datasets import load_svmlight_file
from sklearn.neighbors import LSHForest
from sklearn.externals import joblib

def main():
  # MNIST(手書き数字画像)データセットを読み込む。
  # ※検索対象の例として使う。
  targets, _ = load_svmlight_file('mnist.scale')

  # LSHはアルゴリズム中でRandom Projectionを用いるので、
  # 再現性を確保したい場合は、random_stateに何か与えると良い。
  search_idx = LSHForest(random_state=1984)
  
  # 検索インデックスを作成する。
  search_idx.fit(targets)

  # 検索インデックスを保存する。
  joblib.dump(search_idx, 'lshtree.pkl.cmp', compress=True)

if __name__ == '__main__':
  main()

検索インデックスを読み込み、検索するスクリプトは次のようになる。これをsearch.pyとする。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from sklearn.datasets import load_svmlight_file
from sklearn.neighbors import LSHForest
from sklearn.externals import joblib

def main():
  # 検索クエリとして、MNISTのテストデータを読み込む。
  queries, q_labels = load_svmlight_file('mnist.scale.t')

  # 検索対象データのラベルを読み込む。
  # ※あとで検索結果を確認したいためで、検索には必要ない。
  _, t_labels = load_svmlight_file('mnist.scale')
  
  # 検索インデックスを読み込む。
  search_idx = joblib.load('lshtree.pkl.cmp')

  # 各クエリの5-近傍を検索する。
  dists, ids = search_idx.kneighbors(queries, n_neighbors=5)

  # 検索クエリの一番目のデータで、検索結果を確認する。
  print(q_labels[0])
  print(t_labels[ids[0,:]])
  print(ids[0,:])
  print(dists[0,:])

if __name__ == '__main__':
  main()

実行してみると、検索クエリと同じ「7」のMNISTデータが検索できているのがわかる。

$ ./gen_idx.py
$ ./search.py
7.0
[ 7.  7.  7.  7.  7.]
[15260 16186 14563  9724 31073]
[ 0.08516171  0.09843745  0.10159849  0.10445509  0.10882645]

LSHForestにも、LSHTreeの数などのパラメータがあるが、論文中で書かれている値であったりと、デフォルトで問題なさそう。 内部のLSHとしては、32ビットのバイナリベクトルに変換する様子。 FLANNと違って、検索時に、元の検索対象データを必要としないのが良い。 ちなみに、検索インデックスに新たにデータを加えるときは、partial_fitメソッドを使う。