Learning Tree-Based Models with Gradient Descent
Sascha Marton

TL;DR
This paper introduces a novel gradient descent-based method for learning decision trees, enabling joint optimization of all parameters and seamless integration into modern ML workflows, leading to improved performance across diverse applications.
Contribution
It presents a new approach for training decision trees with gradient descent, overcoming traditional limitations and enabling joint optimization and integration into existing ML methods.
Findings
Achieved state-of-the-art results on small tabular datasets.
Enabled interpretable reinforcement learning without information loss.
Improved performance on complex tabular data and multimodal tasks.
Abstract
Tree-based models are widely recognized for their interpretability and have proven effective in various application domains, particularly in high-stakes domains. However, learning decision trees (DTs) poses a significant challenge due to their combinatorial complexity and discrete, non-differentiable nature. As a result, traditional methods such as CART, which rely on greedy search procedures, remain the most widely used approaches. These methods make locally optimal decisions at each node, constraining the search space and often leading to suboptimal tree structures. Additionally, their demand for custom training methods precludes a seamless integration into modern machine learning (ML) approaches. In this thesis, we propose a novel method for learning hard, axis-aligned DTs through gradient descent. Our approach utilizes backpropagation with a straight-through operator on a dense DT…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsExplainable Artificial Intelligence (XAI) · Imbalanced Data Classification Techniques · Adversarial Robustness in Machine Learning
