Deep Cox Mixtures for Survival Regression
Chirag Nagpal, Steve Yadlowsky, Negar Rostamzadeh, Katherine Heller

TL;DR
This paper introduces Deep Cox Mixtures, a novel survival regression model that combines mixture modeling with deep neural networks to improve survival predictions, especially for minority groups, outperforming existing methods in accuracy and calibration.
Contribution
The paper proposes a new mixture-based Cox regression approach with an efficient EM algorithm and deep hazard modeling, enhancing survival analysis accuracy and fairness.
Findings
Outperforms classical and modern survival models in predictive accuracy.
Achieves better calibration in healthcare survival predictions.
Shows significant improvements for minority demographic groups.
Abstract
Survival analysis is a challenging variation of regression modeling because of the presence of censoring, where the outcome measurement is only partially known, due to, for example, loss to follow up. Such problems come up frequently in medical applications, making survival analysis a key endeavor in biostatistics and machine learning for healthcare, with Cox regression models being amongst the most commonly employed models. We describe a new approach for survival analysis regression models, based on learning mixtures of Cox regressions to model individual survival distributions. We propose an approximation to the Expectation Maximization algorithm for this model that does hard assignments to mixture groups to make optimization efficient. In each group assignment, we fit the hazard ratios within each group using deep neural networks, and the baseline hazard for each mixture component…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Code & Models
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsStatistical Methods and Inference · Insurance, Mortality, Demography, Risk Management · Machine Learning in Healthcare
