Estimating Generalization Error with the PRESS statistic
Win-Vector Blog 2014-09-25
As we’ve mentioned on previous occasions, one of the defining characteristics of data science is the emphasis on the availability of “large” data sets, which we define as “enough data that statistical efficiency is not a concern” (note that a “large” data set need not be “big data,” however you choose to define it). In particular, we advocate the use of hold-out data to evaluate the performance of models.
There is one caveat: if you are evaluating a series of models to pick the best (and you usually are), then a single hold-out set is strictly speaking not enough. Hastie, et.al, say it best:
Ideally, the test set should be kept in a “vault,” and be brought out only at the end of the data analysis. Suppose instead that we use the test-set repeatedly, choosing the model with smallest test-set error. Then the test set error of the final chosen model will underestimate the true test error, sometimes substantially.
The ideal way to select a model from a set of candidates (or set parameters for a model, for example the regularization constant) is to use a training set to train the model(s), a calibration set to select the model or choose parameters, and a test set to estimate the generalization error of the final model.
In many situations, breaking your data into three sets may not be practical: you may not have very much data, or the the phenomena you’re interested in are rare enough that you need a lot of data to detect them. In those cases, you will need more statistically efficient estimates for generalization error or goodness-of-fit. In this article, we look at the PRESS statistic, and how to use it to estimate generalization error and choose between models.
The PRESS Statistic
You can think of the PRESS statistic as an “adjusted sum of squared error (SSE).” It is calculated as
Where n is the number of data points in the training set, yi is the outcome of the ith data point, and fi is the prediction for yi from a model that is trained using all the data except the ith data point. In other words, the PRESS statistic is the SSE from hold-one-out cross-validation; it tries to estimate how the model will perform on hold-out data, using only in-sample data.
For example, if you wanted to calculate the PRESS statistic for linear regression models in R, you could do it this way (though I wouldn’t recommend it):
# For explanation purposes only -# DO NOT implement PRESS this waybrutePRESS.lm = function(fmla, dframe, outcome) { npts = dim(dframe)[1] ssdev = 0 for(i in 1:npts) { # a data frame with all but the ith row d = dframe[-i,] # build a model using all but pt i m = lm(fmla, data=d) # then predict outcome[i] pred = predict(m, newdata=dframe[i,]) # sum the squared deviations ssdev = ssdev + (pred - outcome[i])^2 } ssdev}
We have implemented a couple of helper functions to calculate the PRESS statistic (and related measures) for linear regression models more efficiently. You can find the code here. The function hold1OutLMPreds(fmla, dframe)
returns the vector f
, where f[i] is the prediction on the ith row of dframe
, when fitting the linear regression model described by fmla
on dframe[-i,]
. The function hold1OutMeans(y)
returns a vector g
where g[i] = mean(y[-i])
. With these function, you can efficiently calculate the PRESS statistic for a linear regression model:
hopreds = hold1OutLMPreds(fmla, dframe)devs = y-hopredsPRESS = sum(devs^2)
One disadvantage of the SSE (and the PRESS) is that they are dependent on the data size; you can’t compare a single model’s performance across data sets of different size. You can remove that dependency by going to the root mean squared error (RMSE): rmse = sqrt(sse/n)
, where n
is the size of the data set. You can also calculate an equivalent “root mean PRESS” statistic:
n = length(y)hopreds = hold1OutLMPreds(fmla, dframe)devs = y-hopredsrmPRESS = sqrt(mean(devs^2))
And you can also define a “PRESS R-squared”:
n = length(y)hopreds = hold1OutLMPreds(fmla, dframe)homeans = hold1OutMeans(y)devs = y-hopredsdely = y-homeansPRESS = sum(devs^2)PRESS.r2= 1 - (PRESS/sum(dely^2))
The “PRESS R-squared” is one minus the ratio of the model’s PRESS over the “PRESS of y’s mean value;” it adjusts the estimate of how much variation the model explains by using 1-fold cross validation rather than adjusting for the model’s degrees of freedom (as the more standard adjusted R-square does).
You might also consider defining a PRESS R-squared using the in-sample total error (y-mean(y)
) instead of the 1-hold-out mean; we decided on the latter in an “apples-to-apples” spirit. Note also that PRESS R-squared can be negative if the model is very poor.
An Example
Let’s imagine a situation where we want to predict a quantity y, and we have many many potential inputs to use in our prediction. Some of these inputs are truly correlated with y; some of them are not. Of course, we don’t know which are which. We have some training data with which to build models, and we will get (but don’t yet have) hold-out data to evaluate the final model. How might we proceed?
First, let’s create a process to simulate this situation:
# build a data frame with pure noise columns# and columns weakly correlated with ybuildExample1This function will produce a dataset of
nRows
rows with 20 columns that are weakly correlated (calledcor_1, cor_2...
) withy
and 300 columns (noise_1, noise_2...
) that are independent ofy
. The process is designed so that the noise columns and the correlated columns have similar magnitudes and variances. The outcome can be expressed as a linear combination of the correlated inputs, so a linear regression model should give reasonable predictions.Let's suppose we have two candidate models: one which uses all the variables, and one which magically uses only the intentionally correlated variables.
set.seed(22525)train = buildExample1(1000)output = "y"inputs = setdiff(colnames(train), output)truein = inputs[grepl("^cor",inputs)]# all variables, including noise# (noisy model)fmla1 = paste(output, "~", paste(inputs, collapse="+"))mod1 = lm(fmla1, data=train)# only true inputs# (clean model)fmla2 = paste(output, "~", paste(truein, collapse="+"))mod2 = lm(fmla2, data=train)We can extract all the model coefficients that
lm()
deemed significant to p# 0.05 = "*" in the model summarysigCoeffs = function(model, pmax=0.05) { cmat = summary(model)$coefficients pvals = cmat[,4] plo = names(pvals)[pvalsIn other words, several of the noise inputs appear to be correlated with the output in the training data, just by chance. This means that the noisy model has overfit the data. Can we detect that? Let's look at the SSE and the PRESS:
## name sse PRESS## 1 noisy model 203.3 448.6## 2 clean model 285.8 306.8Looking at the in-sample SSE, the noisy model looks better than the clean model; the PRESS says otherwise. We can see the same thing if we look at the R-squared style measures:
## name R2 R2adj PRESSr2## 1 noisy model 0.7931 0.6956 0.5442## 2 clean model 0.7091 0.7031 0.6884Again, R-squared makes the noisy model look better than the clean model. The adjusted R-squared correctly indicates that the additional variables in the noisy model do not improve the fit, and slightly prefers the clean model. The PRESS R-squared identifies the clean model as the better model, with a much larger margin of difference than the adjusted R-squared.
The PRESS statistic versus Hold-out Data
Of course, while the PRESS statistic is statistically efficient, it is not always computationally efficient, especially with modeling techniques other than linear regression. The calculation of the adjusted R-squared is not computationally demanding, and it also identified the better model in our experiment. One could ask, why not just use adjusted R-squared?
One reason is that the PRESS statistic is attempting to directly model future predictive performance. Our experiment suggests that it shows clearer distinctions between the models than the adjusted R-squared. But how well does the PRESS statistic estimate the "true" generalization error of a model?
To test this, we will hold the ground truth (that is, the data generation process) and the training set fixed. We will then repeat generating test sets, measuring the RMSE of the models' predictions against the test sets, and compare them to the training RMSE and root mean PRESS. This is akin to a situation where the training data and model fitting are accomplished facts, and we are hypothesizing possible future applications of the model.
Specifically, we used
buildExample1()
to generate one hundred tests sets of size 100 (10% the size of the training set) and one hundred tests sets of size 1000 (the size of the training set). We then evaluated both the clean model and the noisy model against all the test sets and compared the distributions of the hold-out root mean squared error (RMSE) against the in-sample RMSE and PRESS statistics. The results are shown below.For each plot, the solid black vertical line is the mean of the distribution of test RMSE; we can assume that the observed mean is a good approximation to the "true" expected RMSE of the model. Not surprisingly, a smaller test set size leads to more variance in the observed RMSE, but after 100 trials, both the n=100 and n=1000 hold out sets lead to similar estimates of the expected RMSE (just under 0.7 for the noisy model, just under 0.6 for the clean model.
The dashed red lines give the root mean PRESS of both models on the training data, and the dashed blue lines give each models' training set RMSE. For both the noisy and clean models, the root mean PRESS gives a better estimate of the models' expected RMSE than the training set RMSE -- dramatically so with the noisy, overfit model.
Note, however, that in this experiment, a single hold-out set reliably preferred the clean model to the noisy one (that is, the hold-out SSE was always greater for the noisy model than the clean one when both models were applied to the same test data). The moral of the story: use hold-out data (both calibration and test sets) when that is feasible. When data is at a premium, then try more statistically efficient metrics like the PRESS statistic to "stretch" the data that you have.