Using PySpark’s transform() function to write better code

Che Kulhan
3 min readSep 2, 2023
The transform() method allows you to chain transformations together

I recently posted this code snippet in various programming groups on LinkedIn about using PySpark’s transform() method. Given the overwhelming response and interest received, I thought a more detailed article would be of interest, providing further examples of how I have used the feature to improve my code readability, separation and modularity.

Code snippet posted on LinkedIn

For my examples, I often import some useful functions from PySpark’s sql.functions module, and create a sample dataframe for learning.

from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType
from pyspark.sql.functions import min, max, avg, col, lit, when

schema = StructType([
StructField("id",IntegerType(),True),
StructField("name",StringType(),True),
StructField("type",StringType(),True),
StructField("salary",DoubleType(),True)])

data_persons=[
(1, "Sara", "CONTRACTOR", 50000.00),
(2, "Jon", "FULL-TIME", 87500.32),
(3, "Susan", "FULL-TIME", 98000.89),
(4, "Axl", "PART-TIME", 25000.00),
(5, "Adam", "CONTRACTOR", 40000.00)
]

df_persons = spark.createDataFrame(data=data_persons, schema=schema)

df_persons.show(5, False)

--

--