洋食の日記

「だ・である」調ではなく「です・ます」調で書きはじめれば良かったなと後悔してる人のブログです

Pythonで近似最近傍探索を試したいときはpyflannがちょうど良い

近似最近傍探索とは近似的に近いものを検索してくる技術で、普通に距離を計算して並べて近くにあるものを探すより速い。代表的なライブラリにFLANN(Fast Library for Approximate Nearest Neighbors)があり、これのPythonバインディングがpyflannになる。FLANNの開発は2013年から止まっているのに(もともとブリティッシュコロンビア大の研究がベースなので研究プロジェクトが一段落したんだと思われる)、pyflannは今でも開発されているのがおもしろい。FLANN自体は、Debian GNU/Linuxとかでもパッケージになってて、pyflannもpipにあるのでインストールは楽ちん。枯れ具合がちょうど良い。

$ sudo apt-get insatll libflann-dev
$ sudo pip install pyflann

では、まず、インデックスを作成する。これをgen_idx.pyとする。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pyflann
from sklearn.datasets import load_svmlight_file

def main():
  # とりあえず検索対象のデータとしてMNISTのデータを使う。
  # MNISTのデータはLIBSVM Dataのページからダウンロードできる。
  targets, _ = load_svmlight_file('mnist.scale')
  targets = targets.toarray()

  # 距離を設定する。一般的なユークリッド距離(euclidean)の他に、
  # マンハッタン距離(manhattan)とか
  # ヒストグラムインターセクションカーネル(hik)とかがある。
  pyflann.set_distance_type('euclidean')

  # 検索インデックスを作成する。
  # algorithmはHierarchical K-Means Clustering Treeを選択した。
  # 他にもRandomized KD-Treeなどがある。
  # centers_initはK-Meansの初期値の設定方法を指定する(デフォルトはランダム)。
  # 再現性を確保したい場合にはrandom_seedを指定すると良い。
  search_idx = pyflann.FLANN()
  params = search_idx.build_index(targets, algorithm='kmeans', 
    centers_init='kmeanspp', random_seed=1984)
  
  # 検索インデックスのパラメータを見てみる。
  print(params)

  # 作成した検索インデックスを保存する。
  search_idx.save_index('mnist.idx')

if __name__ == '__main__':
  main()

これを実行すると、検索インデックスが作成される。

$ ./gen_idx.py
{'branching': 32, 'cb_index': 0.5, 'centers_init': 'default', 'log_level': 'warning', 'algorithm': 'kmeans', 
...
'target_precision': 0.8999999761581421, 'sample_fraction': 0.10000000149011612, 'iterations': 5, 'random_seed': 1984, 
'checks': 32}

作成した検索インデックスを読み込んで検索を行う。これをsearch.pyとする。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pyflann
from sklearn.datasets import load_svmlight_file

def main():
  # 検索対象となるMNISTデータを読み込む.
  targets, t_labels = load_svmlight_file('mnist.scale')
  targets = targets.toarray()

  # 検索クエリとなるMNISTデータを読み込む.
  queries, q_labels = load_svmlight_file('mnist.scale.t')
  queries = queries.toarray()

  # 距離を設定する。
  pyflann.set_distance_type('euclidean')

  # 作成した検索インデックスを読み込む。
  search_idx = pyflann.FLANN()
  search_idx.load_index('mnist.idx', targets)

  # 近似最近傍探索を行う。ここでは5-近傍を探索している。
  result, dists = search_idx.nn_index(queries, num_neighbors=5)

  # 結果を表示する。1番目の検索クエリのラベルと、その5-近傍のラベルと距離を見てみる。
  print(q_labels[0])
  print(t_labels[result[0,:]])
  print(result[0,:])
  print(dists[0,:])

if __name__ == '__main__':
  main()

実行してみると、検索クエリと同じ「7」のMNISTデータが検索できているのがわかる。

$ ./search.py
7.0
[ 7.  7.  7.  7.  7.]
[53843 38620 16186 27059 30502]
[  7.0398414    9.69496007  11.4449976   11.49352929  14.02158225]

他にも細かくパラメータを設定できる。 詳細は、FLANNの公式ページのユーザマニュアルを見ると良い。 MATLABバインディングの説明が参考になる。 FLANNの検索では、(当然だけど)検索インデックスの他に検索対象データも必要で、 検索対象データが高次元かつ大規模であると、検索対象データ自体がメモリに乗るかという問題が発生する。 対策は計算機環境によって色々考えられるだろうけど、どうしても検索対象データ自体を小さくしたい場合は、 Locality Sensitive Hashingに代表されるような、ハッシングによる近似最近傍探索を検討すると良い。