A PySpark-native way to do recursion

Part 1: the Problem

data
spark
python
Author

corey

Published

April 21, 2024

I find lots of situations in my work where recursion in Spark dataframes would be useful. Recently, I needed to define a column with values that were based on each following value, recursively. That’s to say, each row could only be computed after the row before it.

a simple example

To make this example concrete, I’ll use the Fibonacci sequence:

\[ 0, 1, 1, 2, 3, 5, 8, 11, \dots \]

Generating the sequence is really easy: the first 2 values are \(0\) and \(1\), respectively, and each number \(fibonacci(n)\) in the sequence depends on the last 2 values.

\[ fibonacci(0) = 0 \]

\[ fibonacci(1) = 1 \]

\[ fibonacci(n) = fibonacci(n-1) + fibonacci(n-2) \]

Plus, the sequence is equally easy to code:

def fibonacci(n: int) -> int:
  if n == 0:
    f = 0
  elif n == 1:
    f = 1
  else:
    f = fibonacci(n - 1) + fibonacci(n - 2)
  return f

Running this code yields the same sequence we started with:

[fibonacci(n) for n in range(0, 8)]
[0, 1, 1, 2, 3, 5, 8, 13]

the problem

But, try to calculate this on a Spark dataframe and you’ll quickly encounter a problem. The features of distributed computing—that it calculates each row of the dataframe simultaneously—presents as a bug here. Calculating fibonacci(n) requires calculating fibonacci(n-1) and fibonacci(n-2) and fibonacci(n-3) et cetera, and Spark just doesn’t like that.

I use the fibonacci sequence as an example because it’s simple to code and even to calculate by hand, but the approach to calculate it in Spark is the same as much more interesting problems. For example, my current work requires me to calculate the Poisson cumulative distribution function for various values of \(\lambda\) so that we can turn point forecasts of discrete variables into distrubtional forecasts at massive scale.

So, how can we calculate a fibonacci sequence in Spark? One thing we can’t do is use window functions. Though windowing takes into account previous values, it takes them into account statically. Try writing this using window functions and you’ll immediately see a problem:

(
  df
  .withColumn(
    "Fibonacci",
    f.lag("Fibonacci").over(w.Window.orderBy("Fibonacci")) + f.lag("Fibonacci", offset = 2).over(w.Window.orderBy("Fibonacci"))
  )
)

Either you need to have Fibonacci already populated or you need to populate it using a non-recursive function, but what would that be? I’ll explore that in my next post.