How to select the first row of each group?

How to select the first row of each group?

Asked on November 15, 2018 in Apache-spark.
Add Comment


  • 3 Answer(s)

    For the window functions:

    Here is the clear and best solution:

    import org.apache.spark.sql.functions.{row_number, max, broadcast}
    import org.apache.spark.sql.expressions.Window
     
    val df = sc.parallelize(Seq(
        (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
        (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
        (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
        (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")
     
    val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)
     
    val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
     
    dfTop.show
    // +----+--------+----------+
    // |Hour|Category|TotalValue|
    // +----+--------+----------+
    // |    0|  cat26|    30.9|
    // |    1|  cat67|    28.5|
    // |    2|  cat56|    39.6|
    // |    3|  cat8|    35.6|
    // +----+--------+----------+
    

    Here this method will be weak in condition of significant data skew.

     

    Plain SQL aggregation followed by join:

    Here the another solution is joining with aggregated data frame:

    val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))
     
    val dfTopByJoin = df.join(broadcast(dfMax),
        ($"hour" === $"max_hour") && ($"TotalValue" === $"max_value"))
      .drop("max_hour")
      .drop("max_value")
    dfTopByJoin.show
    // +----+--------+----------+
    // |Hour|Category|TotalValue|
    // +----+--------+----------+
    // |    0|  cat26|    30.9|
    // |    1|  cat67|    28.5|
    // |    2|  cat56|    39.6|
    // |    3|  cat8|    35.6|
    // +----+--------+----------+
    

    This will have a duplicate values (if there is more than one category per hour with the same total value).

    It can be removed as by the following:

    dfTopByJoin
        .groupBy($"hour")
        .agg(
            first("category").alias("category"),
            first("TotalValue").alias("TotalValue"))
    

    Here ordering over structs is used:

    Actually this will not require joins or window functions:

    val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs"))
        .groupBy($"hour")
        .agg(max("vs").alias("vs"))
        .select($"Hour", $"vs.Category", $"vs.TotalValue")
     
    dfTop.show
    // +----+--------+----------+
    // |Hour|Category|TotalValue|
    // +----+--------+----------+
    // |    0|  cat26|    30.9|
    // |    1|  cat67|    28.5|
    // |    2|  cat56|    39.6|
    // |    3|  cat8|    35.6|
    // +----+--------+----------+
    

     

    By DataSet API (Spark 1.6+, 2.0+):

    Spark 1.6:

    case class Record(Hour: Integer, Category: String, TotalValue: Double)
    df.as[Record]
        .groupBy($"hour")
        .reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y)
        .show
    // +---+--------------+
    // | _1| _2|
    // +---+--------------+
    // |[0]|[0,cat26,30.9]|
    // |[1]|[1,cat67,28.5]|
    // |[2]|[2,cat56,39.6]|
    // |[3]| [3,cat8,35.6]|
    // +---+--------------+
    

    Spark 2.0 or later:

    df.as[Record]
        .groupByKey(_.Hour)
        .reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)
    

    Here the last two methods has advantage map side combine and does not require full shuffle so most of the time should exhibit a better performance compared to window functions and joins. These cane be also used with Structured Streaming in completed output mode.

    Do not use:

    df.orderBy(...).groupBy(...).agg(first(...), ...)
    

    In this same note applies to:

    df.orderBy(...).dropDuplicates(...)
    

    It uses equivalent execution plan internally.

     

    Answered on November 15, 2018.
    Add Comment

    Here this is done with grouping by multiple columns for Spark 2.0.2:

    import org.apache.spark.sql.functions.row_number
    import org.apache.spark.sql.expressions.Window
     
    val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc)
     
    val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
    
    Answered on November 15, 2018.
    Add Comment

    In the below answer there is only one groupBy and extract the rows of the dataframe that contain the maxValue in one shot and there will be no need for further Joins, or Windows.

    import org.apache.spark.sql.Row
    import org.apache.spark.sql.catalyst.encoders.RowEncoder
    import org.apache.spark.sql.DataFrame
     
    //df is the dataframe with Day, Category, TotalValue
     
    implicit val dfEnc = RowEncoder(df.schema)
     
    val res: DataFrame = df.groupByKey{(r) => r.getInt(0)}.mapGroups[Row]{(day: Int, rows: Iterator[Row]) => i.maxBy{(r) => r.getDouble(2)}}
    
    Answered on November 15, 2018.
    Add Comment


  • Your Answer

    By posting your answer, you agree to the privacy policy and terms of service.