Please Note: This Notebook was originally created in Databricks, then exported as a ".ipynb" for publishing here. You can import this notebook into Databricks via a link to the file on GitHub or by opening the notebook on Databricks public. Both links are presented at the bottom of this notebook. Sign up for the free community edition of Databricks, and learn more about Databricks community edition byreading the FAQ.

pyspark_linear_regression

Linear Regression in PySpark

This is a very basic introduction on how to build a linear regression model on Spark using Python.

Here are reference docs on Linear Regression in PySpark.

In [2]:
import numpy as np

# generate a random and uniform 2D matrix of correlated data
# source: https://stackoverflow.com/a/18684433/5356898

xx = np.array([-0.51, 51.2])
yy = np.array([0.33, 51.6])
means = [xx.mean(), yy.mean()]  
stds = [xx.std() / 3, yy.std() / 3]
corr = 0.8 # correlation
covs = [[stds[0]**2          , stds[0]*stds[1]*corr], 
        [stds[0]*stds[1]*corr,           stds[1]**2]] 

data = np.random.multivariate_normal(means, covs, 1000)
In [3]:
data.shape
In [4]:
data
In [5]:
rdd1 = sc.parallelize(data)
rdd2 = rdd1.map(lambda x: [float(i) for i in x])
df = rdd2.toDF(["y","x"])
In [6]:
display(df)
In [7]:
from pyspark.ml.regression import LinearRegression, LinearRegressionSummary
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml import Pipeline

assembler = VectorAssembler(inputCols=["x"], outputCol="features")

lr = LinearRegression(labelCol="y")

pipeline = Pipeline(stages=[assembler, lr])

train, test = df.randomSplit([0.75, 0.25])

model = pipeline.fit(train)

predictions = model.transform(test)

eval = RegressionEvaluator(labelCol="y", predictionCol="prediction")

# uncomment below for help
#help(eval)
#for line in eval.explainParams().split('\n'):
#  print(line)

print('RMSE:', eval.evaluate(predictions, {eval.metricName: "rmse"}))
print('R-squared:', eval.evaluate(predictions, {eval.metricName: "r2"}))