论文标题

NBDT:神经支持的决策树

NBDT: Neural-Backed Decision Trees

论文作者

Wan, Alvin, Dunlap, Lisa, Ho, Daniel, Yin, Jihan, Lee, Scott, Jin, Henry, Petryk, Suzanne, Bargal, Sarah Adel, Gonzalez, Joseph E.

论文摘要

金融和医学等机器学习应用需要准确且合理的预测,禁止大多数深度学习方法使用。作为回应,先前的工作将决策树与深度学习结合在一起,产生了(1)为准确性而牺牲可解释性的模型,或(2)为可解释性牺牲准确性。我们通过使用神经支持的决策树(NBDT)共同提高准确性和可解释性来放弃这一难题。 NBDT用一系列可区分的决策序列和替代损失代替了神经网络的最终线性层。这迫使模型学习高级概念并减少对高度统一决策的依赖,得出的(1)准确性:NBDTS匹配或胜过CIFAR,ImageNet上的现代神经网络,并且更好地概括了多达16%的班级。此外,我们的替代损失将原始模型的准确性提高了2%。 NBDT还提供(2)可解释性:改善人类的信任度清楚地识别模型错误并协助数据集调试。代码和预读的NBDT在https://github.com/alvinwan/neural-backed-decision-trees上。

Machine learning applications such as finance and medicine demand accurate and justifiable predictions, barring most deep learning methods from use. In response, previous work combines decision trees with deep learning, yielding models that (1) sacrifice interpretability for accuracy or (2) sacrifice accuracy for interpretability. We forgo this dilemma by jointly improving accuracy and interpretability using Neural-Backed Decision Trees (NBDTs). NBDTs replace a neural network's final linear layer with a differentiable sequence of decisions and a surrogate loss. This forces the model to learn high-level concepts and lessens reliance on highly-uncertain decisions, yielding (1) accuracy: NBDTs match or outperform modern neural networks on CIFAR, ImageNet and better generalize to unseen classes by up to 16%. Furthermore, our surrogate loss improves the original model's accuracy by up to 2%. NBDTs also afford (2) interpretability: improving human trustby clearly identifying model mistakes and assisting in dataset debugging. Code and pretrained NBDTs are at https://github.com/alvinwan/neural-backed-decision-trees.

扫码加入交流群

加入微信交流群

微信交流群二维码

扫码加入学术交流群,获取更多资源