【論文紹介】A speed-optimized and cache-friendly implementation of stochastic gradient-boosted decision trees for multivariate classification
論文のダウンロード先
論文の内容
Gradient Boosting Decision Treeの高速化を、特徴量の構造を工夫してCPUキャッシュヒットしやすくすることによって、より高速化したという論文。
XGBoostよりシングルコア比較で2.2倍速く,マルチコア比較で3.8倍速い。また他のライブラリ(e.g. xgboost)より推定精度も向上したとのこと。
内容
- 多クラス分類問題とか、線形問題においてGradient Boosting Decision Treeが最近よく使われている。
- 本論文では計算処理を最適化、キャッシュ処理の最適化をとくに考慮したFastBSD(の提案手法及び実装)を提案。
- 高速化手法の具体的な方法は以下の通り
- 特徴量の整数化、一様分布化
- CPUキャッシュの特徴を考慮した推定
- 従来の特徴量のメモリ内での持ち方はstruct of arrays。これをFASTBSDではarray of structs構造を採用。これをやることによってCPUのメモリキャッシュのヒット率を上げる。
評価
- gradient-boosted decision treesを実装したライブラリを評価
- scikit-learn
- XGBoost(シングルスレッド)
- TMVA
- XGBoost-i7(マルチコア)
- FastBSD(提案手法)
- 評価に利用したデータ・セット
- 100件のデータポイント、35次元の特徴量。人工データ(モンテカルロを利用)
- データ・セットを各ライブラリ向けのフォーマットに変換
- ハイパーパラメータは以下の通り
- depth of the trees = 3
- number of trees = 100
- number of features = 35
- number of training data-points = 500000 • sampling-rate = 0.5
- shrinkage = 0.1
結果
- 推定モデルの生成速度
XGBoostよりシングルコア比較で2.2倍速く,マルチコア比較で3.8倍速い
- 推定精度
基本的に他の手法を上回る。例外としてはtreeを深くしていった場合。
過学習がおきてしまって精度がxgboostに劣る。理由はxgboostの何らかの過学習対策が寄与している?
- ツールとして実装でできていないこと。
- Negative Weights
- 欠損値のサポート
- 特徴量重要度の表示
所感
- xgboostを利用してhyperparameterチューニングとかCVとかやると計算が膨大になるので、そういった意味で高速化手法が出てくるのは嬉しい。
- なぜ精度がよくなるかがいまいち不明。
- 論文では整数化&特徴量の一様分布化をすることによって、過学習がおきなくなっている、といっているけど確信がないので、追加調査が必要
pythonとnetworkxを使ったHITS解析
networkxにはHITS解析の実装も存在します。そこで同ライブラリを利用してHITSによるグラフ解析手法を紹介します。
利用するライブラリ
前回の記事で紹介した通り、networkxとその依存ライブラリであるnumpy,scipyがインストールされている事が前提です。
今回解析する対象のグラフ
解析するグラフも前回の記事と同じグラフを解析することとします。
ソースコード
# -*- coding: utf-8 -*- import networkx as nx #有向グラフ g = nx.DiGraph() #ノードの追加(省略可能) g.add_node(1) g.add_node(2) g.add_node(3) g.add_node(4) #エッジの追加 g.add_edge(1,2) g.add_edge(1,3) g.add_edge(1,4) g.add_edge(2,3) g.add_edge(3,4) #pagerank値の計算 h,a=nx.hits(g) #pagerank値の計算(numpyを利用) nh,na=nx.hits_numpy(g) #pagerank値の計算(scipyを利用) sh,sa=nx.hits_scipy(g) #計算結果表示 print("-----HITS-----") print("--hub--") print(h) print("--authorities--") print(a) print("-----HITS(numpy)-----") print("--hub--") print(nh) print("--authorities--") print(na) print("-----HITS(scipy)-----") print("--hub--") print(sh) print("--authorities--") print(sa)
結果
-----HITS----- --hub-- {1: 0.5773502689711713, 2: 0.2113248655144144, 3: 0.2113248655144144, 4: 0.0} --authorities-- {1: 0.0, 2: 0.26794919177575915, 3: 0.3660254041121205, 4: 0.3660254041121205} -----HITS(numpy)----- --hub-- {1: 0.5773502691896258, 2: 0.2113248654051871, 3: 0.2113248654051871, 4: 0.0} --authorities-- {1: -0.0, 2: 0.26794919243112264, 3: 0.3660254037844386, 4: 0.36602540378443876} -----HITS(scipy)----- --hub-- {1: 0.5773502722323048, 2: 0.21132486388384755, 3: 0.21132486388384755, 4: 0.0} --authorities-- {1: 0.0, 2: 0.2679492015591601, 3: 0.36602539922042, 4: 0.36602539922042}
pythonとnetworkxを使ったPageRank解析
pythonにはnetowrkxという便利なグラフ解析ライブラリが存在します。
このライブラリはかなり良くできていて
・Rのigraph等のライブラリと比較してサブグラフの出力結果へのアクセス手法が用意
・numpyやscipy等との連携がかなり進んでおり、大規模なグラフや高速演算を求められるグラフにおいても適している。
という特徴があります。
ここでは、同ライブラリを使ってPageRankによるグラフ解析を実施する手法を紹介します。
利用ライブラリ
networkxを利用するにあたり必要なライブラリは以下の通りです。
networkx
numpy
scipy
numpyとscipyはnetowrkxを利用するにあたり必須のライブラリではありませんが、numpyやscipyを入れておくと疎行列を扱う事ができるようになるなどメリットが大きいです。。
(というか、numpyやscipy無しの場合、グラフのノード数やエッジ数のサイズが数百ぐらいまでと限定されます)
あと、networkxやnumpyはpipコマンド一発でインストールできますが、scipyについては、内部でBLASやLAPACを利用している関係からか、pipでインストールする前に関連ライブラリをインストールしておく必要があります。
インストール方法は環境によって多様ですが、基本的にはscipyのインストール手法(http://www.scipy.org/Installing_SciPy)を参考にしてください。
今回解析する対象のグラフ
以下のグラフを対象とします。
ソースコード
この例ではダンピングファクターの値を0.9としています。
# -*- coding: utf-8 -*- import networkx as nx #有向グラフ g = nx.DiGraph() #ノードの追加(省略可能) g.add_node(1) g.add_node(2) g.add_node(3) g.add_node(4) #エッジの追加 g.add_edge(1,2) g.add_edge(1,3) g.add_edge(1,4) g.add_edge(2,3) g.add_edge(3,4) #pagerank値の計算 pr=nx.pagerank(g,alpha=0.9) #pagerank値の計算(numpyを利用) prn=nx.pagerank_numpy(g,alpha=0.9) #pagerank値の計算(scipyを利用) prc=nx.pagerank_scipy(g,alpha=0.9) #計算結果表示 print("-----pagerank-----") print(pr) print("-----pagerank(numpy)-----") print(prn) print("-----pagerank(scipy)-----") print(prc)
実行結果
-----pagerank----- {1: 0.12058362438281975, 2: 0.15675871212474798, 3: 0.2978415539438116, 4: 0.4248161095486206} -----pagerank(numpy)----- {1: 0.12058362474375982, 2: 0.15675871216688764, 3: 0.29784155311708677, 4: 0.42481610997226577} -----pagerank(scipy)----- {1: 0.12058369059454467, 2: 0.15675865745841386, 3: 0.29784125143334383, 4: 0.42481640051369757}