Type traits 练习:编译期快排

前排提醒:本文仅做学习交流之用,实际项目中如有条件请尽量使用 Constraints and concepts (since C++20) 。

上回我们通过对 tuple 的类型操作,尝试了一下 Type traits 之后,本文我们开始加点强度,实现一个基于模板元编程的编译期快排(不妨在这里思考一下,如果你来实现,大概分成哪几个步骤呢?跟运行时的快排有什么区别?)

本文涉及到的所有代码均在 godbolts 中,大家有需要可以自取~

定义数据类型

思考一下,我们要排序的是什么数据,这些数据存在哪里?

答案比较显然,存在模板参数里,作为结构体的编译期信息。

template <int... Ns> // 为简化问题,我们直接使用 int 作为待排序的数据类型
struct Nums {};

在开始编写快排之前,我们可以定义一些操作 Nums 的操作,比如将两个 Nums 合并起来。

// 以下这段代码你应该可以很快理解了
template <typename T1, typename T2>
struct Concat {};
template <int... Ns1, int... Ns2>
struct Concat<Nums<Ns1...>, Nums<Ns2...>> {
using type = Nums<Ns1..., Ns2...>;
};

拆解快排流程

温馨提醒:如果还不了解快排的流程,可以参考下维基百科,再阅读剩下的内容。

编译期快排整体流程跟传统快排并没有什么不同:

  1. 基于某个 pivot(一般取首个元素) 将当前数据划分为左右两部分或左、中、右三部分(本文使用三路快排)
  2. 对 1 中划分的左、右部分递归执行快排,直到数组长度小于 2

因此,我们可以将编译期快排拆解成两个步骤:

  1. 划分数据
  2. 递归执行

接下来我们将针对这两个步骤分别思考和实现。

划分数据

直接取 Nums 中的数据做 pivot 划分复杂度会高一些,我们可以先考虑在给定一个 pivot 元素的情况下,怎样把将 Nums 划分为三个部分。现在我们可以开始定义 Divide 类模板了。

template <int Pivot, typename T>
struct Divide {};
template <int Pivot, int... Ns>
struct Divide<Pivot, Nums<Ns...>> {};
template <int Pivot>
struct Divide<Pivot, Nums<>> {
using left = Nums<>;
using mid = Nums<>;
using right = Nums<>;
};

写出上面三个模板还是比较简单的,关键在于怎样把 Divide 的递推关系找出来。其实也比较简单,我们只要不断萃取 Nums 中的第一个元素,将其与 pivot 比较,根据比较结果分类讨论写出递推式就可以了~

实现见以下代码。这里我们用到了 std::conditional_t(自己实现也是很简单的)

template <int Pivot, int N0, int... Ns>
struct Divide<Pivot, Nums<N0, Ns...>> {
using left =
std::conditional_t<
N0 < Pivot,
typename Concat<Nums<N0>,
typename Divide<Pivot, Nums<Ns...>>::left
>::type,
typename Divide<Pivot, Nums<Ns...>>::left
>;
using mid =
std::conditional_t<
N0 == Pivot,
typename Concat<Nums<N0>, typename Divide<Pivot, Nums<Ns...>>::mid>::type,
typename Divide<Pivot, Nums<Ns...>>::mid
>;
using right =
std::conditional_t<
Pivot < N0,
typename Concat<Nums<N0>, typename Divide<Pivot, Nums<Ns...>>::right>::type,
typename Divide<Pivot, Nums<Ns...>>::right
>;
};

测试一下,符合预期:

print_type<Divide<5, Nums<1, 5>>::left>(); // Nums<1>
print_type<Divide<5, Nums<1, 5>>::mid>(); // Nums<5>
print_type<Divide<5, Nums<1, 5>>::right>(); // Nums<>
print_type<Divide<5, Nums<1, 2, 9, 4, 5, 5, 6, 7, 8>>::left>(); // Nums<1, 2, 4>
print_type<Divide<5, Nums<1, 2, 9, 4, 5, 5, 6, 7, 8>>::mid>(); // Nums<5, 5>
print_type<Divide<5, Nums<1, 2, 9, 4, 5, 5, 6, 7, 8>>::right>(); // Nums<9, 6, 7, 8>

以上实现的 Divide 模板中,pivot 是由外部传入的,但快排需要直接取 Nums 中的元素作为 pivot,怎么实现呢?

其实只要继承一下现有的 Divide,我们就能很轻松地写一个取第一个元素作为 pivot 的 DivideInplace 了~

template <typename T>
struct DivideInplace {};
template <>
struct DivideInplace<Nums<>> : Divide<0, Nums<>> {};
template <int N0, int... Ns>
struct DivideInplace<Nums<N0, Ns...>> : Divide<N0, Nums<N0, Ns...>> {};

再测试一下,也没有问题。

print_type<DivideInplace<Nums<1, 2, 3>>::left>(); // Nums<>
print_type<DivideInplace<Nums<1, 2, 3>>::mid>(); // Nums<1>
print_type<DivideInplace<Nums<1, 2, 3>>::right>(); // Nums<2, 3>
print_type<DivideInplace<Nums<3, 2, 1>>::left>(); // Nums<2, 1>
print_type<DivideInplace<Nums<3, 2, 1>>::mid>(); // Nums<3>
print_type<DivideInplace<Nums<3, 2, 1>>::right>(); // Nums<>

至此,我们就把最难的一步——划分数据写完了,接下来实际上就比较简单了。

递归执行

首先我们还是定义一个快排模板,并考虑当 Nums 为空或只有一个元素的情况。

template <typename T>
struct QuickSort {};
template <>
struct QuickSort<Nums<>> {
using type = Nums<>;
};
template <int N>
struct QuickSort<Nums<N>> {
using type = Nums<N>;
};

然后就是递推式了,也不太复杂:

template <int... Ns>
struct QuickSort<Nums<Ns...>> {
using divide = DivideInplace<Nums<Ns...>>;
using left = typename QuickSort<typename divide::left>::type;
using mid = typename divide::mid;
using right = typename QuickSort<typename divide::right>::type;
using type = typename Concat<typename Concat<left, mid>::type, right>::type;
};

最终测试,没问题!

print_type<QuickSort<Nums<>>::type>(); // Nums<>
print_type<QuickSort<Nums<3>>::type>(); // Nums<3>
print_type<QuickSort<Nums<1, 1>>::type>(); // Nums<1, 1>
print_type<QuickSort<Nums<1, 2, 3>>::type>(); // Nums<1, 2, 3>
print_type<QuickSort<Nums<3, 2, 1>>::type>(); // Nums<1, 2, 3>
print_type<QuickSort<Nums<5, 1, 2, 9, 4, 5, 6, 7, 8>>::type>(); // Nums<1, 2, 4, 5, 5, 6, 7, 8, 9>

结语

完工~ 要不你写个编译期归并排序练练手怎么样 🥵