Testing in Apache Spark Structured Streaming


Currently, there are not many sample tests for Spark Structured Streaming based applications. Therefore, this article provides basic test examples with detailed descriptions.

All examples use: Apache Spark 3.0.1.


You need to install:

  • Apache Spark 3.0.x
  • Python 3.7 and the virtual environment for it
  • Conda 4.y
  • scikit-learn 0.22.z
  • Maven 3.v
  • The Scala examples use version 2.12.10.

  1. Download Apache Spark
  2. Unpack: tar -xvzf ./spark-3.0.1-bin-hadoop2.7.tgz
  3. Create environment, for example, with conda: conda create -n sp python = 3.7

You need to set up environment variables. Here's an example for running locally.



An example with scikit-learn

When writing tests, you need to separate your code so that you can isolate the logic and actual use of the final API. Good example of isolation: DataFrame-pandas , DataFrame-spark .

The following example will be used to write tests: LinearRegression .

So let's test the code using the following Python "template":

class XService:

    def __init__(self):

    def train(self, ds):

    def predict(self, ds):


For Scala, the template looks like this.

Complete example:

from sklearn import linear_model

class LocalService:

    def __init__(self):
        self.model = linear_model.LinearRegression()

    def train(self, ds):
        X, y = ds
        self.model.fit(X, y)

    def predict(self, ds):
        r = self.model.predict(ds)




import unittest
import numpy as np

Main class:

class RunTest(unittest.TestCase):

Running tests:

if __name__ == "__main__":

Data preparation:

X = np.array([
    [1, 1],  # 6
    [1, 2],  # 8
    [2, 2],  # 9
    [2, 3]  # 11
y = np.dot(X, np.array([1, 2])) + 3  # [ 6  8  9 11], y = 1 * x_0 + 2 * x_1 + 3

Model creation and training:

service = local_service.LocalService()
service.train((X, y))

Getting results:

service.predict(np.array([[3, 5]]))
service.predict(np.array([[4, 6]]))




import unittest
import numpy as np

from spark_streaming_pp import local_service

class RunTest(unittest.TestCase):
    def test_run(self):
        # Prepare data.
        X = np.array([
            [1, 1],  # 6
            [1, 2],  # 8
            [2, 2],  # 9
            [2, 3]  # 11
        y = np.dot(X, np.array([1, 2])) + 3  # [ 6  8  9 11], y = 1 * x_0 + 2 * x_1 + 3
        # Create model and train.
        service = local_service.LocalService()
        service.train((X, y))
        # Predict and results.
        service.predict(np.array([[3, 5]]))
        service.predict(np.array([[4, 6]]))
        # [16.]
        # [19.]

if __name__ == "__main__":

Example with Spark and Python

self.service = LinearRegression(maxIter=10, regParam=0.01)
self.model = None


self.model = self.service.fit(ds)


transformed_ds = self.model.transform(ds)
q = transformed_ds.select("label", "prediction").writeStream.format("console").start()
return q


from pyspark.ml.regression import LinearRegression

class StructuredStreamingService:
    def __init__(self):
        self.service = LinearRegression(maxIter=10, regParam=0.01)
        self.model = None

    def train(self, ds):
        self.model = self.service.fit(ds)

    def predict(self, ds):
        transformed_ds = self.model.transform(ds)
        q = transformed_ds.select("label", "prediction").writeStream.format("console").start()
        return q



train_ds = spark.createDataFrame([
    (6.0, Vectors.dense([1.0, 1.0])),
    (8.0, Vectors.dense([1.0, 2.0])),
    (9.0, Vectors.dense([2.0, 2.0])),
    (11.0, Vectors.dense([2.0, 3.0]))
    ["label", "features"]


def test_stream_read_options_overwrite(self):
    bad_schema = StructType([StructField("test", IntegerType(), False)])
    schema = StructType([StructField("data", StringType(), False)])
    df = self.spark.readStream.format('csv').option('path', 'python/test_support/sql/fake') \
        .load(path='python/test_support/sql/streaming', schema=schema, format='text')
    self.assertEqual(df.schema.simpleString(), "struct<data:string>")



spark = SparkSession.builder.enableHiveSupport().getOrCreate()

train_ds = spark.createDataFrame([
    (6.0, Vectors.dense([1.0, 1.0])),
    (8.0, Vectors.dense([1.0, 2.0])),
    (9.0, Vectors.dense([2.0, 2.0])),
    (11.0, Vectors.dense([2.0, 3.0]))
    ["label", "features"]


service = structure_streaming_service.StructuredStreamingService()

def extract_features(x):
    values = x.split(",")
    features_ = []
    for i in values[1:]:
    features = Vectors.dense(features_)
    return features

extract_features_udf = udf(extract_features, VectorUDT())

def extract_label(x):
    values = x.split(",")
    label = float(values[0])
    return label

extract_label_udf = udf(extract_label, FloatType())

predict_ds = spark.readStream.format("text").option("path", "data/structured_streaming").load() \
    .withColumn("features", extract_features_udf(col("value"))) \
    .withColumn("label", extract_label_udf(col("value")))





import unittest
import warnings

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, udf
from pyspark.sql.types import FloatType
from pyspark.ml.linalg import Vectors, VectorUDT

from spark_streaming_pp import structure_streaming_service

class RunTest(unittest.TestCase):
    def test_run(self):
        spark = SparkSession.builder.enableHiveSupport().getOrCreate()
        # Prepare data.
        train_ds = spark.createDataFrame([
            (6.0, Vectors.dense([1.0, 1.0])),
            (8.0, Vectors.dense([1.0, 2.0])),
            (9.0, Vectors.dense([2.0, 2.0])),
            (11.0, Vectors.dense([2.0, 3.0]))
            ["label", "features"]
        # Create model and train.
        service = structure_streaming_service.StructuredStreamingService()

        # Predict and results.

        def extract_features(x):
            values = x.split(",")
            features_ = []
            for i in values[1:]:
            features = Vectors.dense(features_)
            return features

        extract_features_udf = udf(extract_features, VectorUDT())

        def extract_label(x):
            values = x.split(",")
            label = float(values[0])
            return label

        extract_label_udf = udf(extract_label, FloatType())

        predict_ds = spark.readStream.format("text").option("path", "data/structured_streaming").load() \
            .withColumn("features", extract_features_udf(col("value"))) \
            .withColumn("label", extract_label_udf(col("value")))


        # +-----+------------------+
        # |label|        prediction|
        # +-----+------------------+
        # |  1.0|15.966990887541273|
        # |  2.0|18.961384020443553|
        # +-----+------------------+

    def setUp(self):
        warnings.filterwarnings("ignore", category=ResourceWarning)
        warnings.filterwarnings("ignore", category=DeprecationWarning)

if __name__ == "__main__":

implicit val sqlCtx = spark.sqlContext
import spark.implicits._
val source = MemoryStream[Record]
source.addData(Record(1.0, Vectors.dense(3.0, 5.0)))
source.addData(Record(2.0, Vectors.dense(4.0, 6.0)))
val predictDs = source.toDF()

Scala (, , sql):

package aaa.abc.dd.spark_streaming_pr.cluster

import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.streaming.StreamingQuery

class StructuredStreamingService {
  var service: LinearRegression = _
  var model: LinearRegressionModel = _

  def train(ds: DataFrame): Unit = {
    service = new LinearRegression().setMaxIter(10).setRegParam(0.01)
    model = service.fit(ds)

  def predict(ds: DataFrame): StreamingQuery = {

    val m = ds.sparkSession.sparkContext.broadcast(model)

    def transformFun(features: org.apache.spark.ml.linalg.Vector): Double = {

    val transform: org.apache.spark.ml.linalg.Vector => Double = transformFun

    val toUpperUdf = udf(transform)

    val predictionDs = ds.withColumn("prediction", toUpperUdf(ds("features")))

      .foreachBatch((r: DataFrame, i: Long) => {
        // scalastyle:off println
        // scalastyle:on println


package aaa.abc.dd.spark_streaming_pr.cluster

import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.scalatest.{Matchers, Outcome, fixture}

class StructuredStreamingServiceSuite extends fixture.FunSuite with Matchers {
  test("run") { spark =>
    // Prepare data.
    val trainDs = spark.createDataFrame(Seq(
      (6.0, Vectors.dense(1.0, 1.0)),
      (8.0, Vectors.dense(1.0, 2.0)),
      (9.0, Vectors.dense(2.0, 2.0)),
      (11.0, Vectors.dense(2.0, 3.0))
    )).toDF("label", "features")
    // Create model and train.
    val service = new StructuredStreamingService()
    // Predict and results.
    implicit val sqlCtx = spark.sqlContext
    import spark.implicits._
    val source = MemoryStream[Record]
    source.addData(Record(1.0, Vectors.dense(3.0, 5.0)))
    source.addData(Record(2.0, Vectors.dense(4.0, 6.0)))
    val predictDs = source.toDF()
    // +-----+---------+------------------+
    // |label| features|        prediction|
    // +-----+---------+------------------+
    // |  1.0|[3.0,5.0]|15.966990887541273|
    // |  2.0|[4.0,6.0]|18.961384020443553|
    // +-----+---------+------------------+

  override protected def withFixture(test: OneArgTest): Outcome = {
    val spark = SparkSession.builder().master("local[2]").getOrCreate()

    try withFixture(test.toNoArgTest(spark))

    finally spark.stop()

  override type FixtureParam = SparkSession

  case class Record(label: Double, features: org.apache.spark.ml.linalg.Vector)


