2017/8/13追記: LSHForestはパフォーマンスがよろしくないため、0.19からDEPRECATEDとなった。0.21から削除されるようなので、使用しないほうが良い。
scikit-learnでは、ver. 0.16から近似最近傍探索手法のLSHForestが実装されている。LSHForestは、ハッシングによる近似最近傍探索の代表的な手法であるLocality Sensitive Hashing(LSH)をベースにした、木構造の近似最近傍探索手法である。LSHは特徴ベクトルをRandom Projectionと閾値処理で0と1の短いバイナリベクトルに変換し、これをキーとしてハッシュテーブルを作る。LSHForestでは、LSHによって得られれたバイナリベクトルから木構造(あるビットが0か1で二分木が作れる、これをLSHTreeと呼ぶ)を複数個つくる。検索クエリが与えられると、それぞれのLSHTreeから候補を割り出して、それら候補を距離で並べることで検索結果を得る。ちなみに、コサイン距離(1-コサイン類似度)による検索しか行えない。この点では、FLANNの方が柔軟である。 それでは、検索インデックスを作成するスクリプトは、次のようになる。これをgen_idx.pyとする。
#!/usr/bin/env python # -*- coding: utf-8 -*- from sklearn.datasets import load_svmlight_file from sklearn.neighbors import LSHForest from sklearn.externals import joblib def main(): # MNIST(手書き数字画像)データセットを読み込む。 # ※検索対象の例として使う。 targets, _ = load_svmlight_file('mnist.scale') # LSHはアルゴリズム中でRandom Projectionを用いるので、 # 再現性を確保したい場合は、random_stateに何か与えると良い。 search_idx = LSHForest(random_state=1984) # 検索インデックスを作成する。 search_idx.fit(targets) # 検索インデックスを保存する。 joblib.dump(search_idx, 'lshtree.pkl.cmp', compress=True) if __name__ == '__main__': main()
検索インデックスを読み込み、検索するスクリプトは次のようになる。これをsearch.pyとする。
#!/usr/bin/env python # -*- coding: utf-8 -*- from sklearn.datasets import load_svmlight_file from sklearn.neighbors import LSHForest from sklearn.externals import joblib def main(): # 検索クエリとして、MNISTのテストデータを読み込む。 queries, q_labels = load_svmlight_file('mnist.scale.t') # 検索対象データのラベルを読み込む。 # ※あとで検索結果を確認したいためで、検索には必要ない。 _, t_labels = load_svmlight_file('mnist.scale') # 検索インデックスを読み込む。 search_idx = joblib.load('lshtree.pkl.cmp') # 各クエリの5-近傍を検索する。 dists, ids = search_idx.kneighbors(queries, n_neighbors=5) # 検索クエリの一番目のデータで、検索結果を確認する。 print(q_labels[0]) print(t_labels[ids[0,:]]) print(ids[0,:]) print(dists[0,:]) if __name__ == '__main__': main()
実行してみると、検索クエリと同じ「7」のMNISTデータが検索できているのがわかる。
$ ./gen_idx.py $ ./search.py 7.0 [ 7. 7. 7. 7. 7.] [15260 16186 14563 9724 31073] [ 0.08516171 0.09843745 0.10159849 0.10445509 0.10882645]
LSHForestにも、LSHTreeの数などのパラメータがあるが、論文中で書かれている値であったりと、デフォルトで問題なさそう。 内部のLSHとしては、32ビットのバイナリベクトルに変換する様子。 FLANNと違って、検索時に、元の検索対象データを必要としないのが良い。 ちなみに、検索インデックスに新たにデータを加えるときは、partial_fitメソッドを使う。