洋食の日記

洋食のことではなく、技術メモを書きます。たまにどうでも良いことも書きます。

Pandasで「ビジネス活用事例で学ぶデータサイエンス入門」を勉強する(第6章)

はじめに

マーケティング寄りのデータ分析の知識を補うため、以下の本で勉強を始めた。事例ベースな内容で、とても読みやすい。 Pandasも習得したいので、Pandasに翻訳しながら読み進めている。今回は第6章を勉強した。

ビジネス活用事例で学ぶ データサイエンス入門

ビジネス活用事例で学ぶ データサイエンス入門

ちなみに、CSVファイルなどのデータは、本のサポートページ(SBクリエイティブ:ビジネス活用事例で学ぶ データサイエンス入門)で配布されている。

翻訳したコード

第6章では、重回帰分析を扱っている。作業はJupyter Notebook上で行った。回帰にはScikit-Learnを用いる。 本ではひとまず、散布図を出力して、テレビ広告・雑誌広告とアプリのインストール数との関係を見ている。

%matplotlib inline # Jupyter Notebookでmatplotlibの図が表示されないときにつけるおまじない
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression

# データを読み込む。
ad_data = pd.read_csv('ad_result.csv')
ad_data.head(2)

 

month tvcm magazine install
0 2013-01 6358 5955 53948
1 2013-02 8176 6069 57300

 

# 散布図の出力にはscatterメソッドを用いる。
# x軸とy軸は、データのヘッダにある列名で指定できる。
ad_data.plot.scatter(x='tvcm', y='install')
ad_data.plot.scatter(x='magazine', y='install')

f:id:yoshoku:20170716150546p:plain f:id:yoshoku:20170716150557p:plain

散布図から何かしらの相関があることが伺える。回帰分析によりモデル化する。

# Numpy形式でデータを取り出す。
X=np.array(ad_data[['tvcm', 'magazine']])
y=np.array(ad_data['install'])

# 回帰モデルを得る。
model=LinearRegression().fit(X, y)

# 回帰モデルの切片と係数を確認する。
#  インストール数とテレビ広告・雑誌広告には以下の関係があることがわかる。
#  「インストール数 = 188.174 + 1.361 x テレビ広告 +7.250 x 雑誌広告」
print(model.intercept_)
print(model.coef_)
# 188.17427483
# [ 1.3609213   7.24980915]

# 決定係数を確認する。
#  1に近いので、当てはまりが良いことがわかる。
print(model.score(X,y))
# 0.937901430104

別の選択肢としては、StatsModelのOLS(Ordinary Least Squares)を用いる方法がある。 StatsModelの回帰分析は、summaryメソッドを持っており、R言語のlm関数に似た結果を得られる。

from statsmodels.regression.linear_model import OLS
results=model = OLS(y, X).fit()
print(results.summary())
#                             OLS Regression Results                            
# ==============================================================================
# Dep. Variable:                      y   R-squared:                       1.000
# Model:                            OLS   Adj. R-squared:                  0.999
# Method:                 Least Squares   F-statistic:                     8403.
# Date:                Sun, 16 Jul 2017   Prob (F-statistic):           5.12e-14
# Time:                        16:15:12   Log-Likelihood:                -84.758
# No. Observations:                  10   AIC:                             173.5
# Df Residuals:                       8   BIC:                             174.1
# Df Model:                           2                                         
# Covariance Type:            nonrobust                                         
# ==============================================================================
#                  coef    std err          t      P>|t|      [0.025      0.975]
# ------------------------------------------------------------------------------
# x1             1.3540      0.405      3.347      0.010       0.421       2.287
# x2             7.2892      0.476     15.320      0.000       6.192       8.386
# ==============================================================================
# Omnibus:                        1.009   Durbin-Watson:                   0.876
# Prob(Omnibus):                  0.604   Jarque-Bera (JB):                0.804
# Skew:                           0.539   Prob(JB):                        0.669
# Kurtosis:                       2.123   Cond. No.                         14.0
# ==============================================================================

おわりに

6章から、機械学習に関係する話がでてくる。PandasというよりもScikit-Learnが活躍する感じ。7章にある「本当の意味での正解データがないなかで、なんらかの(ビジネス上意味のある)示唆を出さなければならない」という記述は、「うっ」という苦しい気持ちになる。

Pandasで「ビジネス活用事例で学ぶデータサイエンス入門」を勉強する(第5章)

はじめに

マーケティング寄りのデータ分析の知識を補うため、以下の本で勉強を始めた。事例ベースな内容で、とても読みやすい。 Pandasも習得したいので、Pandasに翻訳しながら読み進めている。今回は第5章を勉強した。

ビジネス活用事例で学ぶ データサイエンス入門

ビジネス活用事例で学ぶ データサイエンス入門

ちなみに、CSVファイルなどのデータは、本のサポートページ(SBクリエイティブ:ビジネス活用事例で学ぶ データサイエンス入門)で配布されている。

翻訳したコード

第5章では、Web系のデータ分析では定番?のバナー広告のA/Bテストを扱っている。作業はJupyter Notebook上で行った。 カイ二乗検定が使われているが、Pandasにはカイ二乗検定はない様子なので、SciPyのstats.chi2_contingencyを使う。

%matplotlib inline # Jupyter Notebookでmatplotlibの図が表示されないときにつけるおまじない
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from scipy.stats import chi2_contingency

# データを読み込む。
ab_test_imp = pd.read_csv('section5-ab_test_imp.csv')
ab_test_goal = pd.read_csv('section5-ab_test_goal.csv')

# transition_idが同じものをくっつける。
ab_test_imp_goal = pd.merge(ab_test_imp, ab_test_goal, on=['transaction_id'], how='outer', suffixes=['','.g'])

# user_id.gがNaNだった場合に「0」それ以外では「1」として、クリックされたかを判定するフラグとする。
ab_test_imp_goal['is_goal'] = 1 - ab_test_imp_goal['user_id.g'].apply(np.isnan).astype(int)

# 集計してクリック率を計算する。
sum_goal = ab_test_imp_goal.groupby('test_case')['is_goal'].sum()
sz_user = ab_test_imp.groupby('test_case')['user_id'].size()
sum_goal / sz_user
# ※ Jupyter Notebook上には以下の用にクリック率が表示される。
#    Aが約8%で、Bが約11%となった。
# test_case
# A    0.080256
# B    0.115460
# dtype: float64

# カイ二乗検定により、クリック率に統計的な差があるかを確認する。
cr = ab_test_imp_goal.pivot_table(index='test_case',columns='is_goal', values='user_id', aggfunc='count')
# ※ crは以下のようなクロス集計結果となる
# is_goal    0        1
# test_case
# A          40592    3542
# B          38734    5056
#
chi2, p, dof, expected = chi2_contingency(cr.as_matrix())
print(chi2)
# 308.375052893
print(p)
# 4.93413963379e-69

chi2の値が本と同様の値となり、p値も本と同様にとても小さな値となった。 p値が小さいということで 「バナーAとバナーBでは、クリックされなかった(0)とクリックされた(1)の割合に差がない」という帰無仮説が棄却され、広告をバナーAとバナーBに分けたことで、クリック率が変わったと判断できる。 本では、この後、テストケースごとのクリック率を計算したりするが、今回は省略。

おわりに

Rのchisq.test関数と勝手が違ったので、手間取ってしまった。本ではカイ二乗検定の詳しい説明がないので、気になる方は別な統計本を買って勉強する必要がある。「統計的に有意な差」と「ビジネス的に意味があるか」という2つの観点が出てきたりして、ちょっと、この章は読みづらい感がある。

Pandasで「ビジネス活用事例で学ぶデータサイエンス入門」を勉強する(第4章)

はじめに

マーケティング寄りのデータ分析の知識を補うため、以下の本で勉強を始めた。事例ベースな内容で、とても読みやすい。 Pandasも習得したいので、Pandasに翻訳しながら読み進めている。今回は第4章を勉強した。

ビジネス活用事例で学ぶ データサイエンス入門

ビジネス活用事例で学ぶ データサイエンス入門

ちなみに、CSVファイルなどのデータは、本のサポートページ(SBクリエイティブ:ビジネス活用事例で学ぶ データサイエンス入門)で配布されている。

翻訳したコード

第4章では、クロス集計を扱っている。作業はJupyter Notebook上で行った。Pandasなコードは以下のとおり。

%matplotlib inline # Jupyter Notebookでmatplotlibの図が表示されないときにつけるおまじない
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# データを読み込む。
dau = pd.read_csv('section4-dau.csv')
user_info = pd.read_csv('section4-user_info.csv')

# user_idとapp_nameが同じものをくっつける。
dau_user_info = pd.merge(dau, user_info, on=['user_id', 'app_name'])

# 月次で集計するため、日付を月までにする。
dau_user_info['log_month'] = dau_user_info['log_date'].str.slice(0, 7)

# 性別で集計する。
dau_user_info.pivot_table(index='log_month',columns='gender', values='user_id', aggfunc='count')

 

gender F M
log_month
2013-08 47343 46842
2013-09 38027 38148

 

# 年代で集計する(表は省略)。
dau_user_info.pivot_table(index='log_month',columns='generation', values='user_id', aggfunc='count')

# 性別と年代で集計する(表は省略)。
dau_user_info.pivot_table(index='log_month',columns=['gender', 'generation'], values='user_id', aggfunc='count')

# デバイスで集計する(表は省略)。
dau_user_info.pivot_table(index='log_month',columns='device_type', values='user_id', aggfunc='count')

# 集計結果を可視化する。
dau_user_info.groupby(['log_date', 'device_type'])['app_name'].count().groupby('device_type').plot(legend=True,rot=45)

バイスでわけて、日付ごとのユーザー数を集計し、可視化したものが以下のとおり。 本と同様に、9月になってから、Androidユーザーが少なくなっている。

f:id:yoshoku:20170702213153p:plain

おわりに

pivot_tableで多重なクロス集計もできるので便利。今回も手こずったのは、plotの部分で、x軸のラベルを日付だけにしたかったけど諦め。細かいことは、matplotlib使えば良いんだろうけど、勉強中なので、しばらくはPandasのplotにこだわりたい。

Pandasで「ビジネス活用事例で学ぶデータサイエンス入門」を勉強する(第3章)

はじめに

マーケティング寄りのデータ分析の知識を補うため、勉強を開始した。「チュートリアル的な事例ベースの教材がないかな〜」と色々と探していたところ、ぴったりの良い本が見つかった。第1章と第2章には、データ分析がどういう仕事か書かれている。事例は第3章からとなる。

ビジネス活用事例で学ぶ データサイエンス入門

ビジネス活用事例で学ぶ データサイエンス入門

この本のみならず、多くの(マーケティング寄りの)データ分析の本では、R言語が使われている。PythonのPandasも習得したかったので、Pandasで翻訳しながら読み進めることにした。ちなみに、CSVファイルなどのデータは、本のサポートページ(SBクリエイティブ:ビジネス活用事例で学ぶ データサイエンス入門)で配布されている。

翻訳したコード

第3章では、売上を可視化・比較するため、ヒストグラムを作成する。作業はJupyter Notebook上で行った。Pandasのコードは以下のとおり。

%matplotlib inline # Jupyter Notebookでmatplotlibの図が表示されないときにつけるおまじない
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# データを読み込む。
dau = pd.read_csv('section3-dau.csv')
dpu = pd.read_csv('section3-dpu.csv')
install = pd.read_csv('section3-install.csv')

# DAUデータとInstallデータで、user_idとapp_nameが同じものをくっつける。
dau_install = pd.merge(dau, install, on=['user_id', 'app_name'])

# さらにDPUデータをくっつける。
dau_install_payment = pd.merge(dau_install, dpu, on=['log_date', 'app_name', 'user_id'], how='outer')

# 未課金ユーザの課金額をNAから0とする。
dau_install_payment = dau_install_payment.fillna(0)

# 月次で集計するため、日付を月までにする。
dau_install_payment['log_month'] = dau_install_payment['log_date'].str.slice(0, 7)
dau_install_payment['install_month'] = dau_install_payment['install_date'].str.slice(0, 7)

# 利用月・利用開始月・ユーザーIDでまとめる。
mau_payment = dau_install_payment.groupby(['log_month', 'user_id', 'install_month']).sum().reset_index()

# 利用月と利用開始月を確認することで、新規ユーザーか既存ユーザーかを判定する。
mau_payment['user_type'] = np.where(mau_payment['install_month']==mau_payment['log_month'], 'install', 'existing')

# 課金額を集計して可視化する。
mau_payment_summery = mau_payment.groupby(['log_month', 'user_type']).sum().reset_index()
mau_payment_summery.pivot('log_month', 'user_type').plot.bar(y='payment', stacked=True)

可視化した先月と今月の売上状況は以下のとおり。本では、この後、新規ユーザーの売上比較も行うが、今回は省略。 f:id:yoshoku:20170702153333p:plain

おわりに

Numpy/Scipy脳の私には、Pandasは、ちょっととっつきにくい。一番手こずったのは、plotの部分で、なかなか思ったような図が出せなかった。このへんは「慣れ」な気もしなくもない。逆にR言語に慣れてる人は、関数名の違いさえ把握してしまえば、Pandasを使いこなすのは余裕なのでは?

PyCallを使えばRubyでもKerasでDeep Learningができる

mrknさんが開発しているPyCallを使うと、RubyからPythonオブジェクトを操作できる。 Rubyから、Python機械学習・統計分析のツールを利用することを目的としており、 ネット上にもnumpyやscikit-learnを実行する例があがっている。

Rubyist Magazine - PyCall があれば Ruby で機械学習ができる

このPyCallで、Kerasを叩くことができれば、RubyでもDeep Learningできると思い試してみた。 まずはインストールから。

$ gem install --pre pycall

ちなみに、実行環境をまとめると、Ruby 2.4.0、PyCall 0.1.0.alpha.20170317、Python 3.6.0、Theano 0.9.0、Keras 2.0.2である。 試したのは、MNISTの手書き数字画像を、畳み込みニューラルネットで認識するサンプルコードである。

keras/mnist_cnn.py at master · fchollet/keras · GitHub

これを、細かいところは適宜はぶきながら、Ruby+PyCallで移植すると次のようになる。

require 'pycall/import'
include PyCall::Import

# Kerasの必要なものをimportする.
pyimport 'keras'
pyfrom 'keras.datasets', import: 'mnist'
pyfrom 'keras.models', import: 'Sequential'
pyfrom 'keras.layers', import: ['Dense', 'Dropout', 'Flatten']
pyfrom 'keras.layers', import: ['Conv2D', 'MaxPooling2D']

# MNISTは28x28の大きさの手書き数字画像で、各画像が10個のクラスにわけられている.
nb_classes = 10
img_rows = 28
img_cols = 28

# MNISTデータセットを読み込む.初回実行時はダウンロードするところから始まる.
(x_train, y_train), (x_test, y_test) = mnist.load_data.()

# データのreshapeは、元のコードでは...
#   pyfrom 'keras', import: 'backend'
#   backend.image_data_format.()
# の結果で処理を分けている.
# 試してみたところ "channels_last" だったので、
# そちらの処理を移植した.
x_train = x_train.reshape.(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape.(x_test.shape[0], img_rows, img_cols, 1)
# 型をfloat32にして、要素を[0.0,1.0]にする.
x_train = x_train.astype.('float32')
x_test = x_test.astype.('float32')
x_train /= 255
x_test /= 255

# ラベル情報をクラスベクトル形式にする.
y_train = keras.utils.to_categorical.(y_train, nb_classes)
y_test = keras.utils.to_categorical.(y_test, nb_classes)

# ネットワークを定義する.
model = Sequential.()
model.add.(Conv2D.(32, kernel_size: [3, 3], activation: 'relu', 
                   input_shape: [img_rows, img_cols, 1]))
model.add.(Conv2D.(64, kernel_size: [3, 3], activation: 'relu'))
model.add.(MaxPooling2D.(pool_size: [2, 2]))
model.add.(Dropout.(0.25))
model.add.(Flatten.())
model.add.(Dense.(128, activation: 'relu'))
model.add.(Dropout.(0.5))
model.add.(Dense.(nb_classes, activation: 'softmax'))

# ネットワークをコンパイルする.初回実行時はそれなりに時間がかかる.
model.compile.(loss: keras.losses.categorical_crossentropy,
              optimizer: keras.optimizers.Adadelta.(),
              metrics: ['accuracy'])

# ネットワークを学習する.
model.fit.(x_train, y_train,
          batch_size: 128,
          epochs: 10,
          verbose: 1,
          validation_data: [x_test, y_test])

# 分類性能を評価する.
score = model.evaluate.(x_test, y_test, verbose: 0)
print(sprintf("Test loss: %.6f\n", score[0]))
print(sprintf("Test accuracy: %.6f\n", score[1]))

これを実行すると、問題なくネットワークの学習が動き始める!

$ ruby keras_test.rb
Using Theano backend.
Using cuDNN version 5005 on context None
Mapped name None to device cuda: GeForce GTX 1080 (0000:06:00.0)
Train on 60000 samples, validate on 10000 samples
Epoch 1/10
60000/60000 [==============================] - 27s - loss: 0.3227 - acc: 0.9030 - val_loss: 0.0730 - val_acc: 0.9760
...
Epoch 10/10
60000/60000 [==============================] - 26s - loss: 0.0410 - acc: 0.9881 - val_loss: 0.0298 - val_acc: 0.9890
Test loss: 0.029782
Test accuracy: 0.989000

Kerasは、ブロックをつなげる感じで、簡単にネットワーク構造を定義できる。 Pythonに詳しくないRubyエンジニアが、Deep Learningを試してみるには、PyCall+Kerasが最高な気がする。 また、Pythonで学習したネットワークを、Ruby+PyCallで読み込んで使うという形にすれば、 容易にDeep Learningな機能をRailsアプリに組み込める、という感じで夢が広がる。

scikit-learnで近似最近傍探索したいときはLSHForestがある

scikit-learnでは、ver. 0.16から近似最近傍探索手法のLSHForestが実装されている。LSHForestは、ハッシングによる近似最近傍探索の代表的な手法であるLocality Sensitive Hashing(LSH)をベースにした、木構造の近似最近傍探索手法である。LSHは特徴ベクトルをRandom Projectionと閾値処理で0と1の短いバイナリベクトルに変換し、これをキーとしてハッシュテーブルを作る。LSHForestでは、LSHによって得られれたバイナリベクトルから木構造(あるビットが0か1で二分木が作れる、これをLSHTreeと呼ぶ)を複数個つくる。検索クエリが与えられると、それぞれのLSHTreeから候補を割り出して、それら候補を距離で並べることで検索結果を得る。ちなみに、コサイン距離(1-コサイン類似度)による検索しか行えない。この点では、FLANNの方が柔軟である。 それでは、検索インデックスを作成するスクリプトは、次のようになる。これをgen_idx.pyとする。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from sklearn.datasets import load_svmlight_file
from sklearn.neighbors import LSHForest
from sklearn.externals import joblib

def main():
  # MNIST(手書き数字画像)データセットを読み込む。
  # ※検索対象の例として使う。
  targets, _ = load_svmlight_file('mnist.scale')

  # LSHはアルゴリズム中でRandom Projectionを用いるので、
  # 再現性を確保したい場合は、random_stateに何か与えると良い。
  search_idx = LSHForest(random_state=1984)
  
  # 検索インデックスを作成する。
  search_idx.fit(targets)

  # 検索インデックスを保存する。
  joblib.dump(search_idx, 'lshtree.pkl.cmp', compress=True)

if __name__ == '__main__':
  main()

検索インデックスを読み込み、検索するスクリプトは次のようになる。これをsearch.pyとする。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from sklearn.datasets import load_svmlight_file
from sklearn.neighbors import LSHForest
from sklearn.externals import joblib

def main():
  # 検索クエリとして、MNISTのテストデータを読み込む。
  queries, q_labels = load_svmlight_file('mnist.scale.t')

  # 検索対象データのラベルを読み込む。
  # ※あとで検索結果を確認したいためで、検索には必要ない。
  _, t_labels = load_svmlight_file('mnist.scale')
  
  # 検索インデックスを読み込む。
  search_idx = joblib.load('lshtree.pkl.cmp')

  # 各クエリの5-近傍を検索する。
  dists, ids = search_idx.kneighbors(queries, n_neighbors=5)

  # 検索クエリの一番目のデータで、検索結果を確認する。
  print(q_labels[0])
  print(t_labels[ids[0,:]])
  print(ids[0,:])
  print(dists[0,:])

if __name__ == '__main__':
  main()

実行してみると、検索クエリと同じ「7」のMNISTデータが検索できているのがわかる。

$ ./gen_idx.py
$ ./search.py
7.0
[ 7.  7.  7.  7.  7.]
[15260 16186 14563  9724 31073]
[ 0.08516171  0.09843745  0.10159849  0.10445509  0.10882645]

LSHForestにも、LSHTreeの数などのパラメータがあるが、論文中で書かれている値であったりと、デフォルトで問題なさそう。 内部のLSHとしては、32ビットのバイナリベクトルに変換する様子。 FLANNと違って、検索時に、元の検索対象データを必要としないのが良い。 ちなみに、検索インデックスに新たにデータを加えるときは、partial_fitメソッドを使う。

Pythonで近似最近傍探索を試したいときはpyflannがちょうど良い

近似最近傍探索とは近似的に近いものを検索してくる技術で、普通に距離を計算して並べて近くにあるものを探すより速い。代表的なライブラリにFLANN(Fast Library for Approximate Nearest Neighbors)があり、これのPythonバインディングがpyflannになる。FLANNの開発は2013年から止まっているのに(もともとブリティッシュコロンビア大の研究がベースなので研究プロジェクトが一段落したんだと思われる)、pyflannは今でも開発されているのがおもしろい。FLANN自体は、Debian GNU/Linuxとかでもパッケージになってて、pyflannもpipにあるのでインストールは楽ちん。枯れ具合がちょうど良い。

$ sudo apt-get insatll libflann-dev
$ sudo pip install pyflann

では、まず、インデックスを作成する。これをgen_idx.pyとする。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pyflann
from sklearn.datasets import load_svmlight_file

def main():
  # とりあえず検索対象のデータとしてMNISTのデータを使う。
  # MNISTのデータはLIBSVM Dataのページからダウンロードできる。
  targets, _ = load_svmlight_file('mnist.scale')
  targets = targets.toarray()

  # 距離を設定する。一般的なユークリッド距離(euclidean)の他に、
  # マンハッタン距離(manhattan)とか
  # ヒストグラムインターセクションカーネル(hik)とかがある。
  pyflann.set_distance_type('euclidean')

  # 検索インデックスを作成する。
  # algorithmはHierarchical K-Means Clustering Treeを選択した。
  # 他にもRandomized KD-Treeなどがある。
  # centers_initはK-Meansの初期値の設定方法を指定する(デフォルトはランダム)。
  # 再現性を確保したい場合にはrandom_seedを指定すると良い。
  search_idx = pyflann.FLANN()
  params = search_idx.build_index(targets, algorithm='kmeans', 
    centers_init='kmeanspp', random_seed=1984)
  
  # 検索インデックスのパラメータを見てみる。
  print(params)

  # 作成した検索インデックスを保存する。
  search_idx.save_index('mnist.idx')

if __name__ == '__main__':
  main()

これを実行すると、検索インデックスが作成される。

$ ./gen_idx.py
{'branching': 32, 'cb_index': 0.5, 'centers_init': 'default', 'log_level': 'warning', 'algorithm': 'kmeans', 
...
'target_precision': 0.8999999761581421, 'sample_fraction': 0.10000000149011612, 'iterations': 5, 'random_seed': 1984, 
'checks': 32}

作成した検索インデックスを読み込んで検索を行う。これをsearch.pyとする。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import pyflann
from sklearn.datasets import load_svmlight_file

def main():
  # 検索対象となるMNISTデータを読み込む.
  targets, t_labels = load_svmlight_file('mnist.scale')
  targets = targets.toarray()

  # 検索クエリとなるMNISTデータを読み込む.
  queries, q_labels = load_svmlight_file('mnist.scale.t')
  queries = queries.toarray()

  # 距離を設定する。
  pyflann.set_distance_type('euclidean')

  # 作成した検索インデックスを読み込む。
  search_idx = pyflann.FLANN()
  search_idx.load_index('mnist.idx', targets)

  # 近似最近傍探索を行う。ここでは5-近傍を探索している。
  result, dists = search_idx.nn_index(queries, num_neighbors=5)

  # 結果を表示する。1番目の検索クエリのラベルと、その5-近傍のラベルと距離を見てみる。
  print(q_labels[0])
  print(t_labels[result[0,:]])
  print(result[0,:])
  print(dists[0,:])

if __name__ == '__main__':
  main()

実行してみると、検索クエリと同じ「7」のMNISTデータが検索できているのがわかる。

$ ./search.py
7.0
[ 7.  7.  7.  7.  7.]
[53843 38620 16186 27059 30502]
[  7.0398414    9.69496007  11.4449976   11.49352929  14.02158225]

他にも細かくパラメータを設定できる。 詳細は、FLANNの公式ページのユーザマニュアルを見ると良い。 MATLABバインディングの説明が参考になる。 FLANNの検索では、(当然だけど)検索インデックスの他に検索対象データも必要で、 検索対象データが高次元かつ大規模であると、検索対象データ自体がメモリに乗るかという問題が発生する。 対策は計算機環境によって色々考えられるだろうけど、どうしても検索対象データ自体を小さくしたい場合は、 Locality Sensitive Hashingに代表されるような、ハッシングによる近似最近傍探索を検討すると良い。