洋食の日記

「た・である」調ではなく「です・ます」調で書きはじめれば良かったなと後悔してる人のブログです

DaruとRumaleを使ってKaggleのTitanicコンペにRubyで挑戦する

はじめに

Kaggleのチュートリアルコンペぐらいなら、Rubyでもイケるんじゃないかと思って、データ分析ライブラリのDaruと機械学習ライブラリのRumaleでTitanicコンペに挑戦してみた。

Titanic: Machine Learning from Disaster | Kaggle

DaruはPythonでいうところのPandasで、これで特徴エンジニアリングを行い、RumaleのRandom Forestで推定を行う。 DaruもRumaleもgemコマンドでインストールできる。

$ gem install daru rumale

データの読み込みと欠損値の補完

まず、KaggleのTitanicコンペからダウンロードしてきた訓練データセット(train.csv)とテストデータセット(test.csv)を読み込む。Daruでは、from_csvメソッドで、CSVファイルをDataFrameにできる。

require 'daru'
require 'rumale'

# データセットを読み込む
train_df = Daru::DataFrame.from_csv('train.csv', headers: false)
test_df = Daru::DataFrame.from_csv('test.csv', headers: false)

Titanicコンペでは、乗客が生存したかどうかを推定する。評価のために提出するCSVファイルは、テストデータセットの乗客IDに生存したかどうかのラベルを紐づけたものとなる。※ちなみに、コンペのもとになったタイタニック号沈没事故は痛ましい事故なので、Titanicコンペの特徴エンジニアリングを考えるときは感情移入し過ぎるとよくない。

# 訓練データから生存に関するラベルを取り出す
target_vals = train_df['Survived']
train_df.delete_vectors('Survived', 'PassengerId')

# テストデータから乗客IDを取り出す
test_pids = test_df['PassengerId']
test_df.delete_vectors('PassengerId')

データセットの一部には欠損値があり、これはDaruのDataFrame上ではnilとなる。replace_nils!メソッドにより任意の値で埋めることができる。

# カテゴリデータの欠損値をUnknownを意味するUで埋める
train_df['Cabin'].replace_nils!('U')
test_df['Cabin'].replace_nils!('U')
train_df['Ticket'].replace_nils!('U')
test_df['Ticket'].replace_nils!('U')
train_df['Embarked'].replace_nils!('U')
test_df['Embarked'].replace_nils!('U')

# 数量データの欠損値を平均で埋める
mean_age = train_df['Age'].mean
train_df['Age'].replace_nils!(mean_age)
test_df['Age'].replace_nils!(mean_age)
mean_fare = train_df['Fare'].mean
train_df['Fare'].replace_nils!(mean_fare)
test_df['Fare'].replace_nils!(mean_fare)

カテゴリデータを数値に変換する

RumaleのLabelEncoderを使用してカテゴリに数値を割り当てていく。Rumaleはデータの表現にNumo::NArrayを利用しているが、DaruはNumo::NArrayに対応していない。Rumaleとのデータの受け渡しでは、適宜to_aメソッドでArrayに変換した。

# Rumaleのラベルエンコーダーを作成する
encoder = Rumale::Preprocessing::LabelEncoder.new

# DaruのVectorをto_aメソッドでArrayにして、Rumaleのfit_trasformやtransformメソッドにわたす
# fit_transformやtransformメソッドは、Numo::Int32型を返すので、to_aメソッドでArrayにする
train_df['Embarked'] = encoder.fit_transform(train_df['Embarked'].to_a).to_a
test_df['Embarked'] = encoder.transform(test_df['Embarked'].to_a).to_a

train_df['Cabin'] = train_df['Cabin'].map { |v| v[0] }
test_df['Cabin'] = test_df['Cabin'].map { |v| v[0] }
train_df['Cabin'] = encoder.fit_transform(train_df['Cabin'].to_a).to_a
test_df['Cabin'] = encoder.transform(test_df['Cabin'].to_a).to_a

train_df['Ticket'] = train_df['Ticket'].map { |v| v.to_s[0] }
test_df['Ticket'] = test_df['Ticket'].map { |v| v.to_s[0] }
train_df['Ticket'] = encoder.fit_transform(train_df['Ticket'].to_a).to_a
test_df['Ticket'] = encoder.transform(test_df['Ticket'].to_a).to_a

バイナリ変数を追加する

Wikipediaの記事やKaggleのkernelを見ると、女性や少年は生存しているようだ。

# 女性であるか
train_df['IsFemale'] = train_df['Sex'].map { |v| v == 'female' ? 1 : 0 }
test_df['IsFemale'] = test_df['Sex'].map { |v| v == 'female' ? 1 : 0 }

# 少年(敬称がMaster)であるか
train_df['IsMaster'] = train_df['Name'].map { |v| v.split(',')[1].split('.')[0].strip == 'Master' ? 1 : 0 }
test_df['IsMaster'] = test_df['Name'].map { |v| v.split(',')[1].split('.')[0].strip == 'Master' ? 1 : 0 }

数量データの一部を量子化する

特徴量としては、実数値より、ざっくりとヒストグラムで表現したほうが良い場合がある。ただ、Daruには、Pandasのqcutやcutがないようなので、なにかしらの方法で量子化する必要がある。

# 家族の人数を計算する
train_df['FamilySize'] = train_df['Parch'] + train_df['SibSp'] + 1
test_df['FamilySize'] = test_df['Parch'] + test_df['SibSp'] + 1

# 料金を家族の人数で割る
train_df['MeanFare'] = train_df['Fare'] / train_df['FamilySize']
test_df['MeanFare'] = test_df['Fare'] / test_df['FamilySize']

# 料金に関する変数を50段階にする
max_fare = train_df['Fare'].max.to_f
min_fare = train_df['Fare'].min.to_f
train_df['Fare'] = (((train_df['Fare'] - min_fare) / (max_fare - min_fare)) * 50.0).round
test_df['Fare'] = (((test_df['Fare'] - min_fare) / (max_fare - min_fare)) * 50.0).round

max_mean_fare = train_df['MeanFare'].max.to_f
min_mean_fare = train_df['MeanFare'].min.to_f
train_df['MeanFare'] = (((train_df['MeanFare'] - min_mean_fare) / (max_mean_fare - min_mean_fare)) * 50.0).round
test_df['MeanFare'] = (((test_df['MeanFare'] - min_mean_fare) / (max_mean_fare - min_mean_fare)) * 50.0).round

# 年齢に関する変数を10段階にする
# ※50段階や10段階にしたのは単に「思いついた数字」で特に意図はない
max_age = train_df['Age'].max.to_f
min_age = train_df['Age'].min.to_f
train_df['Age'] = (((train_df['Age'] - min_age) / (max_age - min_age)) * 10.0).round
test_df['Age'] = (((test_df['Age'] - min_age) / (max_age - min_age)) * 10.0).round

不要な特徴量を削除してNumo::NArray形式に変換する

Rumaleでは、データの表現にNumo::NArrayを利用している。DaruのDataFrameはto_matrixメソッドでMatrixに、Vectorはto_aメソッドでArrayに変換し、これらをNumo::NArrayにわたすことで、Numo::DFloatやNumo::Int32に変換する。

# 名前などはもう使わないので削除する
del_cols = ['Name', 'SibSp', 'Parch', 'Sex']
train_df.delete_vectors(*del_cols)
test_df.delete_vectors(*del_cols)

# DataFrameをMatrixに、VectorをArrayに変換して、Numo::NArray形式にする
samples = Numo::DFloat[*train_df.to_matrix]
labels = Numo::Int32[*target_vals.to_a]
test_samples = Numo::DFloat[*test_df.to_matrix]

交差検定で確認する

Rumaleを使って交差検定を行う。特徴量の重要度も確認したいので、分類器にはRandom Forestを用いる。

# 結果を保存する変数を初期化する
imp = Numo::DFloat.zeros(n_features)
sum_accuracy = 0.0

# 10-交差検定の分割をおこなう
kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 10, shuffle: true, random_seed: 1)

kf.split(samples, labels).each do |train_ids, valid_ids|
  # 訓練データと検証データを得る
  train_samples = samples[train_ids, true]
  train_labels = labels[train_ids]
  valid_samples = samples[valid_ids, true]
  valid_labels = labels[valid_ids]
  # Random Forestで学習する
  clf = Rumale::Ensemble::RandomForestClassifier.new(n_estimators: 100, max_features: 2, random_seed: 1)
  clf.fit(train_samples, train_labels)
  # 特徴量の重要度を得る
  imp += clf.feature_importances
  # 正確度を得る
  sum_accuracy += clf.score(valid_samples, valid_labels)
end

# 平均の正確度を出力する
mean_accuracy = sum_accuracy / kf.n_splits
puts
puts sprintf("Mean Accuracy: %.5f", mean_accuracy)

# 特徴量の重要度を出力する
puts
puts "Feature importances:"
train_df.vectors.to_a.each_with_index { |col, i| puts("#{col}: #{imp[i] / kf.n_splits}") }

これを実行すると次のようになる。8割ほどの正確度が得られ、最も重要な特徴は女性であるかどうかになった。部屋や料金も重要らしい。

Mean Accuracy: 0.82270

Feature importances:
Pclass: 0.08757945953461778
Age: 0.05991353734583812
Ticket: 0.06425842380708825
Fare: 0.08939140299941103
Cabin: 0.09128477545493736
Embarked: 0.02360076253503096
IsFemale: 0.41341081597428586
IsMaster: 0.027338294741599423
FamilySize: 0.07416848338510751
MeanFare: 0.06905404422208355

提出用CSVファイルを作成する

Daruはwrite_csvメソッドで、DataFrameをCSVファイルに書き出せる。乗客IDと推定結果を結びつけて、提出用ファイルを作成する。

# テストデータセットのラベルを推定する
clf = Rumale::Ensemble::RandomForestClassifier.new(n_estimators: 100, max_features: 2, random_seed: 1)
clf.fit(samples, labels)
prediction = clf.predict(test_samples)

# 提出用ファイルを作成する
submission = Daru::DataFrame.new({'PassengerId': test_pids, 'Survived': prediction })
submission.write_csv('submission.csv')

これを提出すると、Public Leaderboardで0.79425だった。サンプルのgender_submission.csvがたしか0.76とかだったので、それよりは良い感じ。

おわりに

「Rumaleを作るだけじゃなく使ってみないと」と思ってTitanicコンペに挑戦してみた。特徴エンジニアリングとかは、思いつきで適当に行ったが、Titanicコンペのような小さいデータセットであれば、DaruとRumaleで十分にデータ分析などができる印象を持った。ただ、データセットが大きくなると、実行速度の面で厳しい様に思う。Rumaleは、並列に処理できる箇所があるので、Parallel gemで高速化できないか考えている。Daruは、Daru::Viewに力をいれている様子で、Daru本体のリリースが一年近くないのがちょっと気がかり 🤔

ちなみにKaggleだが、それなりのスペックのマシンを用意することをオススメする。私はEarly 2016なMacBookしかもっておらず、一時期KaggleのKernelでガンバったが、タイムアウトとかメモリ超過で強制終了になることがほとんどで、結局あきらめてしまった 😓 いつかお金と時間に余裕ができたら再挑戦したい。

Rumaleに改名するときにやらかしたコト

はじめに

Rumaleに改名する際に以下のページをまずみた。Googleで「gem rename」で上位に来たので。

stackoverflow.com

ざっと読んで、旧名のGemのリポジトリとかREADMEとかに説明をつけて、新しいの出せば良いんだな〜と思った。FactoryBot(旧FactoryGirl)はなんかGithubリポジトリも引き継いでる感じだけど、なんか特殊な方法をやってるのだろうな〜と思った。 だが、実際は特殊な方法ではなくて、GithubRubyGemsがいい感じにやってくれるというコトを教わった。またまた、id:mrkn さんにお世話になりつつ作業をすすめた 🙏 改めて感謝申し上げます 🙏

SVMKitからRumaleへ

Githubでは、リポジトリ名はSettingsでRenameできる。改名前のURLでアクセスしたとしても、Githubがいい感じにリダイレクトしてくれる。新しい名前のリポジトリを用意するのに比べて、Starとかのステータスが維持されるのが良い。

Gitコマンドレベルの作業では次のようにした。手元にSVMKitとRumaleの作業コピーがあるとする。Rumaleの内容でSVMKitを上書きするかたちになる。

$ cd svmkit
$ git rm -rf --ignore-unmatch *
$ cp -pr ../rumale/* .
$ cp ../rumale/.{.coveralls.yml,.gitignore,.rspec*,.rubocop*,.travis.yml} .
$ git add .
$ git commit -m ':rocket: Rename to Rumale'

置き換えたのは version 0.8.0 なので、v0.8.0のタグを付け替えてpushしようと思ったが、なにか手違いがあったみたいで、tagがforceな感じで置き換えれなかった。というわけで、まず普通にpushした。

$ git push origin master

これでSVMKitの中身がRumaleのものになる。 次に、Github上でrumaleリポジトリを削除して、svmkiitリポジトリをrumaleにRenameした。 このとき、Githubで「rumaleはすでにあります」みたいなエラーが出たり、アクセスしたら前のモノが残ってるなど、変なコトは起きなかった。意図通りの変更ができた。

そして、pendingしてたタグ付けだが、以下の準備をしておいて、

$ cd ../
$ mv rumale rumale.old
$ git clone https://github.com/yoshoku/rumale.git
$ cd rumale
$ git tag -a v0.8.0 -m 'Version 0.8.0' 最新のコミット

そしてGithub上で v0.8.0 のタグを削除して、速攻pushした。我ながらヤベーヤツだ。

$ git push origin v0.8.0

これでGithub上でSVMKitがRumaleになった。RubyGemsのRumaleのページに行くと、SVMKitをRumaleに改名したものが紐付いていた。

SVMKitの後始末

SVMKitをRumaleと置き換えるだけで、SVMKitのコードはそのまま動く。なので、実質 SVMKit = Rumale という定数定義だけのRumaleに依存する空Gemを作って、それをSVMKitのversion 0.8.1としてリリースした。

$ bundle gem svmkit
$ cd svmkit

旧SVMKitからlibやspecディレクトリ以外のものをコピーしてきた。svmkit.gemspecファイルは以下のようにした。versionを直打ちにして、rumaleに依存させる。

lib = File.expand_path('lib', __dir__)
$LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)

Gem::Specification.new do |spec|
  spec.name          = 'svmkit'
  spec.version       = '0.8.1'
... 略
  spec.add_runtime_dependency 'rumale', '~> 0.8'
... 略

その上でsvmkit/version.rbを削除した。

$ cd lib
$ rm -rf svmkit

そして、svmkit.rbを次のようにした。

warn 'SVMKit has been deprecated; You should migrate to Rumale.'

require 'rumale'

SVMKit = Rumale

これで、gemファイルを作り、RubyGemsに上げる。Githubにpushされたりすると困るといういか、リポジトリがないので、rake releaseではなく、gem pushを使った。

$ rake build
svmkit 0.8.1 built to pkg/svmkit-0.8.1.gem.

$ gem push pkg/svmkit-0.8.1.gem
Pushing gem to https://rubygems.org...
Successfully registered gem: svmkit (0.8.1)

追記:

念のため、version 0.8.0はyankした...

$ gem yank svmkit -v0.8.0
Yanking gem from https://rubygems.org...
Successfully deleted gem: svmkit (0.8.0)

Travis CIやCoveralls

Travis CIはSettingsページにある、Sync accountを押して、ページをリロードすると、svmkitが消えていた。 Coverallsは、SVMKitのSettingsページに行き、DELETE REPOSITORYした。そして、RumaleのSettingsページで、Resync with githubした。 画面を見た感じ上手く切り替えれたっぽい。

追記:

結局、RumaleもDELETE REPOSITORYして、ADD REPOSページでSYNC REPOSボタンを押したあとで、改めてRumaleのリポジトリを追加した。そのうえでTravis CIでRebuildした。

おわりに

やってしまいました...気をつけたいと思います...

SVMKitをRumaleに改名した

SVMKitをRumale (Ruby machine learning) に改名した。SVMKitにサポートベクターマシン以外のアルゴリズムを実装するようになってから、ずっと考えていたが、いい名前が思いつかず放置していた。Rumaleの命名には、 Red Data Toolsid:mrkn さんや kou さんにご協力頂いた(というか私は本当にノーアイディアだった...) 🙏 改めて感謝申し上げます 🙏

rubygems.org

github.com

改名に向かってやるべきこと(改名したよのメッセージとかリンクはったりとか)は、RubyのFactoryBotやTerrapin、JavascriptのPugを参考にした。

SVMKitもRumaleもバージョン0.8.0では同様の内容となっている。ただ、RumaleはRuby 2.3以上での利用を想定している。これはsafe navigation演算子とかを使いたかったことなどがある。今後SVMKitは、bugfixのみをリリースする予定。

Rumaleを育てていくぞ〜 💪

SVMKitにGrid Seachを実装した

はじめに

SVMKitに、ハイパーパラメータの探索手法として定番のGrid Searchを実装した。Scikit-learnのGrid Searchと同様に、交差検定をベースにした探索を行う。与えられたハイパーパラメータの値のすべての組み合わせで、交差検定を行い、テストでのスコアが最大(もしくは最小)のものを最適なハイパーパラメータとする。

svmkit | RubyGems.org | your community gem host

使い方

サンプルコードでは、LIBSVM DataのpendigitsデータセットとEunite 2001データセットを用いる。

$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/pendigits
$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/pendigits.t
$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/eunite2001
$ wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/eunite2001.t
決定木による分類

決定木には、枝を分岐する評価基準や、木の深さなどのハイパーパラメータがある。これをGrid Searchで最適化する。

require 'pp'
require 'svmkit'

# データを読み込む。
samples, labels = SVMKit::Dataset.load_libsvm_file('pendigits')
samples = Numo::DFloat.cast(samples)

# 決定木を定義する。
dt = SVMKit::Tree::DecisionTreeClassifier.new(random_seed: 1)

# ハイパーパラメータで確認したい値を、パラメータ名をkey、確認したい値のArrayをvalueとするHashで定義する。
pg = { criterion: ['gini', 'entropy'], max_depth: [4, 8] }

# 5-交差検定によるGrid Searchを実行する。
kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5)
gs = SVMKit::ModelSelection::GridSearchCV.new(estimator: dt, param_grid: pg, splitter: kf)
gs.fit(samples, labels)

# 結果を表示する。
pp gs.cv_results
puts '---'
pp gs.best_params

# テストデータセットで推定する。
puts '---'
samples, labels = SVMKit::Dataset.load_libsvm_file('pendigits.t')
samples = Numo::DFloat.cast(samples)
puts("Test Dataset Accuracy: %.1f %%" % (gs.score(samples, labels) * 100.0))

これを実行すると次の様になる。cv_resultsには、各パラメータの組み合わせ(params)と、それに対応する交差検定のスコア(mean_test_scoreなど)が入っている。best_paramsはスコアが最大となる(この場合DecisionTreeClassifierのscoreメソッドが実行されAccuracyが最大となる)パラメータが入っている。

$ ruby svmkit_gs_example.rb
{:mean_test_score=>
  [0.7272529937832951,
   0.9286126486405282,
   0.7272529937832951,
   0.9286126486405282],
... 略
 :params=>
  [{:criterion=>"gini", :max_depth=>4},
   {:criterion=>"gini", :max_depth=>8},
   {:criterion=>"entropy", :max_depth=>4},
   {:criterion=>"entropy", :max_depth=>8}]}
---
{:criterion=>"gini", :max_depth=>8}
---
Test Dataset Accuracy: 87.4 %
PipelineでのGrid Search

Pipelineは特徴変換と分類器で構成される。それぞれのハイパーパラメータをGrid Searchで最適化するコードは以下の様になる。

require 'pp'
require 'svmkit'

# データを読み込む。
samples, labels = SVMKit::Dataset.load_libsvm_file('pendigits')
samples = Numo::DFloat.cast(samples)

# カーネル近似とサポートベクターマシンとの分類によるパイプラインを作る。
rbf = SVMKit::KernelApproximation::RBF.new(random_seed: 1)
svc = SVMKit::LinearModel::SVC.new(random_seed: 1)
pipe = SVMKit::Pipeline::Pipeline.new(steps: { foo: rbf, bar: svc })

# カーネル近似のハイパーパラメータであるガンマと成分数、
# サポートベクターマシンのハイパーパラメータである正則化係数で、
# Grid Searchで探索したい値をHashで定義する。
# それぞれアンダーバー2つで名前とパラメータを指定する。
pg = { foo__gamma: [0.1, 0.0001], foo__n_components: [512, 1024], 
       bar__reg_param: [1.0, 0.0001] }

# 5交差検定で確認するため、分割を定義する。
kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5)

# Grid Search を実行する。
gs = SVMKit::ModelSelection::GridSearchCV.new(estimator: pipe, param_grid: pg, splitter: kf)
gs.fit(samples, labels)

# 結果を表示する。
pp gs.cv_results
puts '---'
pp gs.best_params

# テストデータセットで推定する。
puts '---'
samples, labels = SVMKit::Dataset.load_libsvm_file('pendigits.t')
samples = Numo::DFloat.cast(samples)
puts("Test Dataset Accuracy: %.1f %%" % (gs.score(samples, labels) * 100.0))

これを実行すると次の様になる。

$ ruby svmkit_gs_example.rb
{:mean_test_score=>
  [0.09995046835385611,
   0.10233753516837311,
   0.09994663406224738,
... 略
 :params=>
  [{:rbf__gamma=>0.1, :rbf__n_components=>512, :svc__reg_param=>1.0},
   {:rbf__gamma=>0.1, :rbf__n_components=>512, :svc__reg_param=>0.0001},
   {:rbf__gamma=>0.1, :rbf__n_components=>1024, :svc__reg_param=>1.0},
... 略
---
{:rbf__gamma=>0.0001, :rbf__n_components=>1024, :svc__reg_param=>0.0001}
--
Test Dataset Accuracy: 98.1 %
回帰

回帰の場合は以下の様になる。評価尺度にはMean Squared Errorを用いる。これは小さいほど回帰としては良いと判断できる。小さいほどよいという評価を使うばあいは、greater_is_better引数にfalseを与える(デフォルトではtrue)。

require 'pp'
require 'svmkit'

# データを読み込む。
samples, values = SVMKit::Dataset.load_libsvm_file('eunite2001')
samples = Numo::DFloat.cast(samples)
values = Numo::DFloat.cast(values)

# 決定木による回帰を定義する。
dt = SVMKit::Tree::DecisionTreeRegressor.new(random_seed: 1)

# 確認したいハイパーパラメータを定義する。
pg = { max_depth: [4, 8], max_features: [2, 4] }

# 評価尺度にはMean Squared Error (MSE) を用いる。
ev = SVMKit::EvaluationMeasure::MeanSquaredError.new

# 5-交差検定でGrid Searchを行う。
# MSEは小さいほどよいので、greater_is_betterにfalseを与える。
kf = SVMKit::ModelSelection::KFold.new(n_splits: 5)
gs = SVMKit::ModelSelection::GridSearchCV.new(
  estimator: dt, param_grid: pg, splitter: kf, evaluator: ev, greater_is_better: false)
gs.fit(samples, values)

# 結果を表示する。
pp gs.cv_results
puts '---'
pp gs.best_params

# テストデータセットで推定する。
puts '---'
samples, values = SVMKit::Dataset.load_libsvm_file('eunite2001.t')
samples = Numo::DFloat.cast(samples)
values = Numo::DFloat.cast(values)
puts("Test Dataset MSE: %.4f" % ev.score(values, gs.predict(samples)))

これを実行すると以下の様になる。必ずしも深い木が良いわけではないのがおもしろい。

$ ruby svmkit_gs_example.rb
{:mean_test_score=>
  [1502.7172630764019,
   989.053627383465,
   1456.5703139919217,
   1074.9447023536181],
...略
 :params=>
  [{:max_depth=>4, :max_features=>2},
   {:max_depth=>4, :max_features=>4},
   {:max_depth=>8, :max_features=>2},
   {:max_depth=>8, :max_features=>4}]}
---
{:max_depth=>4, :max_features=>4}
---
Test Dataset MSE: 779.5281

おわりに

Grid Searchの実装により、機械学習ライブラリとして基本的なことがだいたいできるようになった。一方で、少しずつSVMKitのリファクタリングを進めていて、ある程度整えば、いよいよ改名したいと考えている。改名の理由は、SVMだけでなく様々なアルゴリズムを実装してしまったことで、これ以上この名前でアルゴリズムを追加するのは難しい(もともとJavaのMALLETぐらいのサイズ感を考えていた)。ただ、改名後もSVMKitのbugfixなどのメンテナンスは続けるし、高速化などはバックポートしたい。

Rubyでキーワード引数の引数名と値の一覧をHashで得る

はじめに

Ruby 1.9以前のHashでキーワード引数を再現してたような感じで、Ruby 2.0以降のキーワード引数でも、引数の名前とその値の一覧をHashで得たいときがある。

コード

def foo(a: 'yes', b: 'no')
  keywd_args = method(__callee__).parameters.map { |_t, arg| [arg, binding.local_variable_get(arg)] }.to_h
  p keywd_args
end

これをirbで動かすと以下の様な感じ。キーワード引数の名前とその値を、Hashで得ることができている。

irb(main):005:0> foo
{:a=>"yes", :b=>"no"}
=> {:a=>"yes", :b=>"no"}

irb(main):006:0> foo(a: '123', b: '456')
{:a=>"123", :b=>"456"}

irb(main):007:0> foo(a: 'bar')
{:a=>"bar", :b=>"no"}
=> {:a=>"bar", :b=>"no"}

簡単な解説

  • __callee__が返すメソッド名のシンボルをmethodメソッドに渡してMethodオブジェクトを得る
  • Methodオブジェクトのparametersメソッドが返す引数の種類と名前によるArrayをmapで展開する
  • bindingのlocal_variable_getメソッドに引数の名前を与えて引数の値を得る
  • 引数名と値のペアによるArrayをto_hでHashにする

おわりに

「前にどうしたっけな〜」とよく忘れてしまう自分のためのメモ 📝