Applications of (Bayesian) variational inference?
Statistical Modeling, Causal Inference, and Social Science 2024-12-17
I’m curious about whether anyone’s using variational inference, and more specifically, using variational approximations to estimate posterior expectations for applied work. And if so, what kinds of reactions have you gotten from readers or reviewers?
I see a lot of talk in papers about how variational inference (VI) scales better than MCMC at the cost of only approximating the posterior. MCMC, which is often characterized as “approximate”, is technically asymptotically exact. MCMC’s approximation is not very many decimal places of accuracy rather than bias, at least in cases where MCMC can sample the posterior.
But I don’t recall ever seeing anyone use VI for inference in applied statistics. In particular, I’m curious if there are any Bayesian applications of VI, by which I mean applications where the variational approximation is used to estimate Bayesian posterior expectations in the usual way for an applied statistics problem of interest. That is, I’m wondering if anyone uses a variational approximation q(theta | phi)
, where phi
is fixed as usual, to approximate a Bayesian posterior p(theta | y)
and use it to estimate expectations as follows.
E[f(theta) | y] = INTEGRAL f(theta) q(theta | phi) d.theta.
This could be computed with Monte Carlo when it is possible to sample from q(theta | phi)
.
I’m using our Pathfinder variational inference system (now in Stan) to initialize MCMC, but I wouldn’t trust inference based on Pathfinder because of the very restrictive variational family (i.e., multivariate normal with low rank plus diagonal covariance). Similarly, most of the theoretical results I’ve been seeing around VI are for normal approximating families, particularly of the mean field (diagonal covariance) variety. Mean field approximations are easy to manipulate theoretically and computationally, but seem to make poor candidates for predictive inference, where there is often substantial posterior correlation and non-Gaussianity.
I know that there are ML applications to autoencoding that use variational inference, but I’m specifically asking about applied statistics applications that would be published in an applied journal, not a stats methodology or ML journal. I’ve seen some applications of point estimates from VI to “fit” latent Dirichlet allocation (LDA) models, but the ones I’ve seen don’t compute any expectations other than point estimates of parameters from a local mode among combinatorially many modes.
I’m curious about applications using ML techniques like normalizing flows as the variational family. I would expect those to be of more practical interest to applied statisticians than all the VI that has come before. I’ve seen cases where VI outperforms NUTS from Abhinav Agrawal and Justin Domke using a 10-layer deep, 20-ish neuron wide, real non-volume preserving (realNVP) flow touched up with importance sampling—their summary paper’s still under review and Abhinav’s thesis is being revised. But it requires a lot of compute, which isn’t cheap these days. The cases where realNVP outperforms include funnels, multimodal targets, bananas and other varying curvature models (like from an IRT 2PL posterior). I suspect the costs and accessibility of the equivalent of an NVIDIA H100 GPU will drop to a point where everyone will be able to use these methods in 10 years. It’s what I’m spending at least half my time on these days—JAX is fun and ChatGPT can (help) translate Stan programs pretty much line for line into JAX.