The James-Stein Estimator

\[ \newcommand{\cB}{\mathcal{B}} \newcommand{\cF}{\mathcal{F}} \newcommand{\cN}{\mathcal{N}} \newcommand{\cP}{\mathcal{P}} \newcommand{\cX}{\mathcal{X}} \newcommand{\EE}{\mathbb{E}} \newcommand{\PP}{\mathbb{P}} \newcommand{\RR}{\mathbb{R}} \newcommand{\ZZ}{\mathbb{Z}} \newcommand{\td}{\,\textrm{d}} \newcommand{\simiid}{\stackrel{\textrm{i.i.d.}}{\sim}} \newcommand{\simind}{\stackrel{\textrm{ind.}}{\sim}} \newcommand{\eqas}{\stackrel{\textrm{a.s.}}{=}} \newcommand{\eqPas}{\stackrel{\cP\textrm{-a.s.}}{=}} \newcommand{\eqmuas}{\stackrel{\mu\textrm{-a.s.}}{=}} \newcommand{\eqD}{\stackrel{D}{=}} \newcommand{\indep}{\perp\!\!\!\!\perp} \DeclareMathOperator*{\minz}{minimize\;} \DeclareMathOperator*{\maxz}{minimize\;} \DeclareMathOperator*{\argmin}{argmin\;} \DeclareMathOperator*{\argmax}{argmax\;} \newcommand{\Var}{\textnormal{Var}} \newcommand{\Cov}{\textnormal{Cov}} \newcommand{\Corr}{\textnormal{Corr}} \]

1 Gaussian sequence model

Recall that we have discussed a variety of estimators for \(\theta \in \RR^d\) in the Gaussian sequence model

\[X \sim N_d(\theta, I_d)\]

Note that this model is somewhat more general than it appears. If \(X_1,\ldots,X_n \simiid N_d(\theta, \sigma^2 I_d)\) for known \(\sigma^2> 0\), we could make a sufficiency reduction to obtain

\[Z = \frac{1}{\sigma\sqrt{n}} \sum_i X_i \sim N_d(\theta, I_d).\] For simplicity we will discuss the ``vanilla’’ version here, but we can always translate our results to the more general setting via this transformation.

We’ll generally assume in what follows that the loss we care about is the squared error loss, summed over the coordinates: \[L(\theta, d) = \|\delta(X) - \theta\|^2 = \sum_j (\delta_j(X) - \theta_j)^2\]

The most obvious estimator is \(\delta_0(X) = X\) itself, which we could justify in a variety of ways: we’ve shown that it is the UMVU estimator for \(\theta\) and also the objective Bayes estimator, since the flat prior on \(\theta\) coincides with the Jeffreys prior (as it does for any location model). It also happens to be the maximum likelihood estimator (MLE), which we’ll discuss later in the course.

1.1 Bayes estimators

If we introduce the Bayesian prior \(\theta_i \simiid N(0,\tau^2)\) then we have seen that we arrive at the Bayes estimator \(\frac{\tau^2}{1+\tau^2}X\).

We can think of this as a tuning parameter for a generic linear shrinkage estimator \[\delta_\zeta(X) = (1-\zeta)X,\] where \(\zeta \in [0,1]\) is in effect a tuning parameter we will call the shrinkage parameter. Taking \(\zeta = 0\) corresponds to using \(X\) as our estimator for \(\theta\), and taking \(\zeta = 1/(1+\tau^2)\) corresponds to the Bayes estimator where \(\tau^2\) is known.

If we aren’t sure which \(\zeta\) to use, for example because we have some a priori uncertainty about \(\tau^2\), we can try to estimate it from the data using hierarchical Bayes, which we’ve seen would give the final estimator \[\delta(X) = (1 - \EE[\zeta \mid X]) X = \delta_{\hat\zeta_{\text{Bayes}}(X)}(X),\] so we are in effect estimating \(\zeta\) from the whole data set and then plugging it in as a data-adaptive tuning parameter.

The hierarchical Bayes estimator uses a Bayes estimator for \(\zeta\), but if we take an empirical Bayes approach we could try other estimators, such as the MLE or UMVU. If \(d \geq 3\) then the UMVU estimator for \(\zeta\) is \[\hat{\zeta}_{\text{UMVU}}(X) = \frac{d-2}{\|X\|^2},\] which we can verify using the identity \[\EE[1/Y] = \frac{1}{d-2}, \quad \text{ if } Y \sim \chi_d^2 = \text{Gamma}(d/2, 2), \text{ for } d > 2,\] which is proved in the handwritten notes. Plugging in \(\hat\zeta_\text{UMVU}\) results in an estimator called the James-Stein estimator, \[ \delta_{\text{JS}}(X)= \left(1 - \frac{d-2}{\|X\|^2}\right)X = \delta_{\hat\zeta_{\text{UMVU}}}(X) \]

1.2 James-Stein Paradox

While the James-Stein estimator can be motivated as an empirical Bayes estimator, it is surprisingly good even without making any Bayesian assumptions at all.

For \(d \geq 3\), the estimator \(X\) is actually inadmissible as an estimator of \(\theta\) under squared error loss:

\[\text{MSE}(\theta, \delta_{JS}) < \text{MSE}(\theta, X) \quad \text{ for all } \theta \in \mathbb{R}^d.\]

It is not surprising for a Bayes estimator to beat the UMVU estimator an average with respect to some prior, but this result holds for every fixed value of the parameter \(\theta\).

In fact, since there is nothing special about shrinking towards \(0\). We could use a version of the estimator that shrinks toward any other \(\theta_0 \in \RR^d\), i.e. \[ \tilde delta(X) = \theta_0 + \left(1 - \frac{d-2}{\|X - \theta_0\|^2}\right) (X - \theta_0)\]. This also dominates \(\delta_0\) because it is just the James-Stein estimator we’d get if we made the substitution \[Y = X - \theta_0 \sim N_d(\mu, I_d), \quad \text{ for } \mu = \theta - \theta_0.\] The translation-invariance of the Gaussian location model means that the James-Stein estimator for \(\mu\) using \(Y\) also dominates the estimator \(\hat\mu_0(Y) = Y\), which corresponds to the estimator \(\delta_0(X) = \hat\mu_0 + \theta_0 = X\) for \(\theta\).

This result was received as a shock in the 1950s when it first came out. It was regarded for a long time as a curiosity, but it was eventually understood to carry the deep implication that shrinkage makes sense, especially in higher-dimensional problems, even when we don’t have a Bayes justification for it.

1.3 Linear shrinkage estimators

Even without introducing a Bayesian prior for \(\theta\), we can motivate our linear shrinkage estimator purely from the perspective of trading a bit of bias for a reduction in variance.

We can start by calculating the MSE (considered as a purely frequentist risk function) for a single coordinate, using the bias-variance tradeoff: \[ \begin{aligned} \EE_\theta[(\theta - \delta_i(X))^2] &= (\theta_i - \EE_\theta (1-\zeta)X_i)^2 + \text{Var}_\theta (1-\zeta)X_i\\ &= (\zeta\theta_i)^2 + (1-\zeta)^2 \end{aligned} \] Summing over the \(d\) coordinates gives \[\text{MSE}(\theta; \delta) = \zeta^2\|\theta\|^2 + d(1-\zeta)^2,\] where the first term represents the squared bias and the second is the variance.

Note that the risk is a quadratic in \(\zeta\) with positive second derivative, so we can minimize it by setting \[0 = \frac{d}{d\zeta}\text{MSE}(\theta) = 2\zeta\|\theta\|^2 - 2(1-\zeta)d,\] leading to \[\zeta^*(\theta) = \frac{d}{d+\|\theta\|^2} = \frac{1}{1+\|\theta\|^2/d},\] which looks remarkably similar to \(\frac{1}{1+\tau^2}\), which is the Bayes-optimal \(\zeta\) under the Gaussian prior from the last section.

One thing to notice is that \(\zeta^*(\theta) > 0\), so a small amount of shrinkage helps. But the correct amount of shrinkage depends on \(\|\theta\|^2\): if \(\|\theta\|^2 \to \infty\), the correct amount of shrinkage goes to \(0\), so any fixed \(\zeta\) would overshoot for some \(\theta\) parameters.

It turns out the James-Stein estimator manages to estimate the correct amount of shrinkage from the data, in such a way that we avoid overshooting most of the time, and thereby improve on the MSE for any \(\theta\).

To understand why, we need a general way to calculate the MSE for an estimator with an adaptive \(\hat\zeta(X)\). Stein’s unbiased risk estimator will give us that.

2 Stein’s Unbiased Risk Estimator

2.1 Stein’s Lemma

The first ingredient in finding the MSE of \(\delta_\text{JS}\) is a lemma called Stein’s Lemma:

Theorem (Stein’s lemma, univariate): Suppose \(X \sim N(\theta, \sigma^2)\), and that \(h:\; \RR \to \RR\) is differentiable, with \(\EE|\dot{h}(X)| < \infty\). Then we have \[\Cov(X, h(X)) = \EE[(X-\theta)h(X)] = \sigma^2\EE[\dot{h}(X)].\]

Proof:

Next, we will do the calculation for \(\theta = 0\) and \(\sigma^2 = 1\). Then \[\EE[Xh(X)] = \int_{-\infty}^\infty xh(x)\phi(x)\,dx = \int_{-\infty}^\infty \dot{h}(x)\phi(x)\,dx,\] where we’ve used integration by parts: \[\dot{\phi}(x) = \frac{d}{dx} \frac{1}{\sqrt{2\pi}}e^{-x^2/2} = -x \frac{1}{\sqrt{2\pi}}e^{-x^2/2} = -x\phi(x).\] (A slightly more rigorous version is in the handwritten notes: we can assume wlog \(h(0) = 0\) because otherwise we could center it by using \(k(X) = h(X) - h(0)\), which has the same covariance with \(X\). Then we can break up the integral into an integral from \(0\) to \(\infty\) and another from \(-\infty\) to \(0\), and do integration by parts a bit more carefully for each.)

For more general \(\theta\), we can write \(X = \theta + \sigma Z\), where \(Z \sim N(0,1)\). Then applying the result for \(k(z) = h(\theta + \sigma z)\), we have

\[\EE[(X-\theta) h(X)] = \sigma\EE[Zh(\theta + \sigma Z)] = \sigma^2\EE[\dot{h}(\theta + \sigma Z)] = \sigma^2\EE[\dot{h}(X)],\] giving the general result.

We will need the multivariate version of Stein’s lemma. For a function \(h:\; \RR^d \to \RR^d\), define the Jacobian matrix \(Dh \in \RR^{d\times d}\) by \[(Dh(x))_{ij} = \frac{\partial h_i}{\partial x_j}(x).\]

Define the Frobenius norm \(A \in \RR^{d\times d}\) as \[\|A\|_F = \left(\sum_{ij} A_{ij}^2\right)^{1/2}.\]

Now we can state our theorem:

Theorem (Stein’s lemma, multivariate): Assume \(X \sim N_d(\theta; \sigma^2 I_d)\), and \(h:\;\RR^d \to \RR^d\) is differentiable with \(\EE\|Dh(X)\|_F < \infty\). Then \[ \EE[(X-\theta)'h(X)] = \sigma^2 \EE \text{tr}(Dh(X)) = \sigma^2 \sum_i \EE \frac{\partial h_i}{\partial x_i} (X).\]

Proof: The proof follows from the proof of the univariate version if we observe that the distribution of \(X_i\) conditional on the other coordinates \(X_{-i}\) is \(N(\theta, \sigma^2)\). Then we have \[ \EE\left[(X_i - \theta_i)h_i(X) \mid X_{-i}\right] = \sigma^2 \EE\left[\frac{\partial h_i}{\partial x_i}(X_i) \mid X_{-i} \right].\] Taking expectations gives \[ \EE\left[(X_i - \theta_i)h_i(X)\right] = \sigma^2 \EE\left[\frac{\partial h_i}{\partial x_i}(X_i) \right],\] and summing over \(i\) gives the result.

2.2 Stein’s unbaised risk estimator (SURE)

We can obtain an unbiased estimator of the MSE for almost any differentiable estimator \(\delta(X)\) in the Gaussian sequence model, if we apply Stein’s lemma to the function \(h(x) = x - \delta(x)\). We only need the derivative of \(h\) to satisfy the condition of Stein’s lemma, which it does for most differentiable estimators.

Using the identity \(\|a - b\|^2 = \|a\|^2 + \|b\|^2 - 2a'b\), we have for \(h(x) = x - \delta(x)\), \[ \begin{aligned} \text{MSE}(\theta; \delta) &= \EE_\theta\|\delta(X) - \theta\|^2\\ &= \EE_\theta\|X - h(X) - \theta\|^2\\ &= \EE_\theta\|X - \theta\|^2 + \EE_\theta\|h(X)\|^2 - 2\EE_\theta \left[(X - \theta)'h(X)\right]\\ &= \sigma^2 d + \EE_\theta\|h(X)\|^2 - 2\sigma^2 \EE_\theta \text{tr}(Dh(X)). \end{aligned} \] Thus, if \(\sigma^2\) is known, we obtain the unbiased estimator \[ \widehat{\text{MSE}}(X) = \sigma^2 d + \|h(X)\|^2 - 2 \sigma^2 \text{tr}(Dh(X)).\]

2.3 Example: shrinking toward \(\overline{X}\)

As an example, we can estimate the MSE of an estimator that shrinks \(X_i\) partway toward the average estimate across the \(d\) coordinates, \(\overline{X} = \frac{1}{d}\sum_i X_i\): \[ \delta_i^\gamma(X) = (1-\gamma) X_i + \gamma \overline{X}. \] We can think of this as making a bet that most of the \(\theta_i\) values are close to \(\bar{\theta} = \frac{1}{d}\sum_i \theta_i\). Then \[ h(X) = X - \delta^\gamma(X) = \gamma(X - \overline{X} 1_d), \] and \[ Dh(X)_{ii} = \frac{\partial}{\partial X_i} \gamma(X_i - \overline{X}) = \gamma (1-1/d) \Rightarrow \text{tr}(Dh(X)) = (d-1)\gamma \] An unbiased estimator for the MSE of \(\delta^\gamma\) is then \[ \widehat{\text{MSE}}^\gamma(X) = \sigma^2 d + \gamma^2(d-1)V^2 - 2(d-1) \gamma \sigma^2, \] where \(V^2 = \frac{1}{d-1}\sum_i (X_i-\overline{X})^2\) is the sample variance. Take note that the \(X_i\) values are not assumed to be i.i.d. here, since their means are generically different.

We can use this estimator in two different ways. The simplest way would be to calculate the actual MSE by taking the estimator’s expectation, which we know is the actual MSE of \(\delta^\gamma\). The only random variable is \(V^2\). Write \(X_i = \theta_i + Z_i\) where \(Z_i \sim N(0,\sigma^2)\). Then we have \[ \frac{1}{d-1}\sum_i\EE_\theta(X_i-\overline{X})^2 = \frac{1}{d-1}\sum_i (\theta_i - \bar{\theta})^2 + \frac{1}{d-1}\sum_i \EE(Z_i - \overline{Z})^2 = \beta^2 + \sigma^2, \] where \(\beta^2 = \frac{1}{d-1}\sum_i(\theta_i - \bar{\theta})^2\) is the sample variance of the \(\theta_i\) values. Plugging in this expectation and collecting terms, we obtain the MSE \[ \text{MSE}^\gamma(\theta) = \EE_\theta\left[\widehat{\text{MSE}}^\gamma(X)\right] = \sigma^2 + (d-1)(1-\gamma)^2\sigma^2 + (d-1)\gamma^2\beta^2. \] We can solve for the optimal value \(\gamma^*(\beta) = \frac{\sigma^2}{\beta^2+\sigma^2}\), which unsurprisingly depends on \(\beta^2\). If \(\beta^2 = 0\) then all \(\theta_i\) values are equal so we should set \(\gamma = 1\) (shrink fully to the sample mean), but if \(\beta^2 \gg \sigma^2\) we should take \(\gamma \to 0\) (shrink very little).

We could also choose \(\gamma\) adaptively to minimize this estimator. That is, we could take \[ \hat\gamma(X) = \argmin_\gamma \widehat{\text{MSE}}^\gamma(X) = \sigma^2/V^2, \] which we could think of as an estimator of \(\gamma^*(X)\) since \(V^2\) is unbiased for \(\beta^2+\sigma^2\). If we plug in \(\hat\gamma(X)\) we get a new adaptive shrinkage estimator which is not the same as \(\delta^\gamma\) for any fixed \(\gamma\), and we could use the same idea to calculate its MSE, if we wanted to.

3 Risk of the James-Stein estimator

We are now ready to calculate the risk of the James–Stein estimator \(\delta_{JS}(X) = \left(1 - \frac{d-2}{\|X\|^2}\right)X\). We can drop the assumption $^2 = $, under the assumption \(\sigma^2 = 1\) (for general \(\sigma^2\), we should replace the numerator \(d-2\) with \((d-2)\sigma^2\). Then ).

Proceeding as before, we have \[ h(X) = \frac{d-2}{\|X\|^2}X \Rightarrow \|h(X)\|^2 = \frac{(d-2)^2}{\|X\|^2}, \] Applying the quotient rule we have \[Dh(X)_{ii} = (d-2)\frac{\partial}{\partial X_i} \frac{X_i}{\sum_j X_j^2} = (d-2)\frac{\|X\|^2 - 2X_i^2}{\|X\|^4},\] and summing over the coordinates gives \[\text{tr}(Dh(X)) = (d-2) \frac{d\|X\|^2 - 2\|X\|^2}{\|X\|^4} = -\frac{(d-2)^2}{\|X\|^2}.\] We thereby obtain the estimator \[ \widehat{\text{MSE}}(X) = d + \frac{(d-2)^2}{\|X\|^2} - 2\frac{(d-2)^2}{\|X\|^2} = d - \frac{(d-2)^2}{\|X\|^2}. \] Taking expectations, we obtain \[ \text{MSE}(\theta; \delta_{\text{JS}}) = d - (d-2)^2\EE_\theta\left[\frac{1}{\|X\|^2}\right]. \] Note this is always strictly less than \(d\), which is the MSE of \(\delta_0(X) = X\).

If \(\theta = 0\) then \(\|X\|^2 \sim \chi_d^2\) and we can apply our previous result to obtain \[ \text{MSE}(0; \delta_{\text{JS}}) = d - (d-2)^2\frac{1}{d-2} = 2. \] Thus, even though we are estimating \(d\) parameters, our total MSE does not rise with \(d\), because we will shrink harder and harder toward zero the larger \(d\) gets. This is fairly remarkable.

On the other hand, suppose \(\|\theta\|^2 \to \infty\). Then \(\EE_\theta \|X\|^2 \to \infty\) and the improvement \((d-2)^2/\EE_\theta\|X\|^2\) will be driven to \(0\).

Note that, for more general \(\sigma^2\), the James–Stein estimator is \(\left(1-\frac{(d-2)\sigma^2}{\|X\|^2}\right)X\). Then we have \[h(X) = \sigma^2\frac{d-2}{\|X\|^2} \Rightarrow \|h(X)\|^2 = \sigma^4 \frac{(d-2)^2}{\|X\|^2}, \quad \text{tr} Dh(X) = \sigma^2 \frac{(d-2)^2}{\|X\|^2},\] leading to the estimator \(\widehat{\text{MSE}}(X) = \sigma^2 d - \sigma^4\frac{(d-2)^2}{\|X\|^2}\), and plugging in \(\EE_0 1/\|X\|^2 = 1/(d-2)\sigma^2\), the MSE at \(\theta=0\) is \(2\sigma^2\).

3.1 Final thoughts

A few more notes: first, \(\delta_{JS}(X)\) also inadmissible, since \(\delta_{+}(X) = (1 - \frac{d-2}{\|X\|^2})_+ X\) is strictly better since we never benefit from using a shrinkage parameter \(\zeta > 1\).

A practically more useful version of James–Stein shrinks toward the central value \(\overline{X}\):

\[\delta_{JS+, i}(X) = \overline{X} + (1 - \frac{d-3}{V^2})_+ (X_i - \overline{X})\] This estimator dominates \(\delta(X) = X\) for \(d \geq 4\).

Taken to its logical extreme, the James–Stein estimator seems to imply absurd things: should all scientists at Berkeley, across all different fields, pool their estimates to calculate a James–Stein estimator even if their different estimation problems have nothing to do with each other, as long as they are estimating Gaussian location parameters? The overall MSE for all of the estimates really would be better.

To avoid this conclusion we might note that even when the overall MSE is improved, the MSE for a single coordinate can certainly get worse. For example, if \(\theta_1 = 10\) but \(\theta_2=\theta_3=\cdots=\theta_{1000}=0\), it’s likely that we’ll overshrink our estimate of \(\theta_1\) toward \(0\) in order to do well on the other coordinates. So people who believe their estimands are larger than average wouldn’t want to participate in this scheme.