# =================================================================
# Ref: Cook RJ, Bergeron P-J, Boher J-M, Liu Y (2009). Two-stage
# design of clinical trials involving recurrent events. Statistics 
# in Medicine, 28: 2617-2638
# =================================================================


# =================================================================
# Code to calculate the sample size per group based on equation (3)
# =================================================================

ss.f <- function(beta0, beta1, phi, alp1, alp2) {
  delta <- beta1

  term1 <- ( 1 + (phi*exp(beta0 + beta1)) ) / exp(beta0 + beta1)
  term2 <- ( 1 + (phi*exp(beta0)) ) / exp(beta0)   
  
# obtaining critical value for the type 1 error rate : One-sided test
  Zalp1 <- qnorm(1 - alp1, mean=0, sd=1, lower.tail=TRUE) 

# obtaining the critical value for the desired power 
  Zalp2 <- qnorm(1 - alp2, mean=0, sd=1, lower.tail=TRUE)

  term3 <- (Zalp1 + Zalp2) / delta

  m <- ceiling( (term1 + term2)*term3*term3 )
  return(m)
}


# =================================================================
# EM algorithm for interim estimation with complete blinding
# Ref: Section 3.2 
# =================================================================

# -----------------------------------------------------------------
# Computing the expectation terms
# -----------------------------------------------------------------

Eterm.unknown.f <- function(ni, x, q, beta0, beta1, phi) {
  nu <- ( 1 + (phi*ni) ) / ( 1 + ( phi*exp(beta0 + (beta1*x)) ) )
  gam <- digamma( ni + (1/phi) ) + log( phi / ( 1 + ( phi*exp(beta0 + (beta1*x)) ) ) )

  invterm <- ( ( 1 + ( phi*exp(beta0) ) ) / ( 1 + (phi*exp(beta0 + beta1)) ) )^( (-1)*(ni + (1/phi)) )
  zetai <- 1 / ( 1 + ( (1 - q)/q )*invterm*exp( (-1)*ni*beta1 ) )
  zetai <- ifelse(x == 0, (1 - zetai), zetai)
  return( data.frame(nu, gam, zetai) )
}


# -----------------------------------------------------------------
# Evaluating Q[1] w.r.t. beta[0] and beta[1]
# -----------------------------------------------------------------

Q1.unknown.f <- function(ni, Eterm0, Eterm1) {
  nsubj   <- length(ni)
  x       <- c(rep(0, nsubj), rep(1, nsubj))
  resp    <- c(ni, ni)
  adjterm <- c(log(Eterm0$nu), log(Eterm1$nu))
  wgts    <- c(Eterm0$zetai, Eterm1$zetai)

  fit <- glm(resp ~ offset(adjterm) + x, weights=wgts, family=poisson)
  beta0 <- fit$coef[1]
  beta1 <- fit$coef[2]
  return( data.frame(beta0, beta1) )
}


# -----------------------------------------------------------------
# Evaluating Q[2] w.r.t. phi
# -----------------------------------------------------------------

Q2.unknown.f <- function(ini.phi, Eterm0, Eterm1) {
  func2.f <- function(p, Eterm0, Eterm1) {
    phi <- exp(p)

    gami <- (Eterm0$gam*Eterm0$zetai) + (Eterm1$gam*Eterm1$zetai)
    nui  <- (Eterm0$nu*Eterm0$zetai)  + (Eterm1$nu*Eterm1$zetai)

    Q <- ( ((1/phi) - 1)*gami ) - ( nui/phi ) - lgamma( 1/phi ) - ( log(phi)/phi )
    return( (-1)*sum(Q) )
  }

  fit <- nlm(func2.f, p=log(ini.phi), Eterm0=Eterm0, Eterm1=Eterm1,
             gradtol=1e-12, steptol=1e-06, iterlim=500)
  phi <- exp( fit$estimate )
  return(phi)
}


# -----------------------------------------------------------------
# Performing the EM Algorithm
# -----------------------------------------------------------------

em.unknown.f <- function(indata, q, ini.beta0, ini.beta1, ini.phi) {
  cur.beta0 <- ini.beta0
  cur.beta1 <- ini.beta1
  cur.phi   <- ini.phi

  iter <- 0
  tol  <- 9999
  while ( tol > 1e-06 ) {
    iter <- iter + 1

    pre.beta0 <- cur.beta0
    pre.beta1 <- cur.beta1
    pre.phi   <- cur.phi

    Eterm0 <- Eterm.unknown.f(ni=indata$n, x=rep(0,nrow(indata)), q=q,
                              beta0=pre.beta0, beta1=pre.beta1, phi=pre.phi)
    Eterm1 <- Eterm.unknown.f(ni=indata$n, x=rep(1,nrow(indata)), q=q,
                              beta0=pre.beta0, beta1=pre.beta1, phi=pre.phi)

    fit <- Q1.unknown.f(ni=indata$n, Eterm0=Eterm0, Eterm1=Eterm1)
    cur.beta0 <- fit$beta0
    cur.beta1 <- fit$beta1

    cur.phi <- Q2.unknown.f(ini.phi=ini.phi, Eterm0=Eterm0, Eterm1=Eterm1)

    tol <- max( c( abs(cur.beta0 - pre.beta0), abs(cur.beta1 - pre.beta1), abs(cur.phi - pre.phi) ) )

    if ( iter >= 10000 ) { break }
  }

  est <- data.frame(tol=tol, iter=iter, b0=cur.beta0, b1=cur.beta1, ph=cur.phi)
  return(est)
}


# =================================================================
# Script to run each function
# =================================================================

m <- ss.f(beta0=log(2*0.5), beta1=log(0.75), phi=0.6, alp1=0.025, alp2=0.2)
cat("Total Sample size, 2*m = ", 2*m, "\n")


# Input Parameters for em.unknown.f function
# indata     Data frame with two columns: id and n
# q          P(X = 1) 
# ini.beta0  Starting value for beta[0]
# ini.beta1  Starting value for beta[1]
# ini.phi    Starting value for phi
#
# For example,
#   indata <- data.frame(id=c(1,2,3,...), n=c(10,0,2,...))
#   q <- 0.5
#   ini.beta0 <- log(2*0.5)
#   ini.beta1 <- log(0.75)
#   ini.phi   <- 0.6

em <- em.unknown.f(indata, q, ini.beta0, ini.beta1, ini.phi)