Softmax is on the log, not the logit scale
Statistical Modeling, Causal Inference, and Social Science 2024-12-26
Bad Stan naming
I realized recently that we followed the confusing terminological convention of ML in our description of Stan’s categorical_logit
function. In Stan, if there’s a suffix to a distribution, it describes the scale of one or more of the parameters. For example,
poisson_log(y | u) == poisson(y | exp(u))
.
So when we write categorical(y | p)
we take p
to be a simplex (sequence of finite, non-negative values that sum to 1). So it would make sense that categorical_logit(y | logit(p))
would be equivalent, where logit(p) = log(p / (1 - p))
. But that’s not how it works in Stan. Instead,
caetgorical_logit(y | u) = categorical(y | softmax(u))
.
We made the same mistake everyone on ML makes in their variable naming! We call the u
here “logits”, when in fact they’re (unnormalized [see below]) log probabilities. This is probably due to the fact that if u
is a regression, then the resulting system is called “multinomial logistic regression.”
Example
The softmax function is defined by softmax(u) = exp(u) / sum(exp(u))
. When used like this, the arguments to softmax are log probabilities, not logit probabilities. Here’s a little snippet of Python to illustrate (the style sheet is adding the extra space, not me, and I don’t want to fix it manually in this post with a hack because it’ll mess up the page if the style sheet is ever fixed).
>>> p = np.asarray([0.2, 0.5, 0.3])
>>> def logit(p): return np.log(p / (1 - p))
...
>>> logit_p = logit(p)
>>> log_p = np.log(p)
>>> sp.special.softmax(logit_p)
array([0.14893617, 0.59574468, 0.25531915])
>>> sp.special.softmax(log_p)
array([0.2, 0.5, 0.3])
This shows that for the round trip probabilities through softmax, the appropriate operation is the natural logarithm, not the logit function.
Origin of the confusion
So where did this confusion come from? Let’s look at a standard binary logistic regression. There we take
p(y | alpha, beta, x) = bernoulli(y | inv_logit(alpha + beta * x))
where
inv_logit(v) = exp(v) / (1 + exp(v))
.
Writing inverse logit this way suggests how to write a logistic regression with a categorical distribution and softmax.
p(y | alpha, beta, x) = categorical(y | softmax([0, alpha + beta * x]))
that’s because
softmax([0, alpha + beta * x])
= [exp(0), exp(alpha + beta * x)] / (exp(0) + exp(alpha + beta * x))
= [1, exp(alpha + beta(x)] / (1 + exp(alpha + beta * x))
= [1 / (1 + exp(alpha + beta * x), exp(alpha + beta * x) / (1 + exp(alpha + beta * x)]
= [1 - inv_logit(alpha + beta * x), inv_logit(alpha + beta * x)],
This derivation shows that the probability of the categorical in this formulation returning 1 is inv_logit(alpha + beta * x)
. But this connection falls apart in the multinomial case when there are more than two outcomes.
In traditional frequentist K
outcome multinomial logistic regressions, the first input to softmax is pinned to 0 for identifiability just as in the binary case.
softmax([0, u[2], ..., u[K1])
= [exp(0), exp(u[2]), ..., exp(u[K])] / (exp(0) + exp(u[2]) + ... + exp(u[K]))
This leads to asymmetry in the regression as we don’t have a regression for the first element. What it does do is make softmax and log proper inverses. If you reduce the choice to just the first category and some other category, then you get a standard binomial logistic regression again. But you still can’t round trip the multinomial case with logit, because
exp(u[2]) / (exp(0) + exp(u[2]) + ... + exp(u[K])) != inv_logit(u[2])
To see that this is still not going to produce logits in the multinomial case, here’s some more Python.
>>> log_p
array([-1.60943791, -0.69314718, -1.2039728 ])
>>> log_p_zero = log_p - log_p[0]
>>> log_p_zero
array([0. , 0.91629073, 0.40546511])
>>> sp.special.softmax(log_p_zero)
array([0.2, 0.5, 0.3])
So as you can see, softmax isn’t identified without pinning one of the values—we can add or subtract a constant from each element of the input and get the same value. But this still doesn’t turn the inputs to softmax into logits.
>>> def inv_logit(v): 1 / (1 + exp(-v)) …
>>> inv_logit(log_p_zero) array([0.5 , 0.71428571, 0.6 ])
So you can see that the input 0.91629073
is not the logit of the probability even when pinning a value to zero to identify.
P.S. I really miss being able to write math on the blog and really hate that all my old posts with math no longer render. Maybe if Andrew reminds us why it went away, someone will have a suggestion on how to fix.