Consecutive grouping in Apache Spark
How do you find the longest streak of active workout days from your workout data?
Photo by Jake Hills on Unsplash
Introduction
Assume that you have collected your daily data of active workout minutes from your smart watches/health monitoring devices, and you want to find the longest streak of consecutive days that you work out more than 30 minutes, so you can win a workout competition with your friends.
The data might look like this:
Name | Date | Workout minutes |
Alice | 2023-01-01 | 15 |
Alice | 2023-01-02 | 45 |
Alice | 2023-01-03 | 32 |
Alice | 2023-01-04 | 48 |
Alice | 2023-01-05 | 23 |
Bob | 2023-01-01 | 0 |
Bob | 2023-01-02 | 92 |
Bob | 2023-01-03 | 31 |
Bob | 2023-01-04 | 12 |
Bob | 2023-01-05 | 42 |
After grouping workouts by continuous days which have at least 30 minutes of workout, the intermediate result should look like this:
Name | From | To | Streak |
Alice | 2023-01-02 | 2023-01-04 | 3 |
Bob | 2023-01-02 | 2023-01-03 | 2 |
Bob | 2023-01-05 | 2023-01-05 | 1 |
The expected result should look like this:
Name | Best streak |
Alice | 3 |
Bob | 2 |
So Alice is the winner.
How do you calculate your longest active streak in the last year in Spark Dataframe API? And how about Spark SQL?
The answer is a technique called consecutive row grouping.
In this blog, I will demonstrate this technique using Apache Spark Dataframe API and Spark SQL.
Consecutive grouping with Spark Dataframe API
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import Window
from datetime import date
# create Spark Session
spark = SparkSession.builder.getOrCreate()
# create dataframe
data = [
("Alice", date.fromisoformat("2023-01-01"), 32),
("Alice", date.fromisoformat("2023-01-02"), 41),
("Alice", date.fromisoformat("2023-01-03"), 19),
("Alice", date.fromisoformat("2023-01-04"), 52),
("Alice", date.fromisoformat("2023-01-05"), 91),
("Alice", date.fromisoformat("2023-01-06"), 12),
("Alice", date.fromisoformat("2023-01-07"), 49),
("Alice", date.fromisoformat("2023-01-08"), 55),
("Alice", date.fromisoformat("2023-01-09"), 66),
("Alice", date.fromisoformat("2023-01-10"), 77),
("Bob", date.fromisoformat("2023-01-01"), 10),
("Bob", date.fromisoformat("2023-01-02"), 22),
("Bob", date.fromisoformat("2023-01-03"), 31),
("Bob", date.fromisoformat("2023-01-04"), 42),
("Bob", date.fromisoformat("2023-01-05"), 10),
("Bob", date.fromisoformat("2023-01-06"), 120),
("Bob", date.fromisoformat("2023-01-07"), 1),
("Bob", date.fromisoformat("2023-01-08"), 22),
("Bob", date.fromisoformat("2023-01-09"), 33),
("Bob", date.fromisoformat("2023-01-10"), 44),
]
df = spark.createDataFrame(data, schema="name string, date date, minutes integer")
# calculate previous day value of hit using lag function with offset=1
window_spec_01 = Window.partitionBy("name").orderBy("date")
lagged_df = df.withColumn("hit_day", F.col("minutes") >= 30).withColumn(
"previous_day", F.lag("hit_day").over(window_spec_01)
)
# calculate cummulated number of time the hit value changed,
# make them become grouping value for streaks
window_spec_02 = (
Window.partitionBy("name")
.orderBy("date")
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
)
grouped_df = lagged_df.withColumn(
"streak_end",
(F.col("previous_day").isNull()) | (F.col("previous_day") != F.col("streak")),
).withColumn("streak_group", F.count_if("streak_end").over(window_spec_02))
# find start and end date of each streak
streak_df = (
grouped_df.filter(F.col("hit_day").__eq__(True))
.groupBy("name", "streak_group")
.agg(F.min("date").alias("min_date"), F.max("date").alias("max_date"))
)
# find the best streak for each person
best_streak_df = (
streak_df.groupBy("name")
.agg(
F.max(1 + F.col("max_date") - F.col("min_date"))
.cast("integer")
.alias("best_streak_duration")
)
.show()
)
The result:
+-----+--------------------+
| name|best_streak_duration|
+-----+--------------------+
|Alice| 4|
| Bob| 2|
+-----+--------------------+
Consecutive grouping with Spark SQL
df.createOrReplaceTempView("workout_history")
spark.sql(
"""
WITH temp AS (
SELECT
name,
date,
minutes,
(minutes >= 30) AS hit_day
FROM
workout_history),
lagging AS (
SELECT
name,
date,
minutes,
hit_day,
lag(hit_day) OVER (
PARTITION BY name
ORDER BY date
) AS previous_day
FROM
temp),
streaks AS (
SELECT
name,
date,
minutes,
hit_day,
count_if(previous_day IS NULL OR hit_day != previous_day)
OVER (
PARTITION BY name
ORDER BY date
) AS streak_group
FROM
lagging
),
grouped AS (
SELECT
name,
min(date) AS min_date,
max(date) AS max_date,
sum(minutes) AS total_minutes
FROM
streaks
WHERE
hit_day IS TRUE
GROUP BY
name,
streak_group
)
SELECT
name,
CAST(max(1 + max_date - min_date) as INTEGER) AS best_streak_duration
from grouped
group by name
"""
).show()
Note: in Spark SQL, the range ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
is used as the default for window frames, so we don't have to explicitly define it in this case.
How does it work?
Consecutive grouping in SQL typically refers to grouping consecutive rows in a result set that have the same or similar values in a particular column. This can be useful when you want to analyze or summarize data that is organized in a sequential or time-based manner.
The SQL window functions LAG()
is used in conjunction with ORDER BY
to achieve the value of the previous row in the order defined by a sequential column.
Then the trick is to group data rows into consecutive groups by accumulatively counting the number of times the value of the specified column changed in comparison to the previous row value, using the count_if()
function over a second window function between UNBOUNDED PRECEDING
and CURRENT ROW
.
Summary
Consecutive row grouping can be a powerful tool for various data analysis and reporting tasks, allowing you to segment and analyze data in a meaningful way based on the sequential order of rows in your database table.
This technique is useful for various scenarios:
In time series analysis, for example, you want to group consecutive data points based on a particular condition.
To identify sessions, also called sessionization. This terminology also appeared a lot in processing streaming data context.
Finding gaps in sequential data.
Tracking changes in state data.
Reference
I learned this technique from this blog. I just want to show how to implement it in my familiar framework (Apache Spark).
Future
I want to revisit this technique, but in the streaming world, where everything is more complicated.
Thank you for reading this blog. Hope you find it useful.