ASH-based wavelet shrinkage for heteroskedastic models

2018-05-21

In this vignette, we assume a simple parametric model of the form

\[ y_i = \mu_i + \epsilon_i, \]

where \(\epsilon_i \sim N(0,\sigma_i^2)\), and both \(\mu\) and \(\sigma\) are unknowns. We use the “smash” procedure (SMoothing via Adaptive SHrinkage) to estimate both the mean and the variances. Here we present a brief demonstration of the method.

We first look at mean estimation, which is our primary focus. A sample mean function is presented, as well as a couple of different variance functions. Our method is compared against a few other simple methods.

Set up environment

We begin by loading the MASS, smashr, EbayesThresh and wavethresh packages.

library(MASS)
library(smashr)
library(EbayesThresh)
library(wavethresh)

Preparations

First, we define the mean function used to simulate the data.

spike.f <- function(x) (0.75 * exp(-500 * (x - 0.23)^2) +
  1.5  * exp(-2000  * (x - 0.33)^2) +
  3    * exp(-8000  * (x - 0.47)^2) +
  2.25 * exp(-16000 * (x - 0.69)^2) +
  0.5  * exp(-32000 * (x - 0.83)^2))
n    <- 1024
t    <- 1:n/n
mu.s <- spike.f(t)

Define a few other functions which will be used in the code chunks below.

mse    <- function(x, y) mean((x - y)^2)
l2norm <- function(x) sum(x^2)
mise   <- function(x, y, r)
  10000 * mean(apply(x - rep(1, r) %o% y, 1, l2norm)/l2norm(y))
sig.est.func <- function(x, n)
  sqrt(2/(3 * (n - 2)) * sum((1/2 * x[1:(n - 2)] - x[2:(n - 1)] +
  1/2 * x[3:n])^2))

Define a function for the default wavelet thresholding method.

waveti.u <- function(x, filter.number = 10, family = "DaubLeAsymm",
                     min.level = 3, noise.level) {
    TT = length(x)
    thresh = noise.level * sqrt(2 * log(TT))
    x.w = wavethresh::wd(x, filter.number, family, type = "station")
    x.w.t = threshold(x.w, levels = (min.level):(x.w$nlevels - 1),
                      policy = "manual", value = thresh, type = "hard")
    x.w.t.r = AvBasis(convert(x.w.t))
    return(x.w.t.r)
}

Define another function for the default EbayesThresh method.

waveti.ebayes <- function(x, filter.number = 10, family = "DaubLeAsymm",
                          min.level = 3, noise.level) {
    n = length(x)
    J = log2(n)
    x.w = wd(x, filter.number, family, type = "station")
    for (j in min.level:(J - 1)) {
        x.pm = ebayesthresh(accessD(x.w, j), sdev = noise.level)
        x.w = putD(x.w, j, x.pm)
    }
    mu.est = AvBasis(convert(x.w))
    return(mu.est)
}

For the first demonstration, define the mean and variance functions, and set the signal to noise ratio.

mu.t <- (1 + mu.s)/5
rsnr <- sqrt(1)
var1 <- rep(1, n)
var2 <- (1e-04 + 4 * (exp(-550 * (t - 0.2)^2) + exp(-200 * (t - 0.5)^2) +
          exp(-950 * (t - 0.8)^2)))/1.35 

Constant variance example

We first look at the case of constant variance.

set.seed(327)
sigma.ini <- sqrt(var1)
sigma.t   <- sigma.ini/mean(sigma.ini) * sd(mu.t)/rsnr^2
X.s       <- matrix(rnorm(10 * n, mu.t, sigma.t), nrow = 10, byrow = TRUE)
mu.est    <- apply(X.s, 1, smash.gaus)
mu.est.tivar.ash <- apply(X.s, 1, ti.thresh, method = "smash")
mu.est.tivar.mad <- apply(X.s, 1, ti.thresh, method = "rmad")
mu.est.ti        <- matrix(0, 10, n)
mu.est.ti.ebayes <- matrix(0, 10, n)
for (i in 1:10) {
  sig.est = sig.est.func(X.s[i, ], n)
  mu.est.ti[i, ] = waveti.u(X.s[i, ], noise.level = sig.est)
  mu.est.ti.ebayes[i, ] = waveti.ebayes(X.s[i, ], noise.level = sig.est)
}

Assess the accuracy of the results.

cat("SMASH:",mise(t(mu.est), mu.t, 10),"\n")
cat("TI thresholding with variance estimated from smash:",
    mise(t(mu.est.tivar.ash), mu.t, 10),"\n")
cat("TI thresholding with variance estimated from running MAD:",
    mise(t(mu.est.tivar.mad), mu.t, 10),"\n")
cat("TI thresholding with constant variance (estimated):",
    mise(mu.est.ti, mu.t, 10),"\n")
cat("EBayes with constant variance (estimated):",
    mise(mu.est.ti.ebayes, mu.t, 10),"\n")
# SMASH: 59.77 
# TI thresholding with variance estimated from smash: 78.47 
# TI thresholding with variance estimated from running MAD: 118.9 
# TI thresholding with constant variance (estimated): 87.7 
# EBayes with constant variance (estimated): 75.02

Plot the estimated mean functions against the ground-truth function (in black).

plot(mu.t, xlab = "", ylab = "", type = "l")
lines(mu.est[, 1], col = 2)
lines(mu.est.tivar.mad[, 1], col = 3)
lines(mu.est.ti[1, ], col = 4)
lines(mu.est.ti.ebayes[1, ], col = 6)
legend("topright", legend = c("smash", "ti_rmad", "ti_homo", "ebayes_homo"),
       fill = c(2, 3, 4, 6))
&nbsp;

 

Non-constant variance example

Generate the data for this example.

sigma.ini = sqrt(var2)
sigma.t = sigma.ini/mean(sigma.ini) * sd(mu.t)/rsnr^2
set.seed(327)
X.s = matrix(rnorm(10 * n, mu.t, sigma.t), nrow = 10, byrow = TRUE)
mu.est = apply(X.s, 1, smash.gaus)
mu.est.tivar.ash = apply(X.s, 1, ti.thresh, method = "smash")
mu.est.tivar.mad = apply(X.s, 1, ti.thresh, method = "rmad")
mu.est.ti = matrix(0, 10, n)
mu.est.ti.ebayes = matrix(0, 10, n)
for (i in 1:10) {
    sig.est = sig.est.func(X.s[i, ], n)
    mu.est.ti[i, ] = waveti.u(X.s[i, ], noise.level = sig.est)
    mu.est.ti.ebayes[i, ] = waveti.ebayes(X.s[i, ], noise.level = sig.est)
}

Assess the accuracy of the results.

cat("SMASH:",mise(t(mu.est), mu.t, 10),"\n")
cat("TI thresholding with variance estimated from smash:",
    mise(t(mu.est.tivar.ash), mu.t, 10),"\n")
cat("TI thresholding with variance estimated from running MAD:",
    mise(t(mu.est.tivar.mad), mu.t, 10),"\n")
cat("TI thresholding with constant variance (estimated):",
    mise(mu.est.ti, mu.t, 10),"\n")
cat("EBayes with constant variance (estimated):",
    mise(mu.est.ti.ebayes, mu.t, 10) ,"\n")
# SMASH: 106.9 
# TI thresholding with variance estimated from smash: 141.9 
# TI thresholding with variance estimated from running MAD: 215.7 
# TI thresholding with constant variance (estimated): 400.4 
# EBayes with constant variance (estimated): 431.2

Plot the estimated mean functions against the ground-truth function (in black).

plot(mu.t, xlab = "", ylab = "", type = "l")
lines(mu.est[, 1], col = 2)
lines(mu.est.tivar.mad[, 1], col = 3)
lines(mu.est.ti[1, ], col = 4)
lines(mu.est.ti.ebayes[1, ], col = 6)
legend("topright", legend = c("smash", "ti_rmad", "ti_homo", "ebayes_homo"),
       fill = c(2, 3, 4, 6))