Heart disease detection using machine learning and the big data stack.

Rajat M
5 min readMay 27, 2021

--

Big Data and machine learning combination is a revolutionary technology that can make a great impact on any industry if used in a proper way. In the field of healthcare it has great usage in cases like early disease detection, finding signs of early breakage of epidemics and using clustering to figure out regions of epidemics (e.g. like ‘Zika’ prone areas) or finding best air quality zones in countries with high air pollution.

In this article i have tried to explore the prediction of existence of heart disease by using standard machine learning algorithms, and the big data toolset like apache spark, parquet, spark mllib and spark sql.

Source Code :

The source code of this article is available on github here

Dataset Used :

Heart Disease Dataset is a very well studied dataset by researchers in machine learning and is freely available at the UCI machine learning dataset repository here. Though there are 4 datasets in this i have used the cleveland dataset that has 14 main features. The features or attributes are

  • age — age in years
  • sex — sex (1 = male; 0 = female)
  • cp: chest pain type
  • trestbpss: resting blood pressure (in mm Hg on admission to the hospital)
  • chol: serum cholestoral in mg/dl
  • fbs: (fasting blood sugar > 120 mg/dl) (1 = true; 0 = false)
  • restecg: resting electrocardiographic results
  • thalach : maximum heart rate achieved
  • exang : exercise induced angina (1 = yes; 0 = no)
  • oldpeak : ST depression induced by exercise relative to rest
  • slope : the slope of the peak exercise ST segment
  • ca : number of major vessels (0–3) colored by flourosopy
  • thal : 3 = normal; 6 = fixed defect; 7 = reversable defect
  • num : diagnosis of heart disease (angiographic disease status)

Technology Used :

  • Apache Spark : Apache Spark is one of the toolset from the big data stack and is essentially the big brother of the old map reduce technology. It is much faster in performance and also it is much easier to code in apache spark as compared to mapreduce. RDD (the resilient distributed dataset) which a lot developers use as a normal variable is the crux of the whole apache spark piece but behind the scene it nicely handles all the distributed computing work. Spark comes with other cool packages like spark streaming, spark sql (which i would use in this article to analyse the dataset), spark mllib(which i would use to apply the machine learning piece). The documentation for spark from the spark page is excellent and can be found here.
  • Spark SQL : SQL like API from Spark that supports DataFrames (almost similar to Pandas library from Python but this one runs over a full distributed dataset and hence doesnot have all the similar functions).
  • Parquet : Parquet is a columnar file format. The raw data files are parsed and stored in parquet format. This helps in speeding up the aggregation queries a lot. A columnar format helps in choosing only the columns that are needed and hence reduces disk I/O tremendously.
  • Spark MLLib : Machine Learning library from spark. The algorithms in this library are optimized to run over a distributed dataset. This is the main difference between this library and the other popular libraries like SciKit that run in a single process.
  • HDFS : for storing the raw files, storing the generated model and storing the results.

Design :

  • Model Generation and Storage Layer:
  • As shown in the image above the raw files are either pulled into HDFS or they are pushed by some programs directly into HDFS. The file or data can also be received via kafka topics and read using spark streaming. As for this article and the sample code in github i am assuming that the raw files reside in HDFS.
  • The files are read via the Spark Program in Java (it can be in in python or scala too).
  • The files contain data that has to be adapted into the format that the model requires. The model requires all numbers. Some of the datapoints have null or no values and they are replaced by a large value like ‘99.0’ that has no specific meaning except it only helps in passing the null validation. Also the last ‘num’ parameters is converted to digits of either ‘1’ or ‘0’ based on whether user has or doesn’t have heart disease. Thereby any value in the last ‘num’ field which is greater than ‘1’ is converted to ‘1’ and it means that the heart disease exists.
  • The data files are now read into an RDD.
  • For this dataset i have used Naive Bayes algorithm (the same algorithm that is used in Spam Filters). Using the machine learning library from Spark (mllib), the algorithm is now trained with the data from the dataset. Note: Decision Tree algorithm might also give good results in this case.
  • After the algorithm is trained the model is now stored into an external storage on the hdfs for future use for making predictions on the test data.

2. Data Analysis Layer

This layer is used for the analysis of the training data for queries like min age of the person with the disease, total number of women vs men with the disease, which parameter is almost always present when the disease occurs, no. of people with symptoms but don’t have disease etc.

  • The code for the analysis layer is also bunded into the github location here.
  • For running data analysis on the training data first load the full data (clean data) into an rdd using textfile.
  • Now save this rdd to external storage in parquet format.
  • From another program load the data into a dataframe from this parquet storage.
  • Using spark sql now run queries on the dataframe for general analytics. A sample spark sql query is shown below.

DataFrame df = <Built from training Data>;

df.registerTempTable(“heartDisData”);

DataFrame results = sqlCtx.sql(“select min(age) from heartDisData”);

3. Disease Prediction Layer (refer to the code in github)

  • Copy the test data into HDFS
  • Now load the test data into an RDD using apache spark.
  • Clean and adapt the test data to the model
  • Load the model from storage using spark mllib e.g NaiveBayesModel _model = NaiveBayesModel.load(<Spark Context>, <Model Storage Location>);
  • Use the model object to predict the presence of disease e.g. JavaRDD<Double> predictedResults = _model.predict(fRdd);

Problem with the Above Design :

Most important issue with any disease prediction system is accuracy. A false negative in the result can be a dangerous prediction that can get a disease unnoticed.

Deep learning has evolved to give much better predictions then regular machine learning algorithms.In a future article i would try exploring doing the same disease prediction via deep learning neural networks.

Summary :

Using tools like Apache Spark and it’s machine learning library we were easily able to load a heart disease dataset (from UCI) and trained regular machine learning model. This model was later used to predict the existence of heart disease on test samples of data.

The source code for this is available on github here

--

--

Rajat M
Rajat M

Written by Rajat M

Software Engineer | Engineering Manager | MicroServices | Java | Machine Learning | NoSQL | Distributed Systems and more

No responses yet