Factorization Machines as a Tool for Healthcare:

advertisement
Factorization Machines as a Tool for Healthcare:
Case Study on Type 2 Diabetes Detection
Ioakeim Perros and Jimeng Sun
Georgia Institute of Technology, Atlanta, U.S.A.
perros@gatech.edu, jsun@cc.gatech.edu
Abstract. Exploiting high-dimensional data for prediction purposes has
been one of the most attractive targets in Healthcare Analytics. However,
reliable model learning under severe sparsity is challenging for state-ofthe-art methods, such as the Support Vector Machines (SVMs). In this
work, we propose the application of a recently developed model class,
called Factorization Machines (FMs), to Electronic Health Record (EHR)
data. The purpose of this analysis is to detect patients diagnosed with
Type 2 Diabetes Mellitus (T2DM). The wide range of comorbidities with
which T2DM is associated and its prevalence in a large amount of the
population reflect the task’s difficulty and importance. Our experiments
on publicly available EHR data confirm the suitability of FMs for Healthcare Analysis and the insights we provide serve as a stepping stone for
future work.
Keywords: Factorization Machines, Classification, Health Analytics
1
Introduction
The prospect of applying Predictive Modeling to Healthcare data has been acclaimed and established by numerous success stories [6,12]. One of the most popular applications is the detection of a specific disease in a population [5], given its
characteristics, ranging from patients’ clinical status to their demographic and
genetic information. However, this task is far from reaching its full potential,
due to the heterogeneous and noisy nature of the underlying features captured,
as well as due to the complex relations emerging between the risk factors and
a specific disease. It is interesting though that if carefully chosen, additional
features can significantly improve a classifier’s potential [12]. As a result, a fundamental challenge lies in the data-driven leveraging of multiple features while
capturing their complex interactions. However, in many cases, as the number of
features incorporated increases, so does the data sparsity, which imposes another
challenging factor even for state-of-the-art classification approaches, such as the
Support Vector Machines (SVMs).
In this work, we propose the application of Factorization Machines (FMs) [10]
in order to deal with the inherent sparsity of multi-feature Electronic Health
Record (EHR) data. The FM model class is mostly known for its success in
2
Ioakeim Perros and Jimeng Sun
Collaborative Filtering settings [7, 11], such as the rating prediction problem in
movie recommendation systems, where the underlying dataset is very sparse.
The prediction goal we are targeting is the detection of Type 2 Diabetes
Mellitus (T2DM). This task has urged an increasing interest [1, 2, 8], due to its
effect on a huge amount of the global population, and due to the fact that early
detection can significantly decrease diabetes’ complications [4].
In the following of this paper, we provide the Factorization Machines (FM)
framework as well as other related approaches in Section 2. In Section 3, we
describe our experiments, as well as our observations based on them. Finally,
Section 4 provides a summary of this work.
2
Approaches
In this Section, we provide a brief explanation of several model learning methods. This will facilitate the formulation of the FM framework, which follows
subsequently and will highlight its difference to other well-known model learning approaches.
2.1
Polynomial Regression
We will first state the classic polynomial regression model and stress out its
connections to the model class of FMs. The model equation of 2nd degree (d = 2)
polynomial regression for an input vector x ∈ Rp is as follows:
y(x) = w0 +
p
X
wi xi +
{z
linear terms
Wi,j xi xj ,
(1)
i=1 j≥i
i=1
|
p X
p
X
}
|
{z
non-linear terms
}
where w0 ∈ R, w ∈ Rp and W ∈ Rp×p .
A difference between the above model and the standard linear regression
lies in its capability of representing non-linear interactions between the input x
and the target y(x), up to the specified degree d. In practice though, it cannot
estimate non-linear interactions reliably, under the severe sparsity imposed by a
large number of variables.
We will illustrate this shortcoming by the following example: Assume the
existence of the design matrix as in Table 1, where each row of the dataset corresponds to a specific patient and each column represents a data feature. In this
case, we model each patient by the sex (male/female), year of birth (normalized), as well as with binary features indicating several diagnostic conditions.
It is remarkable that a standard polynomial regression model would not take
an interaction between Dyslipidemia and Obesity into account, despite the fact
that those two diagnostic conditions are obviously related to each other. This
issue is due to the lack of co-occurrence of them in a specific row of the design
matrix. In the case that d > 2, when higher-order interaction relationships are
Factorization Machines as a Tool for Healthcare
3
Female Male Year of Birth (norm.) Dyslipidemia Hypertension Obesity
1
0
0.1
1
0
0
0
1
1
0
1
1
0
1
0.4
1
1
0
Table 1: Example of a design matrix from EHR data
taken into account, this issue gets even worse. In general, the estimation of nway interaction effects for n features requires all of them to co-occur during the
training phase, which is restrictive for multi-dimensional, sparse datasets.
2.2
SVMs with Polynomial Kernel
State-of-the-art classifiers, such as the SVMs, model variable interactions through
the usage of the polynomial kernel. The model equation for a 2nd degree kernel
is as follows [9]:
y(x) = w0 +
p
p
p X
p
X
√ X
√ X
2
w i xi +
Wi,i x2i + 2
Wi,j xi xj ,
i=1
i=1
(2)
i=1 j>i
where w0 ∈ R, w ∈ Rp and W ∈ Rp×p .
It is evident that Equation 1 is equivalent to Equation 2 in terms of the
variable interactions that the corresponding models capture. The only difference
in the SVM case is that a scaling is imposed to some of the terms. In both cases
though, the interaction parameters Wi,j and Wi,l are independent.
2.3
Factorization Machines
The equation describing the model class of FMs for d = 2 is as follows [10]:
y(x) = w0 +
p
X
i=1
wi xi +
p X
p
X
hVi , Vj ixi xj
(3)
i=1 j>i


!2
p
p
k
X
1 X X
2 2
= w0 +
wi xi +
Vi,f xi
−
Vi,f
xi
2
i=1
i=1
i=1
p
X
(4)
f =1
where w0 ∈ R, w ∈ Rp , V ∈ Rp×k , h·, ·i corresponds to the vector inner product
and k to a low-rank hyperparameter. The equivalence between Equation 3 and
Equation 4 has been proved in [9].
A fundamental idea behind FMs is the replacement of model parameters
corresponding to feature interaction (W ∈ Rp×p ), with a low-rank factorized
surrogate (V ∈ Rp×k ). A stark difference between this representation and Equations 1, 2 lies in the fact that in this case, hVi , Vj i and hVi , Vl i are now dependent,
4
Ioakeim Perros and Jimeng Sun
since they share overlapping parameters (Vi ). Thus, variable interaction parameters can be estimated with little or no observation about the corresponding
pairs, since data for one interaction facilitate the parameter estimation of related interactions. As can be seen in Equation 4, there is no model parameter
that directly depends on two variables (parameter with index (i, j)), due to the
factorized representation. Also, Equation 4 has only linear complexity in both
k and p, establishing the model’s suitability for multidimensional settings and
large datasets.
MCMC learning In the remainder of this Section, we describe the Monte
Carlo Markov Chain (MCMC) inference method for Factorization Machines [3]
in high-level terms. This description justifies the choice of this inference method
for our experiments against the rest of available alternatives: Stochastic Gradient
Descent (SGD) and Alternating Least Squares (ALS) [10].
µ0 ,
µw
⇡,
w
⇡
wj
µvf,⇡ ,
v
f,⇡
µw
⇡
0
µvf,⇡
w
⇡
wj
Vj,f
↵ ,
v
f,⇡
Vj,f
f = 1, . . . , k
f = 1, . . . , k
xij
xij
j = 1, . . . , p
j = 1, . . . , p
w0
µ0 ,
0
i = 1, . . . , n
↵
yi
w0
yi
i = 1, . . . , n
µ0 ,
0
↵
↵0 ,
0
Fig. 1: Left panel: Standard FM Model. Right panel: FM model with hyperpriors,
used by the MCMC algorithm.
The parameters of the standard FM model, as well as those used by the
MCMC approach are illustrated in Figure 1. In general, an L2 regularization is
assumed, so as to prevent overfitting. As concerns the standard FM framework,
we can see that the variables are the predicted target y, the input features x, the
model parameters we intend to learn (w0 , wj , Vj,f ), as well as the hyperparameters/priors µ, λ, α, where µ, λ are the mean and the regularization for each one
of the model parameters and α corresponds to the prior precision (reciprocal of
Factorization Machines as a Tool for Healthcare
5
variance). The FM framework assumes also that an initialization parameter σ is
given, so that the low-rank matrix V of parameters is initialized by N (0, σ).
The ALS method requires the regularization values λ for each model parameter, as well as σ, used as initialization. On top of them, the SGD method also
requires a learning rate η to be given.
Unlike ALS and SGD which use fixed-point estimates, the MCMC method
generates the target through sampling. Gibbs sampling is used as a computationally tractable way to sample from the posterior distribution. As can be seen
in Figure 1, for every mean µ but µ0 of the model parameters a (Gaussian)
hyperprior is assumed and for every regularization parameter λ but λ0 and precision α a (Gamma) hyperprior is assumed. The core difference between the
MCMC and the rest of the mentioned methods lies in the fact that the regularization values are automatically determined (via sampling). At the same time,
the method is insensitive to the values of the new set of parameters introduced
{α0 , β0 , µ0 , γ0 , αλ , βλ }. As a result, in this case, the only parameter that needs
to be defined as an input is the initialization σ. The complete MCMC sampling
algorithm can be found in [10].
3
Experiments
3.1
Data description
The data we used originate from a Kaggle challenge on Diabetes classification 1 ,
developed in partnership with Practice Fusion 2 . The dataset consisted of 9948
patients. The number of diabetic ones is 1904. The design matrix we built is similar to the one given as an example in Table 1. It included information about the
patients’ sex, year of birth and the diagnostic conditions recorded (first threedigits in the International Classification of Diseases 9 format). The resulting
number of features was 702 and the density of the whole dataset was approximately equal to 0.01, i.e. the design matrix was sparse.
3.2
Experimental Setup
We used the libFM package (version 1.4.2) [10] in order to experiment with the
Factorization Machines framework. The learning method used was the Monte
Carlo Markov Chain (MCMC), since it is the most simple by involving the least
number of hyperparameters required. The maximum number of iterations was
set to 500. The parameter (σ) of the Gaussian distribution that is used for the
initialization of the V matrix was set to 0.01. We empirically remarked that this
value provided favorable convergence properties as compared to the alternatives
proposed in [10] (e.g., 0.1, 1).
We also used Matlab built-in functions for SVM with Gaussian kernel, as
well as Matlab’s built-in Random Forests’ implementation, for which we picked
500 trees (selected via cross-validation).
1
2
https://www.kaggle.com/c/pf2012-diabetes/data
http://www.practicefusion.com/
6
Ioakeim Perros and Jimeng Sun
0.57
Average precision
0.56
0.55
0.54
0.53
0.52
0.51
0
5
10
15
20
25
30
35
40
45
50
55
60
65
Model dimensionality
Fig. 2: Average precision of FM of increasing model complexity (varying k =
{1, 2, 4, 8, 16, 32, 64})
3.3
Results
Evaluation of classifier’s quality The results of our experiments are provided in Table 2. We used a 5-fold cross-validation for all methods tested. The
Accuracy
Factorization Machines 0.8187
SVM
0.8125
Random Forests
0.8178
Precision
0.5656
0.5438
0.6232
Recall
0.2334
0.1263
0.1242
F1-Score
0.3302
0.2046
0.2069
AUC
0.8048
0.7598
0.8048
Table 2: Results on detecting Type 2 Diabetes Mellitus
Factorization Machines’ result reported in this table uses a low-rank parameter
k = 4. The results clearly reflect that the FM framework is competent enough,
and even outperforms state-of-the-art classifiers under specific settings (such as
in terms of the F1-Measure). In terms of the Precision measure, its performance
is approximately the same as the one achieved by SVM, and a bit lower than
the one achieved by Random Forests. On the other hand, in terms of Recall, it
is 2x better than the other methods. Finally, as concerns the Area Under Curve
(AUC) measure, the FM model achieves a bit higher AUC than SVM and is on
the same regime with Random Forests.
In Figure 2, we illustrate the effect of the low-rank parameter k to the FM
framework. It is interesting that for k > 4, the model seems to overfit the data
and the average precision gradually diminishes.
Average time per fold (seconds)
Factorization Machines as a Tool for Healthcare
7
35
30
25
20
15
10
5
0
0
5
10
15
20
25
30
35
40
45
50
55
60
65
Model dimensionality
Fig. 3: Effect of increasing model complexity (varying k = {1, 2, 4, 8, 16, 32, 64})
on the FM training time
Evaluation of classifier’s scalability In Figure 3, we provide time measurements as the low-rank parameter k increases, whereas in Figure 4, the training
time is measured as the training dataset increases. In both cases, the method
maintains a linear scale-up, confirming its theoretical properties.
4
Concluding remarks
In this paper, we applied Factorization Machines to EHR data for classification of
diabetic patients. We argue that this method could prove to be a useful tool in the
Healthcare domain due to both the inherent sparsity of health-related datasets
and due to the inaccurate mechanism of recording patient-related information:
e.g. a patient may suffer from a specific disease, but she may not have visited
her physician in order to get the corresponding diagnosis recorded.
Future work will address the incorporation of more viewpoints of patients’
clinical status (such as the medications prescribed) and will investigate ways to
boost efficiency and accuracy of FMs, by taking domain knowledge into account.
References
1. Abbasi, A., Peelen, L.M., Corpeleijn, E., van der Schouw, Y.T., Stolk, R.P., Spijkerman, A.M., van der A, D.L., Moons, K.G., Navis, G., Bakker, S.J., et al.: Prediction models for risk of developing type 2 diabetes: systematic literature search and
independent external validation study. BMJ: British Medical Journal 345 (2012)
2. Ali, R., Siddiqi, M.H., Idris, M., Kang, B.H., Lee, S.: Prediction of diabetes mellitus
based on boosting ensemble modeling. In: Ubiquitous Computing and Ambient
Intelligence. Personalisation and User Adapted Services, pp. 25–28. Springer (2014)
8
Ioakeim Perros and Jimeng Sun
Training time (seconds)
3
2.5
2
1.5
1
0.5
0
1,000
2,000
3,000
4,000
5,000
6,000
7,000
8,000
Size of training data (#samples)
Fig. 4: Effect of increasing dataset size on the FM training time (fixed k = 4)
3. Freudenthaler, C., Schmidt-Thieme, L., Rendle, S.: Bayesian factorization machines (2011)
4. Harris, M.I., Eastman, R.C.: Early detection of undiagnosed diabetes mellitus: a
us perspective. Diabetes/metabolism research and reviews 16(4), 230–236 (2000)
5. Hirsch, A.T., Criqui, M.H., Treat-Jacobson, D., Regensteiner, J.G., Creager, M.A.,
Olin, J.W., Krook, S.H., Hunninghake, D.B., Comerota, A.J., Walsh, M.E., et al.:
Peripheral arterial disease detection, awareness, and treatment in primary care.
Jama 286(11), 1317–1324 (2001)
6. Koh, H.C., Tan, G., et al.: Data mining applications in healthcare. Journal of
healthcare information management 19(2), 65 (2011)
7. Loni, B., Said, A., Larson, M., Hanjalic, A.: ’free lunch’enhancement for collaborative filtering with factorization machines. In: Proceedings of the 8th ACM Conference on Recommender systems. pp. 281–284. ACM (2014)
8. Mani, S., Chen, Y., Elasy, T., Clayton, W., Denny, J.: Type 2 diabetes risk forecasting from emr data using machine learning. In: AMIA Annual Symposium Proceedings. vol. 2012, p. 606. American Medical Informatics Association (2012)
9. Rendle, S.: Factorization machines. In: Proceedings of the 2010 IEEE International
Conference on Data Mining. pp. 995–1000. ICDM ’10 (2010)
10. Rendle, S.: Factorization machines with libfm. ACM Transactions on Intelligent
Systems and Technology (TIST) 3(3), 57 (2012)
11. Rendle, S.: Scaling factorization machines to relational data. In: Proceedings of
the VLDB Endowment. vol. 6, pp. 337–348. VLDB Endowment (2013)
12. Sun, J., Reddy, C.K.: Big data analytics for healthcare. In: Proceedings of the 19th
ACM SIGKDD international conference on Knowledge discovery and data mining.
pp. 1525–1525. ACM (2013)
Download