Retrieve top n in each group of a DataFrame in pyspark

Retrieve top n in each group of a DataFrame in pyspark

Asked on January 12, 2019 in Apache-spark.
Add Comment


  • 4 Answer(s)

    In this window functions is used to attain the rank of each row based on user_id and score, and subsequently filter results to only keep the first two values.

    from pyspark.sql.window import Window
    from pyspark.sql.functions import rank, col
     
    window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())
     
    df.select('*', rank().over(window).alias('rank'))
      .filter(col('rank') <= 2)
      .show()
    #+-------+---------+-----+----+
    #|user_id|object_id|score|rank|
    #+-------+---------+-----+----+
    #| user_1| object_1| 3   | 1  |
    #| user_1| object_2| 2   | 2  |
    #| user_2| object_2| 6   | 1  |
    #| user_2| object_1| 5   | 2  |
    #+-------+---------+-----+----+
    

    For learning the spark the  programming guide will be useful.

    Data

    rdd = sc.parallelize([("user_1", "object_1", 3),
                        ("user_1", "object_2", 2),
                        ("user_2", "object_1", 5),
                        ("user_2", "object_2", 2),
                        ("user_2", "object_2", 6)])
    df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])
    
    Answered on January 12, 2019.
    Add Comment

    When using row_number instead of rank top-n is more accurate while getting rank equality:

    val n = 5
    df.select(col('*'), row_number().over(window).alias('row_number')) \
      .where(col('row_number') <= n) \
      .limit(20) \
      .toPandas()
    

    Make sure that limit(20).toPandas() is used instead of show() for Jupyter notebooks.

     

    Answered on January 12, 2019.
    Add Comment

    Top-n is more accurate if using row_number instead of rank when getting rank equality:

    val n = 5
    df.select(col('*'), row_number().over(window).alias('row_number')) \
      .where(col('row_number') <= n) \
      .limit(20) \
      .toPandas()
    Answered 7 days ago.
    Add Comment
    Dept_id  |  name  | salary
     1           A       10
     2           B       100
     1           D       100
     2           C       105
     1           N       103
     2           F       102
     1           K       90
     2           E       110
    
    
    
    df.withColumn("rank", rank().over(Window.partitionBy("Dept_id").orderBy($"salary".desc)))
        .filter($"rank" <= 3)
        .drop("rank")
    Dept_id  |  name  | salary
     1           N       103
     1           D       100
     1           K       90
     2           E       110
     2           C       105 
     2           F       102
    
    
    
    

     

    Answered 7 days ago.
    Add Comment


  • Your Answer

    By posting your answer, you agree to the privacy policy and terms of service.