3

How to select the max row for every group in spark structured streaming 2.3.0 without using order by or mapGroupWithState?

Input:

id | amount     | my_timestamp
-------------------------------------------
1  |      5     |  2018-04-01T01:00:00.000Z
1  |     10     |  2018-04-01T01:10:00.000Z
2  |     20     |  2018-04-01T01:20:00.000Z
2  |     30     |  2018-04-01T01:25:00.000Z
2  |     40     |  2018-04-01T01:30:00.000Z

Expected Output:

id | amount     | my_timestamp
-------------------------------------------
1  |     10     |  2018-04-01T01:10:00.000Z
2  |     40     |  2018-04-01T01:30:00.000Z

Looking for a streaming solution using either raw sql like sparkSession.sql("sql query") or similar to raw sql but not something like mapGroupWithState

user1870400
  • 6,028
  • 13
  • 54
  • 115

1 Answers1

2

There are multiple approaches to solve this problem.

Approach 1 :

You can use Window operations in Spark

import org.apache.spark.sql.expressions.{Window, WindowSpec}
import org.apache.spark.sql.functions.{col, desc, rank}

val filterWindow: WindowSpec = Window.partitionBy("id").orderBy(desc("amount"))

val df = ???

df.withColumn("temp_rank", rank().over(filterWindow))
.filter(col("temp_rank") === 1)
.drop("temp_rank")

The problem with this is that it does not work with Structured Streaming as windowing is only supported on TIMESTAMP columns. This works for batch jobs.

Approach 2:

With the specified conditions in the question you could go with something like below. The grouping is done on id and the grouped contents are converted to Seq[A]. Here, A represents a Struct. This Seq is then filtered out for the record.

object StreamingDeDuplication {

  case class SubRecord(time: java.sql.Timestamp, amount: Double)

  val subSchema: StructType = new StructType().add("time", TimestampType).add("amount", DoubleType)

  def deDupe: UserDefinedFunction =
    udf((data: Seq[Row]) => data.maxBy(_.getAs[Double]("amount")), subSchema)

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local").appName("StreamingDeDuplication").getOrCreate()

    import spark.implicits._
    val records = spark.readStream
      .format("socket")
      .option("host", "localhost")
      .option("port", 9999)
      .load()
      .as[String]
      .map(_.split(","))
      .withColumn("id", $"value".getItem(0).cast("STRING"))
      .withColumn("amount", $"value".getItem(1).cast("DOUBLE"))
      .withColumn("time", $"value".getItem(2).cast("TIMESTAMP"))
      .drop("value")

    val results = records
      .withColumn("temp", struct("time", "amount"))
      .groupByKey(a => a.getAs[String]("id"))
      .agg(collect_list("temp").as[Seq[SubRecord]])
      .withColumnRenamed("collect_list(temp)", "temp_agg")
      .withColumn("af", deDupe($"temp_agg"))
      .withColumn("amount", col("af").getField("amount"))
      .withColumn("time", col("af").getField("time"))
      .drop("af", "temp_agg")

    results
      .writeStream
      .outputMode(OutputMode.Update())
      .option("truncate", "false")
      .format("console")
      .start().awaitTermination()
  }

}
Chitral Verma
  • 2,695
  • 1
  • 17
  • 29