Calibration for everyone and every decision problem, maybe

Statistical Modeling, Causal Inference, and Social Science 2024-11-15

This is Jessica. Continuing on the theme of calibration for decision-making I’ve been writing about, recall that a calibrated prediction model or algorithm is one for which, for all predictions that are b, the event realizes at a rate of b, and this is true for all values of b. In practice we settle for approximate calibration and do some binning.

As I mentioned in my last post in calibration for decision-making, an outcome indistinguishability definition of perfect calibration says that the true distribution of individual-outcome pairs (x, y*) is indistinguishable from the distribution of pairs (x, ytilde) generated by simulating a y using the true (empirical) probability, ytilde ∼ Ber(ptilde(x)). This makes clear that calibration can be subject to multiplicity, i.e., there may be many models that are calibrated over X overall but which make different predictions on specific instances.

This is bad for decision-making in a couple ways. First, being able to show that predictions are calibrated in aggregate doesn’t prevent disparities in error rates across subpopulations defined on X. One perspective on why this is bad is socially motivated. A fairness critique of the standard ML paradigm where we minimize empirical risk is that we may be trading off errors in ways that disadvantage some groups more than others. Calibration and certain notions of fairness, like equalized odds, can’t be achieved simultaneously. Going a step further, ideally we would like calibrated model predictions across subgroups without having to rely on the standard, very coarse notions of groups defined by a few variables like sex and race. Intersectionality theory, for example, is concerned with the limitations of the usual coarse demographic buckets for capturing people’s identities. Even if we put social concerns aside, from a technical standpoint, it seems clear that the finer the covariate patterns on which we can guarantee calibration, the better our decisions will be.

The second problem is that a predictor being calibrated doesn’t tell us how informative it is for a particular decision task. Imagine predicting a two-dimensional state. We have two predictors that are equally well calibrated. The first predicts perfectly on the first dimension but is uninformed on the second. The second predicts perfectly on the second dimension but is uninformed on the first. Which predictor we prefer depends on the specifics of our decision problem, like what the loss function is. But once we commit to a loss function in training a model, it is hard to go back and get a different loss-minimizing predictor from the one we got. And so, a mismatch between the loss function we train the model with and the loss function that we care about when making decisions down the road can mean leaving money on the table in the sense of producing a predictor that is not as informative as it could have been for our downstream problem. 

Theoretical computer science has produced a few intriguing concepts to address these problems (at least in theory). They turn out to be closely related to one another. 

Predictors that are calibrated over many possibly intersecting groups

Multicalibration, introduced in 2018 in this paper, guarantees that for any possibly intersecting set of groups that are supported by your data and can be defined by some function class (e.g., decision trees up to some max depth), your predictions are calibrated. Specifically, if C is a collection of subsets of X and alpha takes a value in [0, 1], a predictor f is (C, α)- multicalibrated if for all S in C, f is alpha-calibrated with respect to S. Here alpha-calibrated means the expected difference between the “true” probability and the predicted probability is less than alpha.

This tackles the first problem above, that we may want calibrated predictions without sacrificing some subpopulations. It’s natural to think about calibration from a hierarchical perspective, and multicalibration is a way of interpolating between what is sometimes called “strong calibration”, i.e., calibration on every covariate pattern (which is often impossible) and calibration only in aggregate over a distribution. 

Predictors that simultaneously minimize loss for many downstream decision tasks 

Omniprediction, introduced here, starts with the slightly different goal of identifying predictors that can be used to optimize any loss in a family of loss functions relevant to downstream decision problems. This is useful because sometimes we often don’t know at the time when we are developing the predictive model what sort of utility functions the downstream decision-makers will be facing. 

An omnipredictor is defined with respect to a class of loss functions and (again) a class of functions or hypotheses (e.g., decision trees, neural nets). But this time the hypothesis class fixes a space of possible functions we can learn that will benchmark the performance of the (omni)predictor we get. Specifically, for a family L of loss functions, and a family C of hypotheses c : X → R, an (L,C)-omnipredictor is a predictor f with the property that for every loss function l in L, there is a post-processing function k_l such that the expected loss of the composition of k_l with f measured using l is almost as small as that of the best hypothesis c in C.

Essentially an omnipredictor is extracting all the predictive power from C, so that we don’t need to worry about the specific loss function that characterizes some downstream decision problem. Once we have such a predictor, we post-process it by applying a simple transformation function to its predictions, chosen based on the loss function we care about, and we’re guaranteed to do well compared with any alternative predictor that can be defined within the class of functions we’ve specified.  

Omnipredictors and multicalibration are related

Although the goal of omniprediction is slightly different than the fairness motivation for multicalibration, it turns out the two are closely related. Multicalibration can provide provable guarantees on the transferability of a model’s predictions to different loss functions. Multicalibrated predictors are convex omnipredictors. 

One way to understand how they are related is through covariance. Say we partition X into different disjoint subsets whose union is X. Multicalibration implies that for an average state in the partition of X, for any hypothesis c in C, conditioning on the label does not change the expectation of c(x). No c has extra predictive power on y given once we’re in one of these groups. We can think of multicalibration as predicting a model of nature that fools correlation tests from C. 

Finding such predictors (in theory) 

Naturally there is interest in how hard it is to find predictors that satisfy these definitions. For omnipredictors, a main result is that for any hypothesis class C, a weak agnostic learner for C is sufficient to efficiently compute an (L, C)-omnipredictor where L consists of all decomposable convex loss functions obeying mild Lipshcitz conditions. Weak agnostic learning here means that we can expect to be able to efficiently find a hypothesis in C that is considerably better than random guessing. Both multicalibration and omniprediction are related to boosting – if our algorithm returns a predictor that is only slightly better than chance, we can train another model to predict the error, and repeat this until our “weak learners” together give us a strong learner, one that is well correlated with the true function.

Somewhat surprisingly, the complexity of finding the omnipredictor turns out to be comparable to that of finding the best predictor for some single loss function. 

But can we achieve them in practice?  

If it wasn’t already obvious, work on these topics is heavily theoretical. Compared to other research related to quantifying prediction uncertainty like conformal prediction, there are far fewer empirical results. Part of the reason I took an interest in the first place was because I was hearing multiple theorists respond to discussions on how to communicating prediction uncertainty by proposing that this was easy because you could just use a multicalibrated model. Hmmm. 

It seems obvious that we need a lot of data to achieve multicalibration. How should researchers working on problems where the label spaces are very large (e.g., some medical diagnosis problems) or where the feature space is very high dimensional and hard to slice up think about these solutions? 

Here it’s interesting to contrast how more practically-oriented papers talk about calibration. For example, I came across some papers related to calibration for risk prediction in medicine where the authors seem to be arguing that multicalibration is mostly hopeless to expect in practice. E.g., this paper implies that “moderate calibration” (the definition of calibration at the start of this post) is all that is needed, and “strong calibration,’’ which requires that the event rate equals the predicted risk for every covariate pattern, is utopic and not worth wasting time on. The authors argue that moderate calibration guarantees non-harmful decision-making, sounding kind of like the claims about calibration being sufficient for good decisions that I was questioning in my last postAnother implies that the data demands for approaches like multicalibration are too extreme, but that it’s still useful to test for strong calibration to see whether there are specific groups that are poorly calibrated.

I have a more to say on the practical aspects, since as usual it’s the gap left between the theory and practice that intrigues me. But this post is already quite long so I’ll follow up on this next week.