目录

SQL N 个组中的 Top-K 问题

问题

这是个在半年前遇到的一个有趣的 SQL 查询问题,让我先描述一下:

假定我有一张表:

1
2
3
4
5
CREATE TABLE foods (
  id INT NOT NULL PRIMARY KEY AUTO_INCREMENT COMMENT '没啥意义的主键',
  category VARCHAR(10) COMMENT '类别',
  price INT COMMENT '价格'
);

我现在想知道每个 category 里面价格前 3 高的商品 id,能用一个查询解决吗?

这个问题就是 N 个组中的 Top-K 问题。

窗口函数

DB Fiddle:https://dbfiddle.uk/0tkoovd2

如果你在用比较新的 MySQL/MariaDB 或者 PostgreSQL 的话,你可以用 over(partition by) 来分组,并用窗口函数来解决这个问题:

1
2
3
4
5
6
7
8
9
SELECT
  id,
  category,
  price,
  ROW_NUMBER() OVER (PARTITION BY category ORDER BY price DESC) AS rn,
  RANK() OVER (PARTITION BY category ORDER BY price DESC) AS rk,
  DENSE_RANK() OVER (PARTITION BY category ORDER BY price DESC) AS drk
FROM
  foods;

这里面三个窗口函数的含义:

  • ROW_NUMBER():从上往下,从 1 开始按行号排序。
  • RANK():从上往下,根据 value 排序,其中并列的 rank 相同,并且会重复排名会占位(比如有两个第 5 名,那么第三个从第 7 名开始,而不是第 6 名。)
  • DENSE_RANK():同上,但是不跳过重复排名。

那么要解决上面这个问题,就非常简单了:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
SELECT
  w.id,
  w.category,
  w.price
FROM
  (
    SELECT
      id,
      category,
      price,
      ROW_NUMBER() OVER (PARTITION BY category ORDER BY price DESC) AS rn
    FROM
      foods
  ) w
WHERE
  w.rn <= 3;

SQL 解法

DB Fiddle:https://dbfiddle.uk/0kwueEv8

当然,在种种情况下,可能你并没有这么方便的窗口函数可以用(比如一些 custom 的 SQL 实现或者在 Flink SQL 下的一些非标准 DBMS 环境),怎么用 SQL 解这个问题呢?

我们先看一个求 RANK() 的例子:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
SELECT
  f.id,
  f.category,
  f.price,
  COUNT(af.price) + 1 AS rk
FROM
  foods f
  LEFT JOIN foods af ON (
    f.category = af.category
    AND f.price < af.price
  )
GROUP BY
  f.id
ORDER BY
  rk;

我们先看子查询:子查询做了一次 self join,f.category = af.category 实际上做了拆组的工作,保证在同组间做比较,而 f.price < af.price 这条比较很关键,这使得对于任意一行数据,比当前数据大的 entry 都出现了一次。例如,假定我们的数据是 20, 30, 40, 50,则对于 30 这条 entry,我们的结果是 (30, 40), (30, 50)。我们很容易发现,实际上有几条数据,我们就排在第几名(一个简单的观察是,因为有 k 条比当前 entry 大的数据,所以他们都在前面,则当前数据自然排在第 k+1 名),所以,我们针对 af.price(注意到我们应该计数后面的数据,而不是这一行的比较基准 f.price)做一下 count 即可。由于我们的 Rank 基于 1 开始,所以我们需要加一。

这个情况对于重复数值有用吗?同样有用。考虑 20, 30, 30, 40, 50,对于 30,我们会有两组相同数据 (30, 40), (30, 50) 满足要求,所以都排在第 3 名,而对于 20,此时满足要求的行数有 4 行,则 20 排在第 5 名。

基于 RANK(),我们可以迅速发现,DENSE_RANK() 实际上将多个 30 当做一组处理,则在后一个例子中,我们对 af.price 去重即可。所以,DENSE_RANK() 的 SQL 很直观:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
SELECT
  f.id,
  f.category,
  f.price,
  COUNT(DISTINCT af.price) + 1 AS drk
FROM
  foods f
  LEFT JOIN foods af ON (
    f.category = af.category
    AND f.price < af.price
  )
GROUP BY
  f.id
ORDER BY
  drk;

ROW_NUMBER() 需要多一层注意:由于我们希望每一行都是递增的,则对于相同的行,我们还需要另一个东西来保序。通常来说,我们使用主键 id 来保序,不过如果有特殊需求,也可以变。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
SELECT
  f.id,
  f.category,
  f.price,
  COUNT(af.price) + 1 AS rn
FROM
  foods f
  LEFT JOIN foods af ON (
    f.category = af.category
    AND (
      f.price < af.price
      OR (
        f.price = af.price
        AND f.id > af.id
      )
    )
  )
GROUP BY
  f.id
ORDER BY
  rn;

我们直接修改 join 的条件即可:当两个值相等的时候,我们选择一个其他 key 决定顺序。注意到我们把下面的 f.id > af.id 换成了大于号,其实是一个比较有趣的搞法:由于我们的前一个 f.id 有序,我们希望 f.id 更小的 entry 出现在前面,在最后的筛选时才能让 f.id 更大的包含更多结果。比如我们有 (1, 20), (2, 30), (3, 30), (4, 40), (5, 50),我们期望让 (2, 30) 的 entry 只有两个,(3, 30) 的 entry 有三个,所以我们需要 f.id > af.id,此时 (3, 30, 2, 30) 这条 entry 会被统计到后面,才能保证 (2, 30) 这条排在前面。

由于我们每一行都会有一个唯一 ROW_NUMBER(),所以我们不能插入 DISTINCT,不然重复值这一条的统计会出问题。

所以,我们要解决最初的问题,只需要对 rn 筛选一下即可:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
SELECT
  f.id,
  f.category,
  f.price,
  COUNT(af.price) + 1 AS rn
FROM
  foods f
  LEFT JOIN foods af ON (
    f.category = af.category
    AND (
      f.price < af.price
      OR (
        f.price = af.price
        AND f.id > af.id
      )
    )
  )
GROUP BY
  f.id
HAVING
  rn <= 3
ORDER BY
  rn;

Reference