arrow_backBack to Blog
pysparkinterviewdata-engineeringscenariossparkwindow-functionssql

Top 10 PySpark Scenario-Based Interview Questions (with Solutions)

P
PySpark Lab
calendar_month
schedule15 min read

Top 10 PySpark Scenario-Based Interview Questions (with Solutions)

If you are interviewing for a data engineering role anywhere from a startup to a FAANG, you will be asked to solve scenarios in PySpark, not to recite definitions. The bar has moved — interviewers want to see how you think when the data is messy, the volume is huge, or the requirement reads like a product spec instead of a clean problem statement.

This post walks through the ten scenarios that show up most often. Every solution uses the pure PySpark DataFrame API (no spark.sql, no selectExpr) and is fully runnable — paste it into the PySpark Lab playground and see the output for yourself.

How to use this post: read the scenario, stop scrolling, and try the solution yourself. Only then peek at the answer. The point is the thinking pattern, not memorizing code.


1. The skewed join — one user has 90% of the rows

The scenario. You are joining a 200-million-row orders table to a users table on user_id. The job runs for hours, one executor is at 100% CPU while the rest sit idle, and the Spark UI shows a single task processing 50× more data than the others. One whale customer accounts for 90% of all orders. How do you fix it?

Sample input.

from pyspark.sql import functions as F

orders  = spark.createDataFrame(
    [(i,  1, 100) for i in range(10000)]                # whale: user 1
  + [(i+10000, u, 100) for i, u in enumerate([2,3,4,5])],  # everyone else
    ["order_id", "user_id", "amount"]
)
users = spark.createDataFrame([(1,"whale"),(2,"a"),(3,"b"),(4,"c"),(5,"d")], ["user_id","name"])

Solution — salt the hot key on the larger side, replicate on the smaller side.

from pyspark.sql import functions as F

N = 16   # number of salt buckets — tune to your skew factor

# 1. Salt the skewed side (orders): each row gets a random bucket 0..N-1
orders_salted = orders.withColumn("salt", (F.rand() * N).cast("int"))

# 2. Replicate the small side (users) N times, one row per bucket
salt_range = spark.range(N).withColumnRenamed("id", "salt")
users_salted = users.crossJoin(salt_range)

# 3. Join on (user_id, salt) — load is now spread across N tasks for the whale
result = (orders_salted
          .join(users_salted, ["user_id", "salt"], "inner")
          .drop("salt"))

result.groupBy("name").agg(F.count("*").alias("orders")).orderBy(F.col("orders").desc()).show()

Why this works. A single key forces all matching rows into one partition during the shuffle. By splitting the hot key into N synthetic keys (via salt) and replicating the small side once per bucket, you trade a small amount of duplication for an even N-way split of the work.

Modern shortcut. Since Spark 3.0, AQE skew join can do this automatically: spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true"). But knowing the manual salt is what interviewers actually ask — and you still need it when AQE can't infer the skew (e.g. streaming, broadcast nested loops, or older Spark).


2. Deduplicate, keeping the latest record per user

The scenario. Your CDC ingestion writes every change to a table. The same user_id appears many times with different updated_at timestamps and different field values. You need exactly one row per user — the latest one.

Sample input.

df = spark.createDataFrame([
    (1, "alice@old.com",  "2024-01-01"),
    (1, "alice@new.com",  "2024-06-15"),
    (2, "bob@x.com",      "2024-03-10"),
    (1, "alice@mid.com",  "2024-04-02"),
    (2, "bob@x.com",      "2024-03-10"),    # exact duplicate
], ["user_id", "email", "updated_at"])

Solution — window + row_number() + filter.

from pyspark.sql import functions as F
from pyspark.sql.window import Window

w = Window.partitionBy("user_id").orderBy(F.col("updated_at").desc())

result = (df
          .withColumn("rn", F.row_number().over(w))
          .filter(F.col("rn") == 1)
          .drop("rn"))

result.orderBy("user_id").show()

Expected output.

+-------+--------------+----------+
|user_id|         email|updated_at|
+-------+--------------+----------+
|      1|alice@new.com|2024-06-15|
|      2|     bob@x.com|2024-03-10|
+-------+--------------+----------+

Why this works. dropDuplicates(["user_id"]) would pick an arbitrary row — useless. The window assigns rank 1 to the newest per user; filtering on rn == 1 keeps exactly one row per user_id. If there's a tie at the latest timestamp, add a stable tiebreaker (e.g. orderBy(F.col("updated_at").desc(), "email")).


3. Sessionize a clickstream — new session after 30 min idle

The scenario. You have a stream of pageviews per user. A "session" is a sequence of events where consecutive pageviews are at most 30 minutes apart. Assign a session ID to each event.

Sample input.

from datetime import datetime
df = spark.createDataFrame([
    (1, datetime(2024,1,1, 9,  0)),
    (1, datetime(2024,1,1, 9, 10)),
    (1, datetime(2024,1,1, 9, 50)),   # gap of 40 min -> new session
    (1, datetime(2024,1,1,10, 10)),
    (2, datetime(2024,1,1, 9,  0)),
], ["user_id", "event_at"])

Solution — lag to detect gaps, cumulative sum to assign session IDs.

from pyspark.sql import functions as F
from pyspark.sql.window import Window

w = Window.partitionBy("user_id").orderBy("event_at")
GAP_SECS = 30 * 60

# 1. Time gap from the previous event for the same user (in seconds)
step1 = df.withColumn("prev",  F.lag("event_at").over(w)) \
          .withColumn("gap_s", F.unix_timestamp("event_at") - F.unix_timestamp("prev"))

# 2. Mark the start of a new session: first event of user OR gap > threshold
step2 = step1.withColumn("is_new_session",
                         F.when(F.col("prev").isNull() | (F.col("gap_s") > GAP_SECS), 1).otherwise(0))

# 3. Cumulative sum of "is_new_session" = a monotonic per-user session number
w_cum = Window.partitionBy("user_id").orderBy("event_at") \
              .rowsBetween(Window.unboundedPreceding, Window.currentRow)
result = step2.withColumn("session_id", F.sum("is_new_session").over(w_cum)) \
              .select("user_id", "event_at", "session_id")

result.orderBy("user_id", "event_at").show()

Why this works. This is the classic "running counter that resets on a flag" trick. Detect the resets (gaps), put a 1 there, and cumsum turns the 1s into a session number that increments at each reset. Works for any "group consecutive events with a condition" problem.


4. Top-N highest-paying customers per region

The scenario. For each region, return the top 3 customers by total spend.

Sample input.

df = spark.createDataFrame([
    ("US", "alice", 1200), ("US", "bob", 800), ("US", "carol", 300), ("US", "dan", 1500),
    ("IN", "esha",  900),  ("IN", "ravi",1100), ("IN", "priya",  700),
], ["region", "customer", "spend"])

Solution — rank() (or dense_rank()) over a partitioned window.

from pyspark.sql import functions as F
from pyspark.sql.window import Window

w = Window.partitionBy("region").orderBy(F.col("spend").desc())

result = (df
          .withColumn("rnk", F.dense_rank().over(w))
          .filter(F.col("rnk") <= 3)
          .orderBy("region", "rnk"))

result.show()

Why this works. partitionBy("region") resets the ranking per region. Pick row_number() if you want exactly N rows even with ties; rank() skips numbers after a tie (1, 1, 3); dense_rank() doesn't (1, 1, 2). Interviewers love asking what changes if two customers tie at #3 — make sure you can explain all three.


5. Find users active on at least 3 consecutive days

The scenario. You have one row per user per day they logged in. Find every user who has at least one streak of ≥3 consecutive days.

Sample input.

from datetime import date
df = spark.createDataFrame([
    (1, date(2024,1,1)), (1, date(2024,1,2)), (1, date(2024,1,3)),  # streak of 3
    (2, date(2024,1,1)), (2, date(2024,1,3)),                       # gap on day 2
    (3, date(2024,1,5)), (3, date(2024,1,6)), (3, date(2024,1,7)), (3, date(2024,1,8)),
], ["user_id", "active_date"])

Solution — the island-and-gap trick: date − row_number = constant for one streak.

from pyspark.sql import functions as F
from pyspark.sql.window import Window

w = Window.partitionBy("user_id").orderBy("active_date")

step1 = df.withColumn("rn", F.row_number().over(w)) \
          .withColumn("grp", F.date_sub("active_date", F.col("rn").cast("int")))

streaks = (step1.groupBy("user_id", "grp")
                .agg(F.count("*").alias("streak_len")))

result = streaks.filter(F.col("streak_len") >= 3).select("user_id").distinct().orderBy("user_id")
result.show()

Why this works. For consecutive dates, subtracting the row number from the date yields the same value for every day in the streak (and a different value as soon as there's a gap). Grouping by that derived value gives you each streak's length. This pattern shows up constantly — memorize it.


6. SCD Type 2 — track historical price changes

The scenario. Given a stream of (product_id, price, change_date) events, produce a slowly-changing-dimension table with valid_from, valid_to and an is_current flag.

Sample input.

from datetime import date
df = spark.createDataFrame([
    ("P1", 100, date(2024,1,1)),
    ("P1", 120, date(2024,3,15)),
    ("P1", 110, date(2024,7,1)),
    ("P2",  50, date(2024,2,1)),
], ["product_id", "price", "change_date"])

Solution — lead() to find the next change, then derive valid_to and is_current.

from pyspark.sql import functions as F
from pyspark.sql.window import Window

w = Window.partitionBy("product_id").orderBy("change_date")

result = (df
    .withColumnRenamed("change_date", "valid_from")
    .withColumn("next_change", F.lead("valid_from").over(w))
    .withColumn("valid_to",
                F.coalesce(F.date_sub("next_change", 1),
                           F.lit(date(9999, 12, 31))))
    .withColumn("is_current", F.col("next_change").isNull())
    .drop("next_change")
    .orderBy("product_id", "valid_from"))

result.show()

Why this works. SCD Type 2 = each historical version has a closed interval [valid_from, valid_to]. lead() peeks at the next event for the same product — its date minus one day is the previous version's valid_to. The row whose next_change is null is the current version (open-ended); we cap valid_to at 9999-12-31 so range queries are uniform.


7. Find duplicate transactions within a 5-minute window

The scenario. Fraud rule: flag a (card_number, amount) pair if the same combination appears twice within 5 minutes. Return the duplicate pairs.

Sample input.

from datetime import datetime
df = spark.createDataFrame([
    ("c1", 100.0, datetime(2024,1,1, 9, 0,  0)),
    ("c1", 100.0, datetime(2024,1,1, 9, 3, 10)),   # dup — 3m10s after
    ("c1", 100.0, datetime(2024,1,1, 9,15, 30)),   # not a dup of either
    ("c2",  50.0, datetime(2024,1,1, 9, 0,  0)),
], ["card", "amount", "ts"])

Solution — lag() to compare the previous matching transaction.

from pyspark.sql import functions as F
from pyspark.sql.window import Window

w = Window.partitionBy("card", "amount").orderBy("ts")
WINDOW_SECS = 5 * 60

result = (df
    .withColumn("prev_ts", F.lag("ts").over(w))
    .withColumn("delta_s",
                F.unix_timestamp("ts") - F.unix_timestamp("prev_ts"))
    .filter(F.col("delta_s") <= WINDOW_SECS)
    .select("card", "amount", "ts", "prev_ts", "delta_s")
    .orderBy("card", "ts"))

result.show()

Why this works. Self-joining the table on card = card AND amount = amount AND |ts - ts| < 5min works but is O(N²) and explodes on real data. The window-lag version is O(N log N) (just a sort) and produces one row per offending pair. For a sliding window with any prior in the window (not just the immediate predecessor), use rangeBetween over a time-cast column.


8. Pivot daily sales into one column per day

The scenario. You have one row per (store, day, sales). Produce a wide table with one column per day so the analytics team can paste it into a spreadsheet.

Sample input.

df = spark.createDataFrame([
    ("S1", "Mon", 100), ("S1", "Tue", 150), ("S1", "Wed", 130),
    ("S2", "Mon",  80), ("S2", "Tue",  90), ("S2", "Wed", 110),
], ["store", "day", "sales"])

Solution — groupBy().pivot().

from pyspark.sql import functions as F

# Specifying the day list avoids a second pass over the data
days = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]

result = (df
          .groupBy("store")
          .pivot("day", days)
          .agg(F.sum("sales"))
          .orderBy("store"))

result.show()

Why this works. pivot rotates row values into column headers. Pass the column-value list explicitly when you know it ahead of time — otherwise Spark needs a separate scan to discover the distinct values, which doubles your job runtime on big data. To unpivot back (wide → long), there's no native helper; build it with a union of select statements (one per day column).


9. Customers who bought product A but never product B

The scenario. Find customers who have at least one purchase of product A and zero purchases of product B.

Sample input.

df = spark.createDataFrame([
    ("u1", "A"), ("u1", "B"),    # has both
    ("u2", "A"),                 # A only — keep
    ("u2", "A"),                 # duplicate A
    ("u3", "B"),                 # B only
    ("u4", "C"),                 # neither — skip
], ["customer", "product"])

Solution — left anti-join.

buyers_of_a = df.filter(F.col("product") == "A").select("customer").distinct()
buyers_of_b = df.filter(F.col("product") == "B").select("customer").distinct()

result = buyers_of_a.join(buyers_of_b, "customer", "left_anti").orderBy("customer")
result.show()

Why this works. left_anti returns rows from the left side that have no match on the right. It's the cleanest expression of "A minus B" set semantics in Spark and is heavily optimized vs. doing not isin(...) or where customer not in (subquery) (those force a broadcast or hash-aggregate detour).


10. Seven-day rolling average revenue per store

The scenario. For each store on each day, compute the average daily revenue over the trailing 7 days (inclusive).

Sample input.

from datetime import date
df = spark.createDataFrame([
    ("S1", date(2024,1,1), 100), ("S1", date(2024,1,2), 120),
    ("S1", date(2024,1,3),  90), ("S1", date(2024,1,4), 110),
    ("S1", date(2024,1,5), 130), ("S1", date(2024,1,6), 140),
    ("S1", date(2024,1,7), 150), ("S1", date(2024,1,8), 160),
    ("S2", date(2024,1,1),  50), ("S2", date(2024,1,2),  60),
], ["store", "day", "revenue"])

Solution — rangeBetween on a window ordered by date.

from pyspark.sql import functions as F
from pyspark.sql.window import Window

# Cast the date to a numeric type so rangeBetween can do time-range arithmetic
df_n = df.withColumn("day_n", F.datediff("day", F.lit(date(1970,1,1))))

w = (Window.partitionBy("store")
           .orderBy("day_n")
           .rangeBetween(-6, Window.currentRow))   # last 7 days inclusive

result = (df_n
    .withColumn("avg_7d", F.avg("revenue").over(w))
    .select("store", "day", "revenue", "avg_7d")
    .orderBy("store", "day"))

result.show()

Why this works. rowsBetween(-6, 0) averages the previous 6 physical rows + current row — only correct if your data has no gaps. rangeBetween(-6, 0) (over a numeric day column) averages by value range — correct even when some days are missing, which is the realistic case. Casting the date to an integer (days since epoch) is what makes rangeBetween work cleanly.


Patterns to take away

If you squint at the ten solutions above, only a handful of patterns are doing the heavy lifting:

  • Window + row_number()/rank()/dense_rank() — dedup, top-N, SCD Type 2, sessionization.
  • lag() / lead() — gaps, deltas, SCD intervals, "compare to previous row".
  • Cumulative sum of a flag column — turn boundary markers into group IDs (sessions, streaks, restarts).
  • rangeBetween — time-aware moving averages on potentially-sparse dates.
  • left_anti — clean set-difference filtering.
  • Salting + crossJoin replication — when one key has most of the work.

Master those six and 80% of "PySpark scenario" interview questions become routine.

Practice these live

Every solution above is runnable in your browser — no install, no cluster. Paste any block into the PySpark Lab playground, hit run, and tweak the inputs until you can do it from memory.

If you want hundreds more (basic → medium → complex), the full interview question bank and the code playground have you covered.

Good luck with the interview.