Share your thoughts, 1 month free Claude Pro on usSee more
WorkDL logo mark

Vanilla Gradient Descent for Oblique Decision Trees

About

Decision Trees (DTs) constitute one of the major highly non-linear AI models, valued, e.g., for their efficiency on tabular data. Learning accurate DTs is, however, complicated, especially for oblique DTs, and does take a significant training time. Further, DTs suffer from overfitting, e.g., they proverbially "do not generalize" in regression tasks. Recently, some works proposed ways to make (oblique) DTs differentiable. This enables highly efficient gradient-descent algorithms to be used to learn DTs. It also enables generalizing capabilities by learning regressors at the leaves simultaneously with the decisions in the tree. Prior approaches to making DTs differentiable rely either on probabilistic approximations at the tree's internal nodes (soft DTs) or on approximations in gradient computation at the internal node (quantized gradient descent). In this work, we propose DTSemNet, a novel semantically equivalent and invertible encoding for (hard, oblique) DTs as Neural Networks (NNs), that uses standard vanilla gradient descent. Experiments across various classification and regression benchmarks show that oblique DTs learned using DTSemNet are more accurate than oblique DTs of similar size learned using state-of-the-art techniques. Further, DT training time is significantly reduced. We also experimentally demonstrate that DTSemNet can learn DT policies as efficiently as NN policies in the Reinforcement Learning (RL) setup with physical inputs (dimensions $\leq32$). The code is available at https://github.com/CPS-research-group/dtsemnet.

Subrat Prasad Panda, Blaise Genest, Arvind Easwaran, Ponnuthurai Nagaratnam Suganthan• 2024

Related benchmarks

TaskDatasetResultRank
Reinforcement LearningAcrobot v1
Mean Return78.98
42
Reinforcement LearningLunarLander v2
Final Return257.6
30
Reinforcement LearningCartPole v1
Return326.8
16
Reinforcement LearningMountainCar v0
Cumulative Reward200
7
Reinforcement LearningHalfCheetah Hurdle
Cumulative Reward4.79e+3
7
Reinforcement LearningAnt CrossMaze
Reward1.01e+3
7
Continuous ControlPusher2D
Goal Distance0.14
7
Continuous ControlAnt RandomGoal
Goal Distance2.23
7
Reinforcement LearningAnt RandomGoal
Reward313.5
7
Reinforcement LearningPusher2D
Reward79.13
7
Showing 10 of 13 rows

Other info

Follow for update