Consecutive grouping in Apache Spark

How do you find the longest streak of active workout days from your workout data?

·

5 min read

Consecutive grouping in Apache Spark

Photo by Jake Hills on Unsplash

Introduction

Assume that you have collected your daily data of active workout minutes from your smart watches/health monitoring devices, and you want to find the longest streak of consecutive days that you work out more than 30 minutes, so you can win a workout competition with your friends.

The data might look like this:

NameDateWorkout minutes
Alice2023-01-0115
Alice2023-01-0245
Alice2023-01-0332
Alice2023-01-0448
Alice2023-01-0523
Bob2023-01-010
Bob2023-01-0292
Bob2023-01-0331
Bob2023-01-0412
Bob2023-01-0542

After grouping workouts by continuous days which have at least 30 minutes of workout, the intermediate result should look like this:

NameFromToStreak
Alice2023-01-022023-01-043
Bob2023-01-022023-01-032
Bob2023-01-052023-01-051

The expected result should look like this:

NameBest streak
Alice3
Bob2

So Alice is the winner.

How do you calculate your longest active streak in the last year in Spark Dataframe API? And how about Spark SQL?

The answer is a technique called consecutive row grouping.

In this blog, I will demonstrate this technique using Apache Spark Dataframe API and Spark SQL.

Consecutive grouping with Spark Dataframe API

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import Window
from datetime import date

# create Spark Session
spark = SparkSession.builder.getOrCreate()

# create dataframe
data = [
    ("Alice", date.fromisoformat("2023-01-01"), 32),
    ("Alice", date.fromisoformat("2023-01-02"), 41),
    ("Alice", date.fromisoformat("2023-01-03"), 19),
    ("Alice", date.fromisoformat("2023-01-04"), 52),
    ("Alice", date.fromisoformat("2023-01-05"), 91),
    ("Alice", date.fromisoformat("2023-01-06"), 12),
    ("Alice", date.fromisoformat("2023-01-07"), 49),
    ("Alice", date.fromisoformat("2023-01-08"), 55),
    ("Alice", date.fromisoformat("2023-01-09"), 66),
    ("Alice", date.fromisoformat("2023-01-10"), 77),
    ("Bob", date.fromisoformat("2023-01-01"), 10),
    ("Bob", date.fromisoformat("2023-01-02"), 22),
    ("Bob", date.fromisoformat("2023-01-03"), 31),
    ("Bob", date.fromisoformat("2023-01-04"), 42),
    ("Bob", date.fromisoformat("2023-01-05"), 10),
    ("Bob", date.fromisoformat("2023-01-06"), 120),
    ("Bob", date.fromisoformat("2023-01-07"), 1),
    ("Bob", date.fromisoformat("2023-01-08"), 22),
    ("Bob", date.fromisoformat("2023-01-09"), 33),
    ("Bob", date.fromisoformat("2023-01-10"), 44),
]
df = spark.createDataFrame(data, schema="name string, date date, minutes integer")

# calculate previous day value of hit using lag function with offset=1
window_spec_01 = Window.partitionBy("name").orderBy("date")
lagged_df = df.withColumn("hit_day", F.col("minutes") >= 30).withColumn(
    "previous_day", F.lag("hit_day").over(window_spec_01)
)

# calculate cummulated number of time the hit value changed,
# make them become grouping value for streaks
window_spec_02 = (
    Window.partitionBy("name")
    .orderBy("date")
    .rowsBetween(Window.unboundedPreceding, Window.currentRow)
)
grouped_df = lagged_df.withColumn(
    "streak_end",
    (F.col("previous_day").isNull()) | (F.col("previous_day") != F.col("streak")),
).withColumn("streak_group", F.count_if("streak_end").over(window_spec_02))

# find start and end date of each streak
streak_df = (
    grouped_df.filter(F.col("hit_day").__eq__(True))
    .groupBy("name", "streak_group")
    .agg(F.min("date").alias("min_date"), F.max("date").alias("max_date"))
)

# find the best streak for each person
best_streak_df = (
    streak_df.groupBy("name")
    .agg(
        F.max(1 + F.col("max_date") - F.col("min_date"))
        .cast("integer")
        .alias("best_streak_duration")
    )
    .show()
)

The result:

+-----+--------------------+
| name|best_streak_duration|
+-----+--------------------+
|Alice|                   4|
|  Bob|                   2|
+-----+--------------------+

Consecutive grouping with Spark SQL

df.createOrReplaceTempView("workout_history")

spark.sql(
    """
    WITH temp AS (
    SELECT
        name,
        date,
        minutes,
        (minutes >= 30) AS hit_day
    FROM
        workout_history),
    lagging AS (
    SELECT
        name,
        date,
        minutes,
        hit_day,
        lag(hit_day) OVER (
            PARTITION BY name
            ORDER BY date
        ) AS previous_day
    FROM
        temp),
    streaks AS (
    SELECT
        name,
        date,
        minutes,
        hit_day,
        count_if(previous_day IS NULL OR hit_day != previous_day)
            OVER (
                PARTITION BY name
                ORDER BY date
        ) AS streak_group
    FROM
        lagging
        ),

    grouped AS (
        SELECT
            name,
            min(date) AS min_date,
            max(date) AS max_date,
            sum(minutes) AS total_minutes
        FROM
            streaks
        WHERE
            hit_day IS TRUE
        GROUP BY
            name,
            streak_group
    )

    SELECT
        name,
        CAST(max(1 + max_date - min_date) as INTEGER) AS best_streak_duration
        from grouped
        group by name
    """
).show()

Note: in Spark SQL, the range ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW is used as the default for window frames, so we don't have to explicitly define it in this case.

How does it work?

Consecutive grouping in SQL typically refers to grouping consecutive rows in a result set that have the same or similar values in a particular column. This can be useful when you want to analyze or summarize data that is organized in a sequential or time-based manner.

The SQL window functions LAG() is used in conjunction with ORDER BY to achieve the value of the previous row in the order defined by a sequential column.

Then the trick is to group data rows into consecutive groups by accumulatively counting the number of times the value of the specified column changed in comparison to the previous row value, using the count_if() function over a second window function between UNBOUNDED PRECEDING and CURRENT ROW.

Summary

Consecutive row grouping can be a powerful tool for various data analysis and reporting tasks, allowing you to segment and analyze data in a meaningful way based on the sequential order of rows in your database table.

This technique is useful for various scenarios:

  • In time series analysis, for example, you want to group consecutive data points based on a particular condition.

  • To identify sessions, also called sessionization. This terminology also appeared a lot in processing streaming data context.

  • Finding gaps in sequential data.

  • Tracking changes in state data.

Reference

I learned this technique from this blog. I just want to show how to implement it in my familiar framework (Apache Spark).

Future

I want to revisit this technique, but in the streaming world, where everything is more complicated.

Thank you for reading this blog. Hope you find it useful.