How to build a classification model using apache spark

By NIIT Editorial

Published on 18/08/2021

8 minutes

Apache Spark is an open-source, classified processing system used for big data workloads. It uses in-memory caching and optimized query accomplishment for fast queries against data of any size. In other words, Spark is a very fast and general engine used for large-scale data processing.

Spark’s library for machine learning is known as MLlib (Machine Learning library). It’s based on Scikit-learn’s ideas on pipelines. To create an ML model in this library the basics concepts are:

  • DataFrame: This ML API makes use of Spark SQL's DataFrame as an ML dataset that can hold a range of data types. E.g., a DataFrame can have different columns storing text, feature vectors, true labels, and predictions.
  • Transformer: It is an algorithm for transforming one DataFrame into another.  The data is in csv format, therefore we'll use the spark-csv library to load it.
  • Estimator: An Estimator is an algorithm that can be used to build a Transformer from a DataFrame. A learning algorithm, for example, is an Estimator that guides on a DataFrame and generates a model.
  • Pipeline: To specify an ML workflow, a Pipeline connects many Transformers and Estimators.
  • Parameter: For specifying parameters, all Transformers and Estimators now use the same API.

Load and Analyze data

To analyze the data, first, we have to load data into Spark. Download the train.csv file and open the file and check the content.

$head train_data.csv

You will notice that the file contains a header row that has PassengerId, Survived, Pclass, Name, Sex, Age, SibSp ,Parch , Ticket, Fare,Cabin and Embarked. You can obtain additional information about each of these fields from the Kaggle website. Transfer this file to a folder in HDFS(I have kept mine at /kaggle/titanic/train.csv). The data is in csv format, therefore we'll use the spark-csv library to load it.

We will establish a simplistic load function that can be used to load csv files. First, start your spark-shell using the following command.


Note: You will have to import some classes for this project.


def loadData(file_path: String, sql_data: SQLContext, data_features: String*): DataFrame = {
    var source = "com.databricks.spark.csv"
    var dataValues =
      .toDF(data_features: _*)
      .option("inferSchema", "true")
    return dataValues


The process considers three parameters: the path to the csv file, sqlContext, and a featuresArr, which is used to give the names of the columns as they are loaded. We don’t have to give the featuresArr in this case because our csv file includes header information. If not, the column names would have been allotted default values such as C0, C1 etc.

Use the load method defined, to load the csv file and create a DataFrame. 

var file_path = "train_data.csv"
var all_data = load(file_path,sql_data,"Pclass", "Name", "Sex", "Age", "Cabin", "Embarked", "SibSp", "Parch", "Ticket", "PassengerId", "Survived","Fare"

Note: We are maintaining the dataFrame in-memory by calling cache(), this will enhance the performance during model building.

Now we will examine the loaded DataFrame to interpret the data better. We can examine the schema of the loaded data by calling 


The spark-csv library has gathered the data type of each column. If we go back and check the load method we can see that we have used .option("i.option("inferSchema", "true")nferSchema", "true") which indicates the library to do so. If not fixed, all the fields will be set to type string. show() method in DataFrame can be used to demonstrate the dataframe in tabular form. We can also transfer an int to this method to check how many rows are to be displayed. E.g.,

To see stats related to any numerical column use dataFrame.describe("column"). e.g.,



  • Fill missing values

On examining the data, we will find some irregularities. For instance, there are few missing values in column Age. Likewise, there are null/missing values in Cabin, Fare and Embarked. There are numerous techniques for filling in the missing values. 

  1. Ignore/drop the rows having missing values. This can be done in spark by using the following: 

var all_data =

  1. If the column is numeric,we can fill in the missing value with the mean/avg value of the column. We will replace the missing values in the Age column by using the following method.

var average_age ="Age"))
average_age = average_age.first()(0).asInstanceOf[Double]
//fill missing values with average data
all_data =, Seq("Age"))


  1. If the column is categorical, we can fill it with the most occurring category. 

// handling categorical data

all_data ="M", Seq("Sex"))

Note: We are not going to use above for this example.Here “Embarked” is categorical data. 

  • Now we can build a machine learning model which can predict those missing values.

Discover new features

In many situations,  there will be features in the input data that can be utilized to derive new features which will help build a better model. This is known as Feature Engineering. For instance, we can start extracting the age from each row and form a new column/feature. The udf check_infant is used for extricating a group from a given string.

val check_infant = sql_datat.udf.register("check_infant", (Age: Double, gender: String) => {
    if (Age > 5)

DataFrame presents a method with a column that can be utilized for adding/replacing an existent column. It accepts two parameters: the new column's name and the current DataFrame's Column.

val test_column_add  = all_data.withColumn("test_add_column",all_data("Age")-10)
// See results all values from age is copied to test_add_column with difference of 10"Age","test_add_column")show(10)


We will now implement the function check_infant on the age column.


all_data = all_data.withColumn("Sex", check_infant(all_data("Age"), all_data("Sex"))"Age","Sex").show()


Similarly we will define other udfs, to generate new features.

//Integer to double
val double_val = sql_data.udf.register("double_val", ((value: Int) => { value.toDouble }))

//create new column with family if there are more than 5 members
val check_family = sql_data.udf.register("check_family", (Parents: Int,Siblings : Int) => {
    if (Siblings + Parents < 5)

//adding columns to dataframe
all_data = all_data.withColumn("Age", double_val(all_data("Age"))
all_data = all_data.withColumn("Survived", double_val(all_data("Survived"))
all_data = all_data.withColumn("With_Family", check_family(all_data("Parch"),all_data("SibSp"))
//check new columns"Age","SUrvived","With_Family").show(10)


Pipeline Components

ML pipeline has a series of Pipeline components. Transformers and Estimators are the two sorts of components that are there. Transformers modify the input Dataframe into a new DataFrame using the method transform(). An Estimator first implements a model to data, using the method fit() and then does transform. Let's know about this concept by the following components.


The features in a Spark model should all be of the type Double, however we have a handful that are of the type String. Spark provides a Feature Transformer - StringIndexer that can be utilized for this transformation.

val Name_Rank = new StringIndexer().setInputCol("Name").setOutputCol("NameRank")

StringIndexer is a type of Estimator that alters the column Name, creates indices for the words, and creates a new column entitled Name Rank. The Fit method of the StringIndexer transforms the column to StringType (if it isn't already) and then estimates the number of times each word appears. It then sorts these words in the category of descending order of their repetition and assigns an index to each word. The method returns a Transformer StringIndexerModel.

//Executing fit and transform
var n_model =
//See ranks of each name in the dataset

.transform() indicates the generated index to each value of the column in the provided DataFrame.

Note that we need not call methods fit() or transform() again, which will be taken care of by the Pipeline. The pipeline will administer each stage and transfer the result of the current stage to the next. If a stage is a Transformer, Pipeline will call transform() on it, or if it is an Estimator, the pipeline will first call fit() and then transform(). The transform() function will not be called if the Estimator is the last phase in a pipeline.

Binning / Bucketing

During Binning/Bukceting, a column with consecutive values is converted into buckets. While building the Bucketizer, which is a Transformer, we define the start and end values of each bucket. We are going to bucketize the column ‘Age’.

val splits = Array(0,5,15,25,30,40,Int.PositiveInfinity)
val buckets = new Bucketizer().setInputCol("Age").setOutputCol("Age_bracket").setSplits(splits)

Vector Assembler

VectorAssembler is utilized for gathering features into a vector. We'll provide all of the columns we'll need for the prediction to the VectorAssembler, which will create a new vector column.

val assemble_model = new VectorAssembler().setInputCols(Array("Name_Rank", "Fare", "Age_Rank","With_Family","Age_Bracket")).setOutputCol("Factors")



Next, we will systematise the data using the transformer - Normalizer. The normalizer will take the VectorAssembler's column, normalize it, and generate a new column.


val normalizer_model = new Normalizer().setInputCol("Factors").setOutputCol("Normalized_Factors")


Building and Evaluating Model

We will build our model applying the LogisticRegression algorithm which is used for classification. The variable that is being analysed is called the dependent variable and other variables which decide the value of the dependent variable are called independent variables.

Logistic regression, based on the values of the independent variables, foretells the probability that the dependent variable takes one of its categorical values(classes). In our instance, there are two possible classes 0 or 1. To create a LogitsticRegression component,

val train_model = new LogisticRegression().setMaxIter(20)

Create Pipeline

Using all the components we explained till now, create a pipeline object. As mentioned above, a pipeline has a set of stages and each component we add is a stage in the pipeline. The pipeline will administer each stage one after another, first executing the fit(if Evaluator) and then passing the result of transform on to the next stage.

val pip = new Pipeline().setStages(Array(Age_Rank, Name_Rank, Age_Bracket, assemble_model, normalizer_model,train_model))

Training set & Test set

To evaluate the model, we will break our data into two - training sets (80%) and test sets (20%). We will create our model using the training set and evaluate it by using the test set. Use the area under the ROC curve to conclude how good the model is. To split input data,


val split_values = all_data.randomSplit(Array(0.9, 0.1))
//test data
val test_data = split_values(1).cache()
//train data
val train_data = split_values(0).cache()

The pipeline will now be used to fit our training data. The outcome of the fitting pipeline on our training data is a PipelineModel object which can be used to do foresight on test data.

var final_model =
var predictions = final_model.transform(test_data)


Note that the model object here is an example of PipelineModel, not LogisticRegression. This is due to the fact that LogisticRegression is only one of our PipelineModel's components. When a data set is predicted, it must first go through all of the transformations performed by other components in the Pipeline before being used by the LogisticRegression component for prediction.

To compute  how well the model did, select the columns ‘prediction’ and ‘Survived’ from the result, create an RDD of [(Double, Double)] and pass it on to BinaryClassificationMetrics.


predictions ="Survived","Prediction")
val labels = { row =>
val ROC = new BinaryClassificationMetrics(labels).areaUnderROC()


The prediction that we performed now, was on our input data where we were aware of the actual classification. The reason why we split the data into train and test sets is that we had to compare actual results with predicted results for evaluating the model. Now we can use the whole input data to train the model again.

model =

 Doing the Prediction

We can download test.csv from Kaggle and put it in the HDFS. The test data(submission data) has to go through all loading and pre-process steps done on the training data with an additional element of adding the column ‘Survived’ because test.csv does not contain the column ‘Survived’. Loading and pre-processing of test data same as we did above. 

To make the prediction, use the PipelineModel object that was created during model construction.

prediction = model.transform(test_data)

 Let's see what our model predicted for the first three passengers in the test data."PassengerId","final prediction").show(10)

Passengers with IDs 892 and 894 are not likely to survive, whereas Passenger 893 is expected to live.


The Apache Spark machine learning library (MLlib) enables data scientists to concentrate on their data problems and models rather than solving the complexities surrounding distributed data (such as infrastructure, configurations, and so on). Apache Spark is noted for being a simple, quick, and easy-to-use big data processing engine with built-in streaming, SQL, Machine Learning (ML), and graph processing components. This technology is an in-demand work for data engineers, but data scientists can profit from learning Spark while performing Exploratory Data Analysis (EDA), feature extraction and, of course, ML. You can get a better insight in the concepts of Machine Learning by opting Advanced PGP in Data Science and Machine Learning (Full Time)  or Advanced PGP in Data Science and Machine Learning (Part Time) courses from NIIT.

Advanced PGP in Data Science and Machine Learning (Full-Time)

Be job-ready! Earn a min. CTC of ₹8LPA with this placement-assured program*

Placement Assured Program*

Practitioner Designed