DeepGBMの論文まとめ
DeepGBMとは
KDD 2019において発表された,LightGBMからNNへの蒸留とCategorical NNを用いたGBDT+NNのモデル
GBDTとNNのいいとこ取りしたモデルとなっており,Kaggle competitionで使われたdatasetを用いた予測問題では既存手法(LightGBM)よりスコアが改善したとのこと
authorにはLightGBMの開発陣が加わっており,LightGBM+Pytorchを用いたGithub repositoryが提供されている.
本記事ではKDD 2019で発表された論文の内容を説明する.
(下記リンクのdownloadを押せば入手できます)
モデルについて
GBDT
pros
- dense data (tabular numerical features)に強い
sparse dataに弱い反面,dense dataには強い傾向にある.
cons
- sparse data (categorical features)に弱い
gainベースで分岐を行うため,0-1のカテゴリカルデータは過学習しやすい.
kaggler界隈ではsparseなデータを特徴抽出してdenseなデータにしたり,stacking等の手法を使って特徴選択することが多い.
- オンライン学習に不利
GBDTなどの加法モデルは学習データが増えてくると,都度新たなtreeを追加してロス関数を下げていく手法である.
そのためオンライン学習では新しいデータに過学習しやすくなり,バッチ処理でのオフライン学習に比べて精度が劣る.
DNN
pros
- categorical dataに強い
embedding層を用いることでcategorical dataに対応できる
- オンライン学習に有利
back propagationによってロス関数を下げる手法であることから,例え1つのデータでも過学習は起きにくい
cons
- dense dataに弱い
dense dataにはFCNN(fully connected neural network)が用いられるが,複雑な最適化となることから局所解に陥りやすい
DeepGBM
上記のGBDTとDNNのpros&consを踏まえて,DeepGBMではdense dataにはGBDT,sparse dataにはNN(CatNN)を用いる.
GBDTはNNへと蒸留され,このモデルはGBDT2NNと呼ばれる.
更に,蒸留によって全体の構造がNNとなることで,オンライン学習が可能となる.
GBDT2NN
学習済みGBDTの木構造に近づけるようにFCNNへと蒸留が行われる.
そして,ここが肝なのだが,GBDTのleaf indexがはじめにleaf embeddingへの蒸留を行い,one-hotなleaf indexをdenseなembeddingへと変換する.
次にleaf embeddingは維持したまま,GBDTの木構造に近づけるためにNN(FCNN)を学習する
この蒸留によって得られたモデルはGBDT2NNと呼ばれる.最終的にはNN+embeddingの構成となる.
GBDTの複数木である場合,一つの木からでなく複数個の木を纏めたグループからNNモデルへ蒸留を行う. これをTree Groupingという.
個の木を 個のグループに分ける場合を考えると,各グループには 個の木が含まれる.
今,入力データについて,tree group に使われる特徴量indicesを とすると,から蒸留されたNNモデルの出力は以下の式で表される.
where, : a multi-layered NN model with input and parameter
GBDTモデルの出力は,個のグループの出力の和で表される.
計算コストについて
ある木 の構造関数をとする.
はサンプルについてoutput leaf indexを返す.
についてののone-hot表現をとする.
GBDTの複数木である場合,一つの木に1つのNNモデルを蒸留することを考えると, 個のNNモデルが必要となる.
そのため,計算コストは となる.
leaf embeddingの出力 は,よりとても小さくなることから,leaf embeddingの蒸留はの計算コストよりも小さくなる.
また,Tree Groupingによって よりも小さい計算コストとすることができる.
CatNN
一方,categorical dataにはCatNNが使われる.こちらにもembeddingが使われており,高次元なsparse dataをdense dataへと変換する
CatNNはFM componentとDeep componentの出力を足し合わせた値を出力する
where, : the output of CatNN model,
: the output of FM component,
: the output of Deep component
は特徴間の線形およびpair-wiseな相互作用を表す
where, : the number of features,
: the hyperparameters of linear part,
: the inner product operation
ここで,
であり, は に対応するembedding vectorを返す
は特徴間のより高次な相互作用を表す
DeepGBM
DeepGBMは, と を重み付き和して,sigmoid関数などで変換した値を出力する
where, : the trainable parameters,
: the output transformation, such as sigmoid for binary classification
以上のモデル特性は下記表に纏められている
学習について
DeepGBMはオフライン学習とオンライン学習で学習方法が異なる
オフライン学習 (End-to-end offline learning)
① GBDT,CatNNの学習
これらのモデルは通常の学習を行う
② leaf embeddingの学習
GBDTの複数木の場合,leaf embeddingの学習プロセスは下記で表される
where, : the parameters for mapping embedding to leaf values,
: the same loss function as used in tree learning,
: an one-layered fully connected network with parameter ,
: the concatenate operation,
: the predict leaf value of sample
leaf embeddingによってone-hotなleaf indexからdense embeddingへと変換される.
ここで,tree outputは,ある木のleaf valuesをと表し,をi番目のleafのleaf valueとすると, で表される.
また,
は,パラメータをもつ1層のFCNNであり,multiple one-hot leaf index vectorsが結合したmulti-hot vectorsをdense embedding へ変換する.
③ GBDT2NNの学習(構造の蒸留)
NNモデルの蒸留の標的として新たなembeddingを用いる.
学習プロセスは下記で表される.
where, : the used features in tree group .
は,多くの特徴量が含まれて,特徴選択の能力が損なわれる可能性があることから,feature importanceに従って重要度の高い特徴量のみ使うようにする.
④ DeepGBMの学習
オフラインデータを使って,まずGBDTとCatNNの学習を行なった後,end-to-endでDeepGBMの学習を行う.
ロス関数は以下で表される.
where, : the training target of sample ,
: the number of tree groups,
: hyper-parameters used for controlling the strength of end-to-end loss and embedding, respectively.
オンライン学習 (Online update)
オフライン学習されたGBDTを,embeddingをonliine updateで学習させるために使用することはオンラインでのリアルタイムの性能を損なうと考えられる.
そのため,online updateではロス関数に
実行方法
DeepGBMの実装はgithubに公開されているので,DLして試すことが出来ます.
experimentsディレクトリに移動して,README.mdに書かれている通りの手続きを踏めば良いです.