Sampling from the Wischart distribution.

R
Statistical simulations
Wishart distribution
An approach to sample from the Wischart distribtuion.
Author

Yevgen Ryeznik

Published

February 24, 2020

Wishart distribution

The Wishart distribution is a family of probability distributions defined over symmetric, nonnegative-definite matrix-valued random variables (“random matrices”). These distributions are of great importance in the estimation of covariance matrices in multivariate statistics.

Suppose \(\boldsymbol{X}_{p\times n} = \left(\boldsymbol{X}^{(1)}, \ldots, \boldsymbol{X}^{n}\right)\) is a \(p\times n\) matrix, each column of which is independently drawn from a \(p\)-variate normal distribution with zero mean, i.e.

\[ \boldsymbol{X}^{(j)} \sim N(0, \boldsymbol{V}), \quad j = 1, \ldots, n, \tag{1}\]

where \(\boldsymbol{V}\) is a \(p\times p\) covariance matrix. Then a random matrix-valued variable

\[ \boldsymbol{S} = \boldsymbol{X}\cdot \boldsymbol{X}' \tag{2}\]

follows Wishart distribution \(W_p(\boldsymbol{V}, n)\). It has the following numerical characteristics:

  • Expected value: \[ \text{E}\left[\boldsymbol{S}\right] = n\boldsymbol{V}. \tag{3}\]

  • Variance: \[ \text{Var}\left[\boldsymbol{S}\right] = n\left(\boldsymbol{V}^2 + diag(\boldsymbol{V})\cdot diag(\boldsymbol{V})'\right), \: diag(\boldsymbol{V}) = \left(\boldsymbol{V}_{11}, \ldots, \boldsymbol{V}_{pp}\right), \tag{4}\]

where elements of \(\boldsymbol{V}^2\) are obtained as squares of the corresponding elements of \(\boldsymbol{V}\).

We are interested in sampling random matrices with the mean value equal to \(\boldsymbol{V}\), i.e.

\[ \boldsymbol{S}_{1}, \boldsymbol{S}_{2}, \ldots, \sim W_p(\boldsymbol{V}, 1). \tag{5}\]

A sample-point from the distribution \(W_p(\boldsymbol{V}, 1)\) can be drawn by doing the following steps:

  1. Sample multi-variate (\(p\)-variate) random variable \(\boldsymbol{X}\sim N(0, \boldsymbol{V})\):

    • make a Choletsky decomposition of \(\boldsymbol{V}\), i.e. \(\boldsymbol{V} = \boldsymbol{L}\cdot \boldsymbol{L}'.\)
    • sample multi-variate (\(p\)-variate) random variable \(\boldsymbol{Z}\sim N(0, \boldsymbol{I}).\)
    • \(\boldsymbol{X} = \boldsymbol{L}\cdot \boldsymbol{Z}.\)
  2. \(\boldsymbol{S} = \boldsymbol{X}\cdot \boldsymbol{X}'.\)

R code to perform the procedure

# loading pipe %>%
library(magrittr)

# loading map_ function family
library(purrr)

# function to sample a single observation from the Wishart distribution, 
# given parameters V, n

r1wishart <- function(V, n = 1){
  p <- nrow(V)

  # Choletsky decomposition 
  L <- t(chol(V))
  
  X <- map(seq_len(n), ~ { 
    Z <- rnorm(p)
    L%*%Z
  }) %>%
  unlist() %>%
  matrix(ncol = n, byrow = FALSE)
  X%*%t(X) 
}

# function to sample nsmp observations from the Wishart distribution, 
# given parameters V, n

sample_wishart <- function(nsmp, V, n = 1){
  map(seq_len(nsmp), ~ r1wishart(V, n)) 
}

Test example

# variances
omega <- c(1, 2, 3)

# correlation matrix
rho <- rbind(
  c(1, 0.2, 0.7),
  c(0.2, 1, 0.45),
  c(0.7, 0.45, 1) 
)

# covariance matrix
V <- rho*sqrt(cbind(omega)%*%rbind(omega)) 

# number of sample points
nsmp <- 10000

# sampling
W <- sample_wishart(nsmp, V)

# calculating the sample mean. We expect that it is approximately equal to V
sample_mean <- reduce(W, `+`)/nsmp 

# printing out the sample mean
print(sample_mean) 
          [,1]      [,2]     [,3]
[1,] 0.9942158 0.3034728 1.213229
[2,] 0.3034728 2.0452006 1.148126
[3,] 1.2132293 1.1481260 3.000992
# printing out the true mean
print(V) 
          [,1]      [,2]     [,3]
[1,] 1.0000000 0.2828427 1.212436
[2,] 0.2828427 2.0000000 1.102270
[3,] 1.2124356 1.1022704 3.000000
# calculating the sample variance. 
# We expect that it is approximately equal to V^2 + diag(V)%*%t(diag(V))

sample_var <- map(W, ~ .^2) %>% 
  reduce(`+`) %>% 
  '/'(nsmp) %>% 
  '-'(sample_mean^2) 

# printing out the result
print(round(nsmp/(nsmp-1)*sample_var, 2)) 
     [,1] [,2]  [,3]
[1,] 1.88 2.10  4.25
[2,] 2.10 8.10  7.21
[3,] 4.25 7.21 17.73
# printing out the true variance
print(V^2 + cbind(diag(V))%*%rbind(diag(V)))
     [,1]  [,2]   [,3]
[1,] 2.00 2.080  4.470
[2,] 2.08 8.000  7.215
[3,] 4.47 7.215 18.000