DeepGBMの論文まとめ

DeepGBMとは

KDD 2019において発表された,LightGBMからNNへの蒸留とCategorical NNを用いたGBDT+NNのモデル

GBDTとNNのいいとこ取りしたモデルとなっており,Kaggle competitionで使われたdatasetを用いた予測問題では既存手法(LightGBM)よりスコアが改善したとのこと

authorにはLightGBMの開発陣が加わっており,LightGBM+Pytorchを用いたGithub repositoryが提供されている.

github.com

本記事ではKDD 2019で発表された論文の内容を説明する.

(下記リンクのdownloadを押せば入手できます)

www.kdd.org

モデルについて

GBDT

pros

  • dense data (tabular numerical features)に強い

sparse dataに弱い反面,dense dataには強い傾向にある.

cons

  • sparse data (categorical features)に弱い

gainベースで分岐を行うため,0-1のカテゴリカルデータは過学習しやすい.

kaggler界隈ではsparseなデータを特徴抽出してdenseなデータにしたり,stacking等の手法を使って特徴選択することが多い.

  • オンライン学習に不利

GBDTなどの加法モデルは学習データが増えてくると,都度新たなtreeを追加してロス関数を下げていく手法である.

そのためオンライン学習では新しいデータに過学習しやすくなり,バッチ処理でのオフライン学習に比べて精度が劣る.

DNN

pros

  • categorical dataに強い

embedding層を用いることでcategorical dataに対応できる

  • オンライン学習に有利

back propagationによってロス関数を下げる手法であることから,例え1つのデータでも過学習は起きにくい

cons

  • dense dataに弱い

dense dataにはFCNN(fully connected neural network)が用いられるが,複雑な最適化となることから局所解に陥りやすい

DeepGBM

上記のGBDTとDNNのpros&consを踏まえて,DeepGBMではdense dataにはGBDT,sparse dataにはNN(CatNN)を用いる.

GBDTはNNへと蒸留され,このモデルはGBDT2NNと呼ばれる.

更に,蒸留によって全体の構造がNNとなることで,オンライン学習が可能となる.

f:id:behemoth03:20191109031617j:plain

GBDT2NN

学習済みGBDTの木構造に近づけるようにFCNNへと蒸留が行われる.

そして,ここが肝なのだが,GBDTのleaf indexがはじめにleaf embeddingへの蒸留を行い,one-hotなleaf indexをdenseなembeddingへと変換する.

次にleaf embeddingは維持したまま,GBDTの木構造に近づけるためにNN(FCNN)を学習する

この蒸留によって得られたモデルはGBDT2NNと呼ばれる.最終的にはNN+embeddingの構成となる.

GBDTの複数木である場合,一つの木からでなく複数個の木を纏めたグループからNNモデルへ蒸留を行う. これをTree Groupingという.


\begin{aligned}
m
\end{aligned}
個の木を 
\begin{aligned}
k\end{aligned}
個のグループに分ける場合を考えると,各グループには 
\begin{aligned}
s = \lceil m / k \rceil
\end{aligned}
個の木が含まれる.

今,入力データ
\begin{aligned}
\boldsymbol{x}
\end{aligned}
について,tree group 
\begin{aligned}
\mathbb{T}
\end{aligned}
に使われる特徴量indicesを 
\begin{aligned}
\mathbb{I}
\end{aligned}
とすると,
\begin{aligned}
\mathbb{T}
\end{aligned}
から蒸留されたNNモデルの出力
\begin{aligned}
y_{\mathbb{T}}(\boldsymbol{x})
\end{aligned}
は以下の式で表される.


\begin{aligned}
y_{\mathbb{T}}(\boldsymbol{x}) &= 
 \boldsymbol{w}^{T} \times \boldsymbol{N} 
 \left( \boldsymbol{x} \left.[ \mathbb{I}^{\mathbb{T}} \right.] ; \theta^{\mathbb{T}} \right) 
+ w_{0} \\
\end{aligned}
where, 
\begin{aligned}
\boldsymbol{N} (\boldsymbol{x} ; \boldsymbol{\theta})
\end{aligned}
: a multi-layered NN model with input 
\begin{aligned}
\boldsymbol{x}
\end{aligned}
and parameter 
\begin{aligned}
\boldsymbol{\theta}
\end{aligned}

GBDTモデルの出力
\begin{aligned}
y_{GBDT2NN}(\boldsymbol{x})
\end{aligned}
は,
\begin{aligned}
k
\end{aligned}
個のグループの出力の和で表される.


\begin{aligned}
y_{GBDT2NN}(\boldsymbol{x}) &= \sum_{j=1}^{k} y_{\mathbb{T}_{j}}(\boldsymbol{x})
\end{aligned}

f:id:behemoth03:20191109031650j:plain

計算コストについて

ある木 
\begin{aligned}
t
\end{aligned}
の構造関数を
\begin{aligned}
C^{t} (\boldsymbol{x})
\end{aligned}
とする.


\begin{aligned}
C^{t} (\boldsymbol{x})
\end{aligned}
はサンプル
\begin{aligned}
\boldsymbol{x}
\end{aligned}
についてoutput leaf indexを返す.


\begin{aligned}
\boldsymbol{x}^{i}
\end{aligned}
についての
\begin{aligned}
C^{t} (\boldsymbol{x}^{i})
\end{aligned}
のone-hot表現を
\begin{aligned}
\boldsymbol{L}^{t,i}
\end{aligned}
とする.

GBDTの複数木である場合,一つの木に1つのNNモデルを蒸留することを考えると,
\begin{aligned}
\# NN = \# tree
\end{aligned}
個のNNモデルが必要となる.

そのため,計算コストは 
\begin{aligned}
O ( | L | \times \# NN)
\end{aligned}
となる.

leaf embeddingの出力 
\begin{aligned}
\boldsymbol{H}^{t,i}
\end{aligned}
は,
\begin{aligned}
\boldsymbol{L}^{t,i}
\end{aligned}
よりとても小さくなることから,leaf embeddingの蒸留は
\begin{aligned}
| L |
\end{aligned}
の計算コストよりも小さくなる.

また,Tree Groupingによって
\begin{aligned}
\# NN
\end{aligned}
よりも小さい計算コストとすることができる.

CatNN

一方,categorical dataにはCatNNが使われる.こちらにもembeddingが使われており,高次元なsparse dataをdense dataへと変換する

CatNNはFM componentとDeep componentの出力を足し合わせた値を出力する


\begin{aligned}
y_{Cat}(\boldsymbol{x}) &= y_{FM}(\boldsymbol{x}) + y_{Deep}(\boldsymbol{x}) \\
\end{aligned}
where, 
\begin{aligned}
y_{Cat}(\boldsymbol{x})
\end{aligned}
: the output of CatNN model,

\begin{aligned}
y_{FM}(\boldsymbol{x})
\end{aligned}
: the output of FM component,

\begin{aligned}
y_{Deep}(\boldsymbol{x})
\end{aligned}
: the output of Deep component


\begin{aligned}
y_{FM}(\boldsymbol{x})
\end{aligned}
は特徴間の線形およびpair-wiseな相互作用を表す


\begin{aligned}
y _ {FM}(\boldsymbol{x}) = w _ {0} + \langle \boldsymbol{w} , \boldsymbol{x} \rangle + \sum^{d}_{i=1} \sum^{d}_{j=i+1} \langle E_{ \boldsymbol{V}_{i}} (x_{i}), E_{ \boldsymbol{V}_{j}} (x_{j}) \rangle x_{i} x_{j}
\end{aligned}
where, 
\begin{aligned}
d
\end{aligned}
: the number of features,

\begin{aligned}
w_{0}, \boldsymbol{w}
\end{aligned}
: the hyperparameters of linear part,

\begin{aligned}
\langle \cdot , \cdot \rangle
\end{aligned}
: the inner product operation

ここで,


\begin{aligned}
E_{ \boldsymbol{V} _{i}} (x_{i}) = embedding \_ lookup( \boldsymbol{V}_{i}, x_{i} )
\end{aligned}
where, 
\begin{aligned}
x_{i}
\end{aligned}
: the value of i-th feature

\begin{aligned}
\boldsymbol{V} _ {i}
\end{aligned}
: all embeddings of the i-th feature

であり, 
\begin{aligned}
E _ { \boldsymbol{V} _ {i}} (x _ {i})
\end{aligned}

\begin{aligned}
x _ {i}
\end{aligned}
に対応するembedding vectorを返す


\begin{aligned}
y_{Deep}(\boldsymbol{x})
\end{aligned}
は特徴間のより高次な相互作用を表す


\begin{aligned}
y _ {Deep}(\boldsymbol{x}) = \boldsymbol{N} ( [ 
E _ {\boldsymbol{V} _ {1}} (x _ {1}) ^ {T}, 
E _ {\boldsymbol{V} _ {2}} (x _ {2}) ^ {T}, ... ,
E _ {\boldsymbol{V} _ {d}} (x _ {d}) ^ {T}
] ^ {T}
 ; \boldsymbol{\theta} )
\end{aligned}

DeepGBM

DeepGBMは, 
\begin{aligned}
y _ {GBDT2NN}\left(\boldsymbol{x}\right)
\end{aligned}

\begin{aligned}
y _ {Cat}(\boldsymbol{x})
\end{aligned}
を重み付き和して,sigmoid関数などで変換した値を出力する


\begin{aligned}
\hat{y} &= \sigma{\prime} (w_{1} \times y_{GBDT2NN}(\boldsymbol{x}) + w_{2} \times y_{Cat}(\boldsymbol{x})) \\
\end{aligned}
where, 
\begin{aligned}
w _ {1}, w _ {2}
\end{aligned}
: the trainable parameters,

\begin{aligned}
\sigma{\prime}
\end{aligned}
: the output transformation, such as sigmoid for binary classification

以上のモデル特性は下記表に纏められている

f:id:behemoth03:20191104030629p:plain
Table 1: Comparison over different models.

学習について

DeepGBMはオフライン学習とオンライン学習で学習方法が異なる

オフライン学習 (End-to-end offline learning)

① GBDT,CatNNの学習

これらのモデルは通常の学習を行う

leaf embeddingの学習

GBDTの複数木の場合,leaf embeddingの学習プロセスは下記で表される


\begin{aligned}
\min_{\boldsymbol{w}, w_{0}, \omega^{\mathbb{T}}}  \frac{1}{n} \sum^{n}_{i=1} \mathcal {L''} 
\left( \boldsymbol{w}^{T} \mathcal{H} ( \|_{t \in \mathbb{T}} (\boldsymbol{L}^{t,i}) ; \omega^{\mathbb{T}} ) + w_{0}, 
\sum_{t \in \mathbb{T}} p^{t,i} \right) 
\end{aligned}
where, 
\begin{aligned}
\boldsymbol{w}, w_{0}
\end{aligned}
: the parameters for mapping embedding to leaf values,

\begin{aligned}
\mathcal {L''}
\end{aligned}
: the same loss function as used in tree learning,

\begin{aligned}
\boldsymbol{H}^{t,i} = \mathcal{H} ( \boldsymbol{L}^{t,i} ; \omega^{t} )
\end{aligned}
: an one-layered fully connected network with parameter 
\begin{aligned}
\omega^{t}
\end{aligned}
,

\begin{aligned}
 \| ( \cdot )
\end{aligned}
: the concatenate operation,

\begin{aligned}
p^{t,i}
\end{aligned}
: the predict leaf value of sample
\begin{aligned}
\boldsymbol{x}^{i}
\end{aligned}

leaf embeddingによってone-hotなleaf index
\begin{aligned}
\boldsymbol{L}^{t,i}
\end{aligned}
からdense embedding
\begin{aligned}
\boldsymbol{H}^{t,i}
\end{aligned}
へと変換される.

ここで,tree output
\begin{aligned}
p^{t}
\end{aligned}
は,ある木
\begin{aligned}
t
\end{aligned}
leaf valuesを
\begin{aligned}
\boldsymbol{q}^{t}
\end{aligned}
と表し,
\begin{aligned}
q^{t}_{i}
\end{aligned}
をi番目のleafleaf valueとすると, 
\begin{aligned}
p^{t}=\boldsymbol{L}^{t} \times \boldsymbol{q}^{t}
\end{aligned}
で表される.

また,


\begin{aligned}
G^{\mathbb{T},i} = \mathcal{H} (\|_{t \in \mathbb{T}} (\boldsymbol{L}^{t,i}) ; \omega^{\mathbb{T}} )
\end{aligned}

は,パラメータ
\begin{aligned}
\omega^{\mathbb{T}}
\end{aligned}
をもつ1層のFCNNであり,multiple one-hot leaf index vectorsが結合したmulti-hot vectorsをdense embedding 
\begin{aligned}
G^{\mathbb{T},i}
\end{aligned}
へ変換する.

③ GBDT2NNの学習(構造の蒸留)

NNモデルの蒸留の標的として新たなembeddingを用いる.

学習プロセスは下記で表される.


\begin{aligned}
\mathcal {L}^{\mathbb{T}} = \min_{ \theta ^ { \mathbb{T} } } \frac{1}{n} \sum^{n}_{i=1} \mathcal {L} \left( N ( x^{i} [ \mathbb{I ^ {T}} ] ; \theta ^ {\mathbb{T}} ) , G ^ { \mathbb{T} , i} \right)
\end{aligned}
where, 
\begin{aligned}
\mathbb{I^{T}}
\end{aligned}
: the used features in tree group 
\begin{aligned}
\mathbb{T}
\end{aligned}
.


\begin{aligned}
\mathbb{I^{T}}
\end{aligned}
は,多くの特徴量が含まれて,特徴選択の能力が損なわれる可能性があることから,feature importanceに従って重要度の高い特徴量のみ使うようにする.

④ DeepGBMの学習

オフラインデータを使って,まずGBDTとCatNNの学習を行なった後,end-to-endでDeepGBMの学習を行う.

ロス関数は以下で表される.


\begin{aligned}
\mathcal {L}_{offline} = \alpha \mathcal {L''} ( \hat {y} (\boldsymbol {x} ) , y ) + \beta \sum^{k}_{j=1} \mathcal {L}^{\mathbb{T}_{j}} 
\end{aligned}
where, 
\begin{aligned}
y
\end{aligned}
: the training target of sample 
\begin{aligned}
\boldsymbol{x}
\end{aligned}
,

\begin{aligned}
k
\end{aligned}
: the number of tree groups,

\begin{aligned}
\alpha, \beta
\end{aligned}
: hyper-parameters used for controlling the strength of end-to-end loss and embedding, respectively.

オンライン学習 (Online update)

オフライン学習されたGBDTを,embeddingをonliine updateで学習させるために使用することはオンラインでのリアルタイムの性能を損なうと考えられる.

そのため,online updateではロス関数に
\begin{aligned}
\mathcal {L}^{\mathbb {T}}を含めない.
\end{aligned}


\begin{aligned}
\mathcal {L}_{online} = \mathcal {L''} ( \hat {y} (\boldsymbol {x} ) , y ) 
\end{aligned}

実行方法

DeepGBMの実装はgithubに公開されているので,DLして試すことが出来ます.

experimentsディレクトリに移動して,README.mdに書かれている通りの手続きを踏めば良いです.