A PySpark-native way to do recursion

Part 2: Solutions

data
spark
python
Author

corey

Published

April 21, 2024

In my last post, I described an example of recursive algorithms, the Fibonacci sequence, and showed that it can’t be solved with classic SQL tools like window functions. In this post, I’ll explore possible solutions and demonstrate my preferred, PySpark-native approach.

possible solutions

Most answers to this problem rely on Python user-defined functions (UDFs). In Spark, a Python UDF works by: 1. creating tiny Python sessions for each row 2. converting the data from Scala/Java datatypes to Python data types 3. running the Python code 4. re-converting the data into Scala datatypes

The code might look like the example Python script above or it could be more concise, but either way, it needs to go through this 4-step process for each row, which makes even simple UDFs computationally expensive. Plus, adding \(r\) rows to a computation running on a cluster with \(w\) worker nodes scales the computational cost by \(\frac{r}{w}\).

One way to speed these computations up is to use Pandas UDFs. Unlike Python UDFs, which Pandas UDFs only go through this process once per partition. So, assuming you are computing this for a bunch of grouped data, you only need to perform the 4-step process once per group. Still, there must be a way to do this without all that overhead, right?

One option is to use Java or Scala UDFs, which removes lots of the overhead, but of course is harder to handoff to other people who don’t know Java or Scala. With a problem more complex than the Fibonacci sequence, that might be a dealbreaker.

my preferred solution

Enter aggregate() and array functions. You can use these methods to calculate recursive problems using only PySpark. Of course, these are compiled to Scala under the hood, so you could easily see this as an intermediate step to learning a new language, but I think seeing this method made me a more creative programmer.

Essentially, you’ll create a vector of input data and then iterate a function over that. For our Fibonacci example, this means that you’ll need to create an \(n \times 1\) vector within each column, so there’s a possibility of an out-of-memory error if \(n\) is big enough.

Let’s start with a new data frame and go from there:

import pyspark.sql.functions as f
import pyspark.sql.types as t
from pyspark.sql import DataFrame, SparkSession

spark = SparkSession.builder.getOrCreate()

df = spark.createDataFrame([(1, ), (2, ), (3, ), (4, ), (5, )], ("n", ))

With row numbers assigned to each row, we can create base vectors

df = (
  df.agg(
    f.sort_array(f.collect_set("n")).alias("BaseArray")
  )
)

df.show()
+---------------+
|      BaseArray|
+---------------+
|[1, 2, 3, 4, 5]|
+---------------+

Now, we can start working with aggregate. Let’s look at the definition of this function:

def aggregate(
    col: ColumnOrName,
    initialValue: ColumnOrName,
    merge: (Column, Column) -> Column,
    finish: ((Column) -> Column) | None = None
) -> Column

This function works like a local accumulator, allowing each row to iterate over the colBaseArray in this case—and, crucially, carry over values between iteration steps.

initialValue is crucial for the Fibonacci sequence because it lets us set fibonacci(0) and fibonacci(1):

f.struct(
  f.lit(0).alias("nMinus2"),
  f.lit(1).alias("nMinus1")
)

merge does the heavy lifting for this function, choosing how to get from one step to the other. One annoying necessity is that it must operate on PySpark columns using PySpark column functions. That means we can’t write many functions using the native Python library and operators, but for the Fibonacci sequence it shouldn’t be a problem.

def fib(previous, current):
  n_2 = previous.nMinus1
  n_1 = previous.nMinus1 + previous.nMinus2
  return f.struct(n_2.alias("nMinus2"), n_1.alias("nMinus1"))

Last, finish requires another function to extract the final value. Since we only want the last value to calculate the fibonacci, this is very straightforward:

lambda x: x.nMinus1

putting it all together

df = (
  df
  .withColumn(
    "Fibonacci",
    f.aggregate(
      col = "BaseArray",
      initialValue = f.struct(
        f.lit(0).cast(t.LongType()).alias("nMinus2"),
        f.lit(1).cast(t.LongType()).alias("nMinus1")
      ),
      merge = fib,
      finish = lambda x: x.nMinus1
    )
  )
)

So, did it work?

df.show()
+---------------+---------+
|      BaseArray|Fibonacci|
+---------------+---------+
|[1, 2, 3, 4, 5]|        8|
+---------------+---------+

As you can see, this returned a single value. Now let’s check it against the native Python implementation:

def fibonacci(n: int) -> int:
  if n == 0:
    f = 0
  elif n == 1:
    f = 1
  else:
    f = fibonacci(n - 1) + fibonacci(n - 2)
  return f

fibonacci(6)
8

You’ll notice our PySpark function doesn’t actually use current for anything, though it’s included to meet the function requirements. You could make this example more complex by scaling each \(fibonacci(n)\) by \(n\). This would change the recursive element to be \(fibonacci(n) = n \times fibonacci(n-1) + fibonacci(n-2)\).

It’s a contrived example, but it illustrates how you could pull in another variable. This lets us change the function above to make use of the current value:

def fib(previous, current):
  n_2 = previous.nMinus1
  n_1 = current * (previous.nMinus1 + previous.nMinus2)
  return f.struct(n_2.alias("nMinus2"), n_1.alias("nMinus1"))

In turn, this outputs:

(
    df
    .withColumn(
        "ScaledFibonacci",
        f.aggregate(
            col = "BaseArray",
            initialValue = f.struct(
            f.lit(0).cast(t.LongType()).alias("nMinus2"),
            f.lit(1).cast(t.LongType()).alias("nMinus1")
            ),
            merge = fib,
            finish = lambda x: x.nMinus1
        )
    )
    .show()
)
+---------------+---------+---------------+
|      BaseArray|Fibonacci|ScaledFibonacci|
+---------------+---------+---------------+
|[1, 2, 3, 4, 5]|        8|            455|
+---------------+---------+---------------+

tl;dr

We can perform recursive algorithms using only native PySpark. This let’s us expand the types of problems we can solve while: 1. avoid the overhead associated with Python and Pandas UDFs 2. not assuming our collborators know any other languages

I demonstrated a classic recursive problem, the Fibonacci sequence, but the same approach will work for lots of harder problems. For instance, how would you implement a Poisson CDF up to the value \(k\)?

\[ CDF(k) = e^{-\lambda}\sum_{j=1}^{k}\frac{k^j}{j!} \]

for \(k = {0, 1, 2, \dots}\)

Hint: start by repeating \(e^{-\lambda}\) in a \(k \times 1\) array.