Variational Inference of Joint Models using Multivariate Gaussian Convolution Processes

We present a non-parametric prognostic framework for individualized event prediction based on joint modeling of both longitudinal and time-to-event data. Our approach exploits a multivariate Gaussian convolution process (MGCP) to model the evolution of longitudinal signals and a Cox model to map time-to-event data with longitudinal data modeled through the MGCP. Taking advantage of the unique structure imposed by convolved processes, we provide a variational inference framework to simultaneously estimate parameters in the joint MGCP-Cox model. This significantly reduces computational complexity and safeguards against model overfitting. Experiments on synthetic and real world data show that the proposed framework outperforms state-of-the art approaches built on two-stage inference and strong parametric assumptions.


Introduction
In recent years, the multivariate Gaussian process (MGP) has drawn significant attention as an efficient nonparametric approach to predict longitudinal signal trajectories (Dürichen et al., 2015;Moreno-Muñoz et al., 2018;Kontar et al., 2018b). The MGP draws its roots from multitask learning where transfer of knowledge is achieved through a shared representation between training and testing signals. One neat approach that achieves this knowledge transfer, employs convolution processes to construct the MGP. Specifically, each signal is expressed as a convolution of latent functions drawn from a Gaussian process (GP). Commonalities amongst training and testing signals are then captured by sharing these latent functions across the outputs Álvarez et al., 2010;Álvarez & Lawrence, 2011). Consequently, the multiple signals can be expressed as a single output from a common multivari-ate Gaussian convolution process (MGCP). Indeed, many recent studies have demonstrated the MGCP ability to account for non-trivial commonalities in the data and provide accurate predictive results (Zhao & Sun, 2016;Guarnizo & Alvarez, 2018;Cheng, 2018). In this article we exploit the MGCP to explore the following question: can we use both time-to-event data (also known as survival data) along with longitudinal signals to obtain a reliable event prediction? This is illustrated in Figure 1. As shown in the figure, our goal is to utilize both survival data and longitudinal signals from training units to predict the survival probability of a partially observed testing unit. Naturally, the aforementioned question is often encountered in a wide range of applications, including: disease prognosis in clinical trials, event prediction using vital health signals from monitored patients at risk, remaining useful life estimation of operational units/machines and failure prognosis in connected manufacturing systems (e.g., nuclear power plants) (Tsiatis et al., 1995;Bycott & Taylor, 1998;Gasmi et al., 2003;Pham et al., 2012;Gao et al., 2015;Soleimani et al., 2018).
In order to link survival and longitudinal data, state-of-theart methods have focused on joint models. The seminal work of Rizopoulos (Rizopoulos, 2011;2012) laid a foundation for joint models where a linear mixed effects model is used to model longitudinal signals. The coefficients of the mixed model are then used in a Cox model to compute the probability of event occurrence conditioned on the observed longitudinal signals. This idea provided the bases for many extensions and applications in the literature (Crowther et al., 2012;Wu et al., 2012;Zhu et al., 2012;Crowther et al., 2013;Proust-Lima et al., 2014;He et al., 2015;Rizopoulos et al., 2017;Mauff et al., 2018). It is important to note here that joint methods are in general built using a twostage inference procedure. In two-stage inference, features from the longitudinal data are first learned, these estimated features are then inserted into a survival model to predict event probabilities. Indeed, many papers have shown that this two-stage procedure can produce competitive predictive results (Wulfsohn & Tsiatis, 1997;Yu et al., 2004;Zhou et al., 2014;Mauff et al., 2018). Nevertheless, the foregoing works are based on strong parametric assumptions where signals are assumed to follow a specific parametric form and all the signals (training and testing) exhibit that same functional form. In other words, signals behave according to a similar trend but at different rates (i.e., different parameter values). However, parametric methods are restrictive in many applications and if the specified form is far from the truth, predictive results will be misleading. Furthermore, the assumption that all signals possess the same functional form may not hold in real-life applications. For instance, units operated under different environmental conditions may exhibit different signal evolution rates and trends (Yan et al., 2016;Kontar et al., 2018a). Some recent efforts aimed to relax strong parametric assumptions using splines, continuous time Markov chains and the GP. Unfortunately, these methods still assume homogeneity across the population and focus on merely imputing the longitudinal data rather than predicting signal evolution within a time interval of interest (Dempsey et al., 2017;Soleimani et al., 2018). We here note that there has been some recent attempts at rebuilding the Cox model using a GP (Fernández et al., 2016;Kim & Pavlovic, 2018). However these approaches are only based on survival data and do not handle joint modeling, which is the focus of this article.
To address the aforementioned challenges, we propose a flexible joint modeling approach denoted as MGCP-Cox. Our approach exploits the MGCP to model the evolution of longitudinal signals and a Cox model to map time-toevent data with longitudinal data modeled through MGCP. Event occurrence probability is then derived within any future interval ∆t as shown in Figure 1. We also propose a variational inference framework using pseudo-inputs (Snelson & Ghahramani, 2006;Damianou & Lawrence, 2013) to simultaneously estimate parameters in the joint MGCP-Cox model. This facilitates scalability to large data settings and safeguards against model overfitting. Finally, the advantageous features of the proposed method are demonstrated through numerical studies and a case study with real-world data in the application to finding the remaining useful life-time of NASA Aero-propulsion engines.
The rest of the paper is organized as follows. In section 2 we briefly review survival analysis. In section 3, we present our joint modeling framework and the variational inference algorithm. Numerical experiments using synthetic data and real-world data are provided in section 4. Finally, section 5 concludes the paper with a brief discussion. A detailed code and the used real-world data are available in the supplementary materials.

Background: Survival Analysis
In this section, we will briefly review survival analysis which will be used for event prediction in the joint model. Survival analysis is a branch of statistics for analyzing time-to-event data and predicting the probability of occurrence of an event.
For each individual unit i, the associated data is 1} is an event indicator (δ i = 1/0 indicates the unit has failed/censored), Y i are the noisy observed longitudinal data (e.g., vital signals collected from patients) corresponding to the underlying latent values f i , and w i is a set of timeinvariant features (e.g., patient's gender). Typically, the continuous random variable T i is characterized by a survival function S(t) = P (T ≥ t) which represents the probability of survival up to time t. Another important term is the hazard function h(t) = lim ∆→0 1 ∆ P (t < T ≤ t + ∆|T ≥ t) = − d dt log S(t) and can be thought of as the instantaneous rate of occurrence of an event at time t. It is easy to show that S(t) = exp{− t 0 h(u)du}. The term t 0 h(u)du is called cumulative hazard function and is denoted by H(t). The basic scheme of survival analysis is to find suitable models to explain relationships between the hazard function h i (t) and collected data D i . These models are defined as survival models.
Many survival models have been developed to analyze time-to-event data. They typically model the hazard function as a function of some time-varying and fixed features. One class of prevailed survival models is called the Cox model (Cox, 1972), which has the form h is a baseline hazard function shared by all individuals, and is typically modeled by the Weibull or a piecewise constant function, γ is a vector of coefficients for the fixed covariates (features), f i (t) is the feature estimated by a longitudinal model (e.g., linear mixed model, Gaussian Process), and β is a scaling parameter for the time-varying covariates. Parameters in the Cox model are typically estimated by maximizing the log-likelihood (1) For a comprehensive review of survival models, see (Kalbfleisch & Prentice, 2011). Given an estimate of parameters from the Cox model, we can then obtain the event (failure) probability within a future time interval ∆t given the fact that the testing unit i survives non-shorter than the current time instance t * . This probability, denotedP ∆t , is estimated as follows: where w i and f i are features for a testing unit i. Note that in Figure 1 we show the survival curve which is defined aŝ S(t|t * ) = 1 −P ∆t , where t = t * + ∆t.

The multivariate Gaussian convolution process (MGCP)
Assume data have been collected from N units and let where l i represents the number of observations and {t ir : r = 1, . . . , l i } denotes the inputs. We decompose the longitudinal signal as is a mean zero GP and i (t) denotes additive noise with zero mean and σ 2 variance.
To obtain an accurate predictive result, we need to capture the intrinsic relatedness among N signals. Particularly, we resort to the convolution process to model the latent function f i (t). We consider K independent latent functions {X k (t)} K k=1 and N K different smoothing kernels The latent functions are assumed independent GPs with covariance cov[X k (t), X k (t )] = κ k (t, t ). We set G i,k (t) = α i,k N (0, ξ 2 i,k ) to be scaled Gaussian kernels and κ k (t, t ) to be squared exponential covariance functions (Álvarez & Lawrence, 2009).
The GP f i (t) is then constructed by convolving the shared latent functions with the smoothing kernel as shown in (4). This is the underlying principle of the MGCP, where the latent functions {X k (t)} K k=1 are shared across different outputs through the corresponding kernels G i,k (t). Since convolutions are linear operators on a function and since the latent function, a GP, is shared across multiple outputs then all outputs can be expressed as a jointly distributed GP, an MGCP. As shown in Figure 2, a key feature is that information is shared through different parameters encoded in the kernels G i,k (t). Outputs then can possess both shared and unique features. Thus, accounting for heterogeneity in the longitudinal data.
Based on equation (4), the covariance function between f i and f j , and the covariance function between f i and X k , can be calculated as where To alleviate computational burden, we introduce M pseudo-inputs from the latent functions denoted as Since the latent functions are GPs, then any sample X k (Z) follows a multivariate Gaussian distribution. Conditioned on X k (Z), we next sample from the conditional prior p(X k (u)|X k (Z)). In equation (4), X k (u) can be approximated well by the expectation E(X k (u)|X k (Z)) as long as the latent functions are smooth (Álvarez & Lawrence, 2011).
The probability distribution of X can be expressed as p(X|Z) = N (0, K X,X ), where K X,X is a block-diagonal matrix such that each block is associated with the covariance of X k in (3). By multivariate Gaussian identities, the probability distribution of f conditional on X, Z is where Q = K f ,X K −1 X,X K X,f . Therefore, p(f ) can be approximated by p(f |Z), which is given as By equation (7), p(Y ) can then be approximated by p(Y |Z) = p(Y |f )p(f |X, Z)p(X|Z)df dX.

Joint Model and Variational Inference
Now following our convolution construction in (4), the hazard function at time t is given as This key equation links the MGCP to the Cox model. We begin with presenting the log-likelihood of the joint model We would like to provide a good approximation of log p(D) by introducing an evidence lower bound (ELBO) L. This bound is calculated by finding the Kullback-Leibler (KL) divergence between the variational density q(f , X|Z) and the true posterior density p(f , X|D, Z). Specifically, The variaitonal density is assumed to be factorized as Maximizing the ELBO with respect to q(X) and hyperparameters from the MGCP-Cox model can achieve purposes of variational inference and model selection simultaneously (Kim & Pavlovic, 2018). By equation (10), Furthermore, we can decompose log p(D|f . Based on equation (12), the MGCP propagates uncertainties through the latent processes to the Cox model.
It is desirable to find a closed form of the ELBO in equation (12). Since p(Y |f ) and p(f |X, Z) are both Gaussian, we can obtain where Tr(·) is a trace operator. Therefore, the ELBO can be simplified as (14) We compute the optimal upper bound of L by reversing Jensen's inequality. This gives an optimal distribution q * 1 (X) and where P E = − 1 2σ 2 Tr(K f ,f − Q). P E can be thought of as a penalization term that regularizes the estimation of the parameters. Note that the first two terms in equation (15) We will present a solution to solve the last integration in equation (15) in the following section.

Variational Inference on Cox Model
Parameters in the Cox model can be attained by maximizing the following log-likelihood function: In equation (15), we obtain the optimal q * 1 (X) to maximize the ELBO. In this section, we will use q * 1 (X) to approximate q 2 (X). Specifically, the optimal q * 1 (X) has the form It is easy to show that q(f |Z) has the normal distribution with parameter µ, Σ. Specifically, where (15) can be simplified to

The last integration in equation
The first term in equation (19) can be calculated analytically.
For each unit i, where In the last step, we applied the Fubini's theorem to interchange integrals. The second term in equation (19) can be estimated by the moment generating function (MGF) and the numerical integration. For each unit i, where µ i (u) := K fi(u),X K −1 X,X m, and σ 2 i (u) := K fi(u),fi(u) − K fi(u),X K −1 X,X (I − sK −1 X,X )K X,fi(u) . We can assume h 0 (t) to be an exponential function where b,ψ are parameters to be learned and h 0 (t) = 0 when t < min{V i } N i=1 because units are not subject to risk before the first failure event. Note that the baseline hazard is non-decreasing with time and thus we add one constraint ψ ∈ R + . To obtain a good baseline hazard prediction given the estimatedb,ψ, we can calculate the cumulative hazard at time point t as H(t) = u∈F (t)ĥ 0 (u), ∀t, where F(t) := Then we fit a regularized smooth spline to H(t) (Ruppert, 2002). The predicted baseline hazard at u ∈ [t * , t * + ∆t * ] can be estimated by dĤ (t) dt t=u (Rosenberg, 1995).
The L * is maximized with respect to the parameters Θ = (θ, , by the gradient-based method. Specifically, We can obtain the optimal parametersΘ by maximizing L * subject to ψ ≥ 0.

Event Prediction
Without loss of generality, we focus on predicting the event occurrence probability for unit N . Suppose observations from the testing unit N have been collected up to time t * . The survival model computes the event probabilities conditioned on the predicted longitudinal features f N (u), u ∈ [t * , t * + ∆t]. Given estimated parameters, and following (2), we are interested in calculating Based on equation (22), the accurate extrapolation within ∆t is essential. In the MGCP, the predictive distribution for any new input point T is given by where We have used K f N (T * ),f N (T * ) as a notation to indicate when the covariance matrix is evaluated at the T * . Consequently, the predicted signal at the time point T * for unit N isf N (T * ) = AD −1 Y .

Experiments
We conduct case studies to demonstrate the performance of our proposed methodology, denoted as MGCP-Cox. Both synthetic and real-world data are used. We also provide an illustrative example in Figure 3 to demonstrate the benefits of the MGCP-Cox model.

Data setting
For the synthetic data we assume that the underlying true path for unit i has the form y where a ∼ uniform(0.003, 0.03). Without loss of generality, we assume that the time unit is month and that signals were obtained regularly at each month up to their failure or censoring time. An example of the signals is shown in the top row of Figure 3. For each unit we specify a time-invariant feature w i ∈ {0, 1} generated by a Bernoulli distribution with p = 0.5. In the Cox model, we use the Weibull baseline hazard rate function h 0 (t) = λρt ρ−1 with λ = 0.001 and ρ = 1.05. We generate the failure time T i for each unit by rejection sampling using its probability density function h i (t)S i (t) . We set γ = 0 and β = 0.5. Also, we randomly select 5% of the units to be right censored. The number of units generated is N = 20 and the experiment is repeated for Q = 100 times. Detailed code for data generation is provided in the supplementary materials.
For the real-world case study we use the C-MAPSS dataset provided by the National Aeronautics and Space Administration (NASA). The dataset contains failure time data of aircraft turbofan engines and degradation signals from informative sensors mounted on these engines. Note that in our analysis we standardize all sensor data. We refer readers to Saxena & Goebel (2008) and Liu et al. (2013) for more details about the data. The C-MAPSS data publically available at: https: //ti.arc.nasa.gov/tech/dash/groups/ pcoe/prognostic-data-repository/.

Baselines and Evaluations
We focus on predicting the event probability within a future time interval ∆t. We consider ∆t = 12, 15, 20 months in this simulation study. Prediction performance at varying time points t * for the partially observed unit N is then reported. The time instant t * = αT N is defined as the αobservation percentile, where T N is the failure time of unit N . The values of α are specified as 30%, 50%. classifier: the SVM here is used as a flexible alternative to the LR. We use the radial basis kernel and determine parameters using 2-fold cross-validation on the training data (3) Parametric Joint Model (LMM-Joint): we implement a state-of-the-art joint modeling algorithm using the linear mixed-effect model. The LMM-joint uses a general polynomial function whose corresponding degree is determined through an Akaike information criteria to model the signal path. Note that this framework estimates parameters from the mixed-effect model and the Cox model separately (Rizopoulos, 2011;Zhou et al., 2014;Mauff et al., 2018). Regarding our MGCP-Cox model we set the number of pseudo-inputs to M = 10 and the number of latent functions to K = 1. This setting is a commonly used setting for the MGCP (Álvarez & Lawrence, 2011; Zhao & Sun, 2016). The performance of each method is then assessed by the Receiver Operating Characteristic (ROC) curve, which is a common diagnostic tool for binary classifier. The ROC curve is created by plotting the true positive rate (TPR) against the false positive rate (FPR). Predictive accuracy is then assessed through the area under the curve (AUC). The results from the synthetic data are shown in Figure 4. Due to poor performance of both the LR and SVM on N = 20, we also checked whether they can produce comparable results to the MGCP-Cox when N = 200. We denote those models as LR-200 and SVM-200.
For the real data, the true survival probabilities are not available since we do not have information about the underlying parameters used to generate the data. Therefore, to evaluate model performance, we calculate the mean remaining lifetime of the testing unit, which is defined as mrl(t * ) = ∞ t * Ŝ (u|t * , w N , f N )du. This integration can be obtained by the Gauss-Legendre quadrature. The performance is assessed by the absolute error AE = |rl j − mrl j | where rl j is the true remaining lifetime of the testing unit. We then report the distribution of the errors across all units using the boxplot in Figure 5. Similar to the synthetic data we use 30% and 50% percentiles to assess prediction accuracy. We also note than we cannot obtain mrl estimates from the SVM and LR as they transform event prediction into a time series classification problem. Therefore, only results from LMM-joint and MGCP-Cox are reported in Figure 5.

Results
The illustrative example in Figure 3 demonstrates the behavior of our method. As shown in the figure, our joint model framework can provide accurate predictions of both longitudinal signals and event probabilities. The unique smoothing kernel G i,k for each individual allows flexibility in the prediction, since it enables each training signal to have its own characteristics. This substantiates the strength of the MGCP. Equipped with the shared latent processes, the model can infer the similarities among all units, and predict Based on the figure we can obtain some important insights. First, as expected, prediction errors significantly decrease as the lifetime percentiles increase. Thus, the prediction accuracy from the MGCP-Cox model will become more accurate as t * increases and more data are collected from an online monitored unit. Second, the prediction accuracy slightly decreases as we predict over a longer horizon (i.e. prediction is better for the near future). This is intuitively understandable as accuracy might decrease when predicting over a large region where not many training data might be observed. Third, the results show that the MGCP-Cox clearly outperforms LMM-joint. This result highlights the danger of parametric modeling and demonstrates the ability of our non-parametric approach to avoid model misspecifications. Fourth, even when the LR and SVM had a much larger number of units, the MGCP-Cox was still superior. This observation, also true to the LMM-Joint, highlights the strength of joint models. Lastly, one striking feature, shown in Figures 3, 4 and 5, is that even with a small number of observations (30% observation percentile) from the testing unit we were still able to get accurate predictive results. This crucial in many applications as its allows early prediction of an event occurrence such as a disease or machine failure.

Conclusion
We have presented a flexible and efficient non-parametric joint modeling framework for longitudinal and time-to-event data. A variational inference framework using pseudoinputs is established to jointly estimate parameters from the MGCP-Cox model. Empirical studies highlight the advantageous features of our model to predict signal trajectories and provide reliable event prediction.