Factorization Machines as a Tool for Healthcare: Case Study on Type 2 Diabetes Detection Ioakeim Perros and Jimeng Sun Georgia Institute of Technology, Atlanta, U.S.A. perros@gatech.edu, jsun@cc.gatech.edu Abstract. Exploiting high-dimensional data for prediction purposes has been one of the most attractive targets in Healthcare Analytics. However, reliable model learning under severe sparsity is challenging for state-ofthe-art methods, such as the Support Vector Machines (SVMs). In this work, we propose the application of a recently developed model class, called Factorization Machines (FMs), to Electronic Health Record (EHR) data. The purpose of this analysis is to detect patients diagnosed with Type 2 Diabetes Mellitus (T2DM). The wide range of comorbidities with which T2DM is associated and its prevalence in a large amount of the population reflect the task’s difficulty and importance. Our experiments on publicly available EHR data confirm the suitability of FMs for Healthcare Analysis and the insights we provide serve as a stepping stone for future work. Keywords: Factorization Machines, Classification, Health Analytics 1 Introduction The prospect of applying Predictive Modeling to Healthcare data has been acclaimed and established by numerous success stories [6,12]. One of the most popular applications is the detection of a specific disease in a population [5], given its characteristics, ranging from patients’ clinical status to their demographic and genetic information. However, this task is far from reaching its full potential, due to the heterogeneous and noisy nature of the underlying features captured, as well as due to the complex relations emerging between the risk factors and a specific disease. It is interesting though that if carefully chosen, additional features can significantly improve a classifier’s potential [12]. As a result, a fundamental challenge lies in the data-driven leveraging of multiple features while capturing their complex interactions. However, in many cases, as the number of features incorporated increases, so does the data sparsity, which imposes another challenging factor even for state-of-the-art classification approaches, such as the Support Vector Machines (SVMs). In this work, we propose the application of Factorization Machines (FMs) [10] in order to deal with the inherent sparsity of multi-feature Electronic Health Record (EHR) data. The FM model class is mostly known for its success in 2 Ioakeim Perros and Jimeng Sun Collaborative Filtering settings [7, 11], such as the rating prediction problem in movie recommendation systems, where the underlying dataset is very sparse. The prediction goal we are targeting is the detection of Type 2 Diabetes Mellitus (T2DM). This task has urged an increasing interest [1, 2, 8], due to its effect on a huge amount of the global population, and due to the fact that early detection can significantly decrease diabetes’ complications [4]. In the following of this paper, we provide the Factorization Machines (FM) framework as well as other related approaches in Section 2. In Section 3, we describe our experiments, as well as our observations based on them. Finally, Section 4 provides a summary of this work. 2 Approaches In this Section, we provide a brief explanation of several model learning methods. This will facilitate the formulation of the FM framework, which follows subsequently and will highlight its difference to other well-known model learning approaches. 2.1 Polynomial Regression We will first state the classic polynomial regression model and stress out its connections to the model class of FMs. The model equation of 2nd degree (d = 2) polynomial regression for an input vector x ∈ Rp is as follows: y(x) = w0 + p X wi xi + {z linear terms Wi,j xi xj , (1) i=1 j≥i i=1 | p X p X } | {z non-linear terms } where w0 ∈ R, w ∈ Rp and W ∈ Rp×p . A difference between the above model and the standard linear regression lies in its capability of representing non-linear interactions between the input x and the target y(x), up to the specified degree d. In practice though, it cannot estimate non-linear interactions reliably, under the severe sparsity imposed by a large number of variables. We will illustrate this shortcoming by the following example: Assume the existence of the design matrix as in Table 1, where each row of the dataset corresponds to a specific patient and each column represents a data feature. In this case, we model each patient by the sex (male/female), year of birth (normalized), as well as with binary features indicating several diagnostic conditions. It is remarkable that a standard polynomial regression model would not take an interaction between Dyslipidemia and Obesity into account, despite the fact that those two diagnostic conditions are obviously related to each other. This issue is due to the lack of co-occurrence of them in a specific row of the design matrix. In the case that d > 2, when higher-order interaction relationships are Factorization Machines as a Tool for Healthcare 3 Female Male Year of Birth (norm.) Dyslipidemia Hypertension Obesity 1 0 0.1 1 0 0 0 1 1 0 1 1 0 1 0.4 1 1 0 Table 1: Example of a design matrix from EHR data taken into account, this issue gets even worse. In general, the estimation of nway interaction effects for n features requires all of them to co-occur during the training phase, which is restrictive for multi-dimensional, sparse datasets. 2.2 SVMs with Polynomial Kernel State-of-the-art classifiers, such as the SVMs, model variable interactions through the usage of the polynomial kernel. The model equation for a 2nd degree kernel is as follows [9]: y(x) = w0 + p p p X p X √ X √ X 2 w i xi + Wi,i x2i + 2 Wi,j xi xj , i=1 i=1 (2) i=1 j>i where w0 ∈ R, w ∈ Rp and W ∈ Rp×p . It is evident that Equation 1 is equivalent to Equation 2 in terms of the variable interactions that the corresponding models capture. The only difference in the SVM case is that a scaling is imposed to some of the terms. In both cases though, the interaction parameters Wi,j and Wi,l are independent. 2.3 Factorization Machines The equation describing the model class of FMs for d = 2 is as follows [10]: y(x) = w0 + p X i=1 wi xi + p X p X hVi , Vj ixi xj (3) i=1 j>i !2 p p k X 1 X X 2 2 = w0 + wi xi + Vi,f xi − Vi,f xi 2 i=1 i=1 i=1 p X (4) f =1 where w0 ∈ R, w ∈ Rp , V ∈ Rp×k , h·, ·i corresponds to the vector inner product and k to a low-rank hyperparameter. The equivalence between Equation 3 and Equation 4 has been proved in [9]. A fundamental idea behind FMs is the replacement of model parameters corresponding to feature interaction (W ∈ Rp×p ), with a low-rank factorized surrogate (V ∈ Rp×k ). A stark difference between this representation and Equations 1, 2 lies in the fact that in this case, hVi , Vj i and hVi , Vl i are now dependent, 4 Ioakeim Perros and Jimeng Sun since they share overlapping parameters (Vi ). Thus, variable interaction parameters can be estimated with little or no observation about the corresponding pairs, since data for one interaction facilitate the parameter estimation of related interactions. As can be seen in Equation 4, there is no model parameter that directly depends on two variables (parameter with index (i, j)), due to the factorized representation. Also, Equation 4 has only linear complexity in both k and p, establishing the model’s suitability for multidimensional settings and large datasets. MCMC learning In the remainder of this Section, we describe the Monte Carlo Markov Chain (MCMC) inference method for Factorization Machines [3] in high-level terms. This description justifies the choice of this inference method for our experiments against the rest of available alternatives: Stochastic Gradient Descent (SGD) and Alternating Least Squares (ALS) [10]. µ0 , µw ⇡, w ⇡ wj µvf,⇡ , v f,⇡ µw ⇡ 0 µvf,⇡ w ⇡ wj Vj,f ↵ , v f,⇡ Vj,f f = 1, . . . , k f = 1, . . . , k xij xij j = 1, . . . , p j = 1, . . . , p w0 µ0 , 0 i = 1, . . . , n ↵ yi w0 yi i = 1, . . . , n µ0 , 0 ↵ ↵0 , 0 Fig. 1: Left panel: Standard FM Model. Right panel: FM model with hyperpriors, used by the MCMC algorithm. The parameters of the standard FM model, as well as those used by the MCMC approach are illustrated in Figure 1. In general, an L2 regularization is assumed, so as to prevent overfitting. As concerns the standard FM framework, we can see that the variables are the predicted target y, the input features x, the model parameters we intend to learn (w0 , wj , Vj,f ), as well as the hyperparameters/priors µ, λ, α, where µ, λ are the mean and the regularization for each one of the model parameters and α corresponds to the prior precision (reciprocal of Factorization Machines as a Tool for Healthcare 5 variance). The FM framework assumes also that an initialization parameter σ is given, so that the low-rank matrix V of parameters is initialized by N (0, σ). The ALS method requires the regularization values λ for each model parameter, as well as σ, used as initialization. On top of them, the SGD method also requires a learning rate η to be given. Unlike ALS and SGD which use fixed-point estimates, the MCMC method generates the target through sampling. Gibbs sampling is used as a computationally tractable way to sample from the posterior distribution. As can be seen in Figure 1, for every mean µ but µ0 of the model parameters a (Gaussian) hyperprior is assumed and for every regularization parameter λ but λ0 and precision α a (Gamma) hyperprior is assumed. The core difference between the MCMC and the rest of the mentioned methods lies in the fact that the regularization values are automatically determined (via sampling). At the same time, the method is insensitive to the values of the new set of parameters introduced {α0 , β0 , µ0 , γ0 , αλ , βλ }. As a result, in this case, the only parameter that needs to be defined as an input is the initialization σ. The complete MCMC sampling algorithm can be found in [10]. 3 Experiments 3.1 Data description The data we used originate from a Kaggle challenge on Diabetes classification 1 , developed in partnership with Practice Fusion 2 . The dataset consisted of 9948 patients. The number of diabetic ones is 1904. The design matrix we built is similar to the one given as an example in Table 1. It included information about the patients’ sex, year of birth and the diagnostic conditions recorded (first threedigits in the International Classification of Diseases 9 format). The resulting number of features was 702 and the density of the whole dataset was approximately equal to 0.01, i.e. the design matrix was sparse. 3.2 Experimental Setup We used the libFM package (version 1.4.2) [10] in order to experiment with the Factorization Machines framework. The learning method used was the Monte Carlo Markov Chain (MCMC), since it is the most simple by involving the least number of hyperparameters required. The maximum number of iterations was set to 500. The parameter (σ) of the Gaussian distribution that is used for the initialization of the V matrix was set to 0.01. We empirically remarked that this value provided favorable convergence properties as compared to the alternatives proposed in [10] (e.g., 0.1, 1). We also used Matlab built-in functions for SVM with Gaussian kernel, as well as Matlab’s built-in Random Forests’ implementation, for which we picked 500 trees (selected via cross-validation). 1 2 https://www.kaggle.com/c/pf2012-diabetes/data http://www.practicefusion.com/ 6 Ioakeim Perros and Jimeng Sun 0.57 Average precision 0.56 0.55 0.54 0.53 0.52 0.51 0 5 10 15 20 25 30 35 40 45 50 55 60 65 Model dimensionality Fig. 2: Average precision of FM of increasing model complexity (varying k = {1, 2, 4, 8, 16, 32, 64}) 3.3 Results Evaluation of classifier’s quality The results of our experiments are provided in Table 2. We used a 5-fold cross-validation for all methods tested. The Accuracy Factorization Machines 0.8187 SVM 0.8125 Random Forests 0.8178 Precision 0.5656 0.5438 0.6232 Recall 0.2334 0.1263 0.1242 F1-Score 0.3302 0.2046 0.2069 AUC 0.8048 0.7598 0.8048 Table 2: Results on detecting Type 2 Diabetes Mellitus Factorization Machines’ result reported in this table uses a low-rank parameter k = 4. The results clearly reflect that the FM framework is competent enough, and even outperforms state-of-the-art classifiers under specific settings (such as in terms of the F1-Measure). In terms of the Precision measure, its performance is approximately the same as the one achieved by SVM, and a bit lower than the one achieved by Random Forests. On the other hand, in terms of Recall, it is 2x better than the other methods. Finally, as concerns the Area Under Curve (AUC) measure, the FM model achieves a bit higher AUC than SVM and is on the same regime with Random Forests. In Figure 2, we illustrate the effect of the low-rank parameter k to the FM framework. It is interesting that for k > 4, the model seems to overfit the data and the average precision gradually diminishes. Average time per fold (seconds) Factorization Machines as a Tool for Healthcare 7 35 30 25 20 15 10 5 0 0 5 10 15 20 25 30 35 40 45 50 55 60 65 Model dimensionality Fig. 3: Effect of increasing model complexity (varying k = {1, 2, 4, 8, 16, 32, 64}) on the FM training time Evaluation of classifier’s scalability In Figure 3, we provide time measurements as the low-rank parameter k increases, whereas in Figure 4, the training time is measured as the training dataset increases. In both cases, the method maintains a linear scale-up, confirming its theoretical properties. 4 Concluding remarks In this paper, we applied Factorization Machines to EHR data for classification of diabetic patients. We argue that this method could prove to be a useful tool in the Healthcare domain due to both the inherent sparsity of health-related datasets and due to the inaccurate mechanism of recording patient-related information: e.g. a patient may suffer from a specific disease, but she may not have visited her physician in order to get the corresponding diagnosis recorded. Future work will address the incorporation of more viewpoints of patients’ clinical status (such as the medications prescribed) and will investigate ways to boost efficiency and accuracy of FMs, by taking domain knowledge into account. References 1. Abbasi, A., Peelen, L.M., Corpeleijn, E., van der Schouw, Y.T., Stolk, R.P., Spijkerman, A.M., van der A, D.L., Moons, K.G., Navis, G., Bakker, S.J., et al.: Prediction models for risk of developing type 2 diabetes: systematic literature search and independent external validation study. BMJ: British Medical Journal 345 (2012) 2. Ali, R., Siddiqi, M.H., Idris, M., Kang, B.H., Lee, S.: Prediction of diabetes mellitus based on boosting ensemble modeling. In: Ubiquitous Computing and Ambient Intelligence. Personalisation and User Adapted Services, pp. 25–28. Springer (2014) 8 Ioakeim Perros and Jimeng Sun Training time (seconds) 3 2.5 2 1.5 1 0.5 0 1,000 2,000 3,000 4,000 5,000 6,000 7,000 8,000 Size of training data (#samples) Fig. 4: Effect of increasing dataset size on the FM training time (fixed k = 4) 3. Freudenthaler, C., Schmidt-Thieme, L., Rendle, S.: Bayesian factorization machines (2011) 4. Harris, M.I., Eastman, R.C.: Early detection of undiagnosed diabetes mellitus: a us perspective. Diabetes/metabolism research and reviews 16(4), 230–236 (2000) 5. Hirsch, A.T., Criqui, M.H., Treat-Jacobson, D., Regensteiner, J.G., Creager, M.A., Olin, J.W., Krook, S.H., Hunninghake, D.B., Comerota, A.J., Walsh, M.E., et al.: Peripheral arterial disease detection, awareness, and treatment in primary care. Jama 286(11), 1317–1324 (2001) 6. Koh, H.C., Tan, G., et al.: Data mining applications in healthcare. Journal of healthcare information management 19(2), 65 (2011) 7. Loni, B., Said, A., Larson, M., Hanjalic, A.: ’free lunch’enhancement for collaborative filtering with factorization machines. In: Proceedings of the 8th ACM Conference on Recommender systems. pp. 281–284. ACM (2014) 8. Mani, S., Chen, Y., Elasy, T., Clayton, W., Denny, J.: Type 2 diabetes risk forecasting from emr data using machine learning. In: AMIA Annual Symposium Proceedings. vol. 2012, p. 606. American Medical Informatics Association (2012) 9. Rendle, S.: Factorization machines. In: Proceedings of the 2010 IEEE International Conference on Data Mining. pp. 995–1000. ICDM ’10 (2010) 10. Rendle, S.: Factorization machines with libfm. ACM Transactions on Intelligent Systems and Technology (TIST) 3(3), 57 (2012) 11. Rendle, S.: Scaling factorization machines to relational data. In: Proceedings of the VLDB Endowment. vol. 6, pp. 337–348. VLDB Endowment (2013) 12. Sun, J., Reddy, C.K.: Big data analytics for healthcare. In: Proceedings of the 19th ACM SIGKDD international conference on Knowledge discovery and data mining. pp. 1525–1525. ACM (2013)