library(rjags)
library(dplyr)
library(tidyr)
library(tidybayes)
library(ggplot2)

# Simulate Data
n = 1000
nmiss = 10
x = rnorm(n, 0, 1)
y = rnorm(n, 1*x, 1) # simulate y as a function of x

d = data.frame(x=x, y=y)
d$xtrue = d$x
d[sample(1:nrow(d), size = nmiss, replace = F), "x" ] = NA # drop 10 x values at random to create cases of missing data
d = d %>% arrange(x) # Arrange so that d[91:100, "x"] contains missing values

# Specify model
m_string = "
  model {

  # Prior on intercept and slope
  intercept ~ dnorm(0, 1)
  slope ~ dnorm(0, 1)
  
  # Prior on sd of y
  sigma ~ dunif(0, 5)
  
  # Priors on mean + sd of x
  mux ~ dnorm(0, 1)
  sdx ~ dunif(0, 5)
  
  # x[1:(n-nmiss)] and y[1:(n-nmiss)] are known
  # Use known values to estimate mux, sdx, intercept, slope, sigma
  for (i in 1:(n-nmiss)) {
    x[i] ~ dnorm(mux, sdx)
    mu[i] = intercept + slope * x[i]
    y[i] ~ dnorm(mu[i], sigma)
  }
  
  # x[1:(n-nmiss)] are unknown (NA) and y[1:(n-nmiss)] are known
  # Estimate unkown x
  for (i in (n-nmiss+1):n) {
    x[i] ~ dnorm(mux, sdx)
    mu[i] = intercept + slope * x[i]
    y[i] ~ dnorm(mu[i], sigma)
  }
}
"

# JAGS Data
jags_data = list(
  y = d$y,
  x = d$x,
  n = n,
  nmiss=nmiss
)

# Model
m = jags.model(file = textConnection(m_string), data = jags_data)

# Get draws
draws = coda.samples(m, variable.names = "x", n.iter = 4000, thin = 1)

# Compute HPDIs
draws = draws[[1]] %>% as.data.frame() %>% 
  pivot_longer(everything(), names_to = "obs", values_to = "ximpute") %>%
  mutate(obs = as.numeric(gsub("[^0-9]", "", obs))) %>%
  group_by(obs) %>%
  mean_hdci() %>% 
  select(c(ximpute, .lower, .upper))

# Add to data
d = d %>% bind_cols(draws)

# Plot imputed values against true values
d %>% ggplot(aes(x=xtrue, y=ximpute, ymin=.lower, ymax=.upper)) + 
  geom_point(col="black") + 
  geom_pointrange(data = d %>% filter(is.na(x)),col="red")

# Plot y against x with imputed x values
d %>% ggplot(aes(x=y, y=x)) + geom_point() +
  geom_pointrange(data = d %>% filter(is.na(x)), aes(x=y, y=ximpute, ymin=.lower, ymax=.upper,), col="red") +
  coord_flip()
