Skip to content

bug: Graphframes with Structured Streaming results in OOM even with very low data volumes #360

@Will-Hardman

Description

@Will-Hardman

I have a simple PySpark structured streaming app that transforms incoming messages into a graph (using GraphFrames). A simplified example of the code is given below.

The code will run for ~50 batches before crashing with a "GC Overhead..." or "Heap size..." error. By this point it will have a graph no larger than 100 vertices and 300 edges.

The logs + traceback shows the labelPropogation() call as having requested a memory page which then causes the OOM error. If I switch the function for a pagerank() or a custom Pregel algorithm (using graphframes.lib.pregel) the crash will still happen, still after ~50 batches. When using a Pregel algorithm the traceback will show the aggregate function within aggMsgs() as being the culprit.

This feels like a memory leak but I am unable to trace anyone else having experienced precisely this. So perhaps the problem lies with my settings? (Note that in the code below I am running in local mode but the problem also exists when submitting to a cluster, so I have left the settings relevant to running on a cluster in the code).

I wonder if the use of GraphX under the hood could be part of the problem?

Any insights very gratefully received!

from pyspark.sql.functions import udf, window,count, col, sum
from pyspark.sql.types import *
import jsonpickle

LABEL_PROP_MAX_ITER = 5

def main():

    spark = (
        SparkSession
        .builder
        .appName('GraphFrames_Test')
        .master("local[2]")
        .config('spark.jars.packages', 'org.apache.spark:spark-sql-kafka-0-10_2.11:2.4.3,graphframes:graphframes:0.7.0-spark2.4-s_2.11')
        .config('spark.sql.shuffle.partitions', 1)
        .config('spark.python.worker.memory', '2G')
        .config('spark.executor.memory', '2G')
        .config('spark.driver.memory', '2G')
        .config('spark.cleaner.ttl', '10s')
        .config('spark.cleaner.periodicGC.interval', '5min')
        .config('spark.cleaner.referenceTracking.blocking.shuffle', 'true')
        .config('spark.cleaner.referenceTracking.cleanCheckpoints', 'true')
        .config('spark.driver.maxResultSize', '500m')
        .config('spark.graphx.pregel.checkpointInterval', 2)
        .config('spark.executor.cores', '2G')
        .config('spark.dynamicAllocation.enabled', 'true')
        .config('spark.shuffle.service.enabled', 'true')
        .config('spark.dynamicAllocation.maxExecutors', 2)
        .getOrCreate()
    )

    # These imports can only happen once the graphframes JAR has been registered
    from graphframes import GraphFrame
    from graphframes.lib import Pregel

    # Checkpointing
    spark.sparkContext.setCheckpointDir('/tmp')

    # Initialise logger
    log4j = spark.sparkContext._jvm.org.apache.log4j
    log4j.LogManager.getRootLogger().setLevel(log4j.Level.WARN)

    # Define the stream to process
    stream = (
        spark
        .readStream
        .format("kafka")
        .option("kafka.bootstrap.servers", 'localhost:9092')
        .option("subscribe", 'raw')
        .option("startingOffsets", "latest")
        .load()
    )

    # Schema for the elements we will use
    schema = StructType([
        StructField("src", LongType(), True),
        StructField("src", TimestampType(), True),
        StructField("dst", LongType(), False)
    ])

    def parser(serialised_packet):
        packet = jsonpickle.decode(serialised_packet)
        src = packet.src
        dst = packet.dst
        created_at = packet.created_at
        return [src, dst, created_at]

    # Register UDF
    parser_udf = udf(lambda value: parser(value), schema)

    msgs = (
        stream
        .select(parser_udf('value').alias('msg'))
        .select(
            col('msg.src').alias('src'),
            col('msg.dst').alias('dst'),
            col('msg.created_at').alias('created_at')
        )
        .withWatermark('created_at', '1 day')
    )

    # Weighted edges
    edges_df = (
        msgs
        .groupBy(window('created_at', '7 days', '1 day'), 'src', 'dst')
        .count()
        .withColumnRenamed('count', 'weight')
    )

    def label_clusters(edges_df: DataFrame, batch_id):

        if not edges_df.rdd.isEmpty():

            # Generate a DF of vertices
            vertices_df = (
                edges_df.select(edges_df.src.alias('id'))
                .union(
                    edges_df.select(edges_df.dst.alias('id'))
                )
            )

            g = GraphFrame(vertices_df, edges_df)

        clusters_df = g.labelPropagation(LABEL_PROP_MAX_ITER)

        labels = (
            clusters_df
            .groupBy('label')
            .count()
            .toPandas()
            .label
            .values
        )

        print(f"Batch: [{batch_id}]")
        print(f"Vertex Count: [{g.vertices.count()}]")
        print(f"Edge Count: [{g.edges.count()}]")
        print(f"Num. Labels: [{len(labels)}]")
        print("********************************")

        g.unpersist(blocking=True)

    (
        edges_df
        .writeStream
        .outputMode("complete")
        .foreachBatch(label_clusters)
        .start()
    )

if __name__ == '__main__':
    main()

Metadata

Metadata

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions