Of course you can, but how would you expect your data to be? If you expect your output to be something like:
Age Amount A B
2 200 500 1450
3 450 500 1450
5 500 500 1450
0 200 500 1450
8 300 500 1450
9 200 500 1450
1 100 500 1450
Then this is a windowed aggregate function (windowing over sum). A windowing function is used to place an aggregated value for all the rows (in this case).
df
.withColumn(
"A",
sum(when(col("Age") lt 3, col("Amount")).otherwise(lit(0)))
.over()
)
.withColumn(
"B",
sum(when(col("Age") >= 3, col("Amount")).otherwise(lit(0)))
.over()
)
Note that using over
window function without partitioning is not performant at all, use partitioning.
Here's the output:
22/04/25 23:54:01 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
+---+------+---+----+
|Age|Amount| A| B|
+---+------+---+----+
| 2| 200|500|1450|
| 3| 450|500|1450|
| 5| 500|500|1450|
| 0| 200|500|1450|
| 8| 300|500|1450|
| 9| 200|500|1450|
| 1| 100|500|1450|
+---+------+---+----+
Update:
So after you updated the question, I suggest you to do this:
df
.groupBy(
col("JoinKey"), col("period"), expr("Age < 3").as("under3")
).agg(sum(col("Amount")) as "grouped_age_sum")
.withColumn("A", sum(when(col("under3") === true, col("grouped_age_sum")).otherwise(lit(0)))
.over()
)
.withColumn("B", sum(when(col("under3") === false, col("grouped_age_sum")).otherwise(lit(0)))
.over()
).drop("grouped_age_sum", "under3")
.groupBy(col("JoinKey"), col("period")).min()
.withColumnRenamed("min(A)", "A")
.withColumnRenamed("min(B)", "B")
.show
Please note that the same thing about partitioning still exists, I had some few sample data and didn't really need performance (it would've also add some logic dependent boilerplate to the solution), but you should do it, here's the output:
22/04/26 22:36:48 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
+-------+-------+---+----+
|JoinKey| period| A| B|
+-------+-------+---+----+
| JK1|2022-02|500|1450|
| JK2|2022-03|500|1450|
| JK3|2022-03|500|1450|
| JK2|2022-02|500|1450|
| JK2|2022-04|500|1450|
+-------+-------+---+----+
Update #2:
After clearer explanations you provided: so you just need grouping with 2 simple aggregate functions:
df
.groupBy(col("JoinKey"), col("period"))
.agg(
sum(when(col("Age") lt 4, col("Amount")).otherwise(lit(0))).as("Amount (Age <= 3)"),
sum(when(col("Age") gt 3, col("Amount")).otherwise(lit(0))).as("Amount (Age > 3)")
)
Output:
+-------+-------+-----------------+----------------+
|JoinKey| period|Amount (Age <= 3)|Amount (Age > 3)|
+-------+-------+-----------------+----------------+
| JK1|2022-02| 650| 0|
| JK2|2022-03| 0| 500|
| JK3|2022-03| 200| 200|
| JK2|2022-02| 0| 300|
| JK2|2022-04| 100| 0|
+-------+-------+-----------------+----------------+