Differentiable Filtering for Learning Hidden Markov Models
Reginald Zhiyan Chen, Heng-Sheng Chang, Prashant G. Mehta

TL;DR
This paper introduces Belief Net, a differentiable filtering framework that learns HMM parameters efficiently via neural network optimization, outperforming classical and spectral methods in convergence and recovery accuracy.
Contribution
It proposes a novel neural network-based approach for learning HMMs that is interpretable, faster, and effective in complex settings compared to traditional algorithms.
Findings
Belief Net converges faster than Baum-Welch on synthetic data.
It successfully recovers parameters in overcomplete HMMs where spectral methods fail.
On language data, it compares favorably with transformer models.
Abstract
Hidden Markov Models (HMMs) are fundamental for modeling sequential data, yet learning their parameters from observations remains challenging. Classical methods like the Baum-Welch algorithm are computationally intensive and prone to local optima, while modern spectral algorithms offer provable guarantees but may produce probability outputs outside valid ranges. This work introduces Belief Net, a differentiable filtering framework that learns HMM parameters by formulating the forward filter as a structured neural network and optimizing it with stochastic gradient descent. This architecture recursively updates the belief state, which represents the posterior probability distribution over hidden states based on the observation history. Unlike black-box transformer models, Belief Net's learnable weights are explicitly the logits of the initial distribution, transition matrix, and emission…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
