Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 110 additions & 86 deletions lib/sort.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
* This performs n*log2(n) + 0.37*n + o(n) comparisons on average,
* and 1.5*n*log2(n) + O(n) in the (very contrived) worst case.
*
* Quicksort manages n*log2(n) - 1.26*n for random inputs (1.63*n
* Glibc qsort() manages n*log2(n) - 1.26*n for random inputs (1.63*n
* better) at the expense of stack usage and much larger code to avoid
* quicksort's O(n^2) worst case.
*/

#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt

#include <linux/types.h>
#include <linux/export.h>
#include <linux/sort.h>
#include <linux/log2.h>
#include <linux/sched.h>

/**
* is_aligned - is this pointer & size okay for word-wide copying?
Expand Down Expand Up @@ -186,18 +190,35 @@ static size_t parent(size_t i, unsigned int lsbit, size_t size)
return i / 2;
}

#include <linux/sched.h>

static void __sort_r(void *base, size_t num, size_t size,
cmp_r_func_t cmp_func,
swap_r_func_t swap_func,
const void *priv,
bool may_schedule)
/**
* sort_r - sort an array of elements
* @base: pointer to data to sort
* @num: number of elements
* @size: size of each element
* @cmp_func: pointer to comparison function
* @swap_func: pointer to swap function or NULL
* @priv: third argument passed to comparison function
* @may_schedule: whether to call cond_resched() periodically
*
* This function does a heapsort on the given array. You may provide
* a swap_func function if you need to do something more than a memory
* copy (e.g. fix up pointers or auxiliary data), but the built-in swap
* avoids a slow retpoline and so is significantly faster.
*
* Sorting time is O(n log n) both on average and worst-case. While
* quicksort is slightly faster on average, it suffers from exploitable
* O(n*n) worst-case behavior and extra memory requirements that make
* it less suitable for kernel use.
*/
void sort_r(void *base, size_t num, size_t size,
cmp_r_func_t cmp_func,
swap_r_func_t swap_func,
const void *priv,
bool may_schedule)
{
/* pre-scale counters for performance */
size_t n = num * size, a = (num/2) * size;
const unsigned int lsbit = size & -size; /* Used to find parent */
size_t shift = 0;

if (!a) /* num < 2 || size == 0 */
return;
Expand Down Expand Up @@ -225,18 +246,12 @@ static void __sort_r(void *base, size_t num, size_t size,
for (;;) {
size_t b, c, d;

if (a) /* Building heap: sift down a */
a -= size << shift;
else if (n > 3 * size) { /* Sorting: Extract two largest elements */
n -= size;
if (a) /* Building heap: sift down --a */
a -= size;
else if (n -= size) /* Sorting: Extract root to --n */
do_swap(base, base + n, size, swap_func, priv);
shift = do_cmp(base + size, base + 2 * size, cmp_func, priv) <= 0;
a = size << shift;
n -= size;
do_swap(base + a, base + n, size, swap_func, priv);
} else { /* Sort complete */
else /* Sort complete */
break;
}

/*
* Sift element at "a" down into heap. This is the
Expand All @@ -251,7 +266,7 @@ static void __sort_r(void *base, size_t num, size_t size,
* average, 3/4 worst-case.)
*/
for (b = a; c = 2*b + size, (d = c + size) < n;)
b = do_cmp(base + c, base + d, cmp_func, priv) > 0 ? c : d;
b = do_cmp(base + c, base + d, cmp_func, priv) >= 0 ? c : d;
if (d == n) /* Special case last leaf with no sibling */
b = c;

Expand All @@ -263,70 +278,18 @@ static void __sort_r(void *base, size_t num, size_t size,
b = parent(b, lsbit, size);
do_swap(base + b, base + c, size, swap_func, priv);
}

if (may_schedule)
cond_resched();
}

n -= size;
do_swap(base, base + n, size, swap_func, priv);
if (n == size * 2 && do_cmp(base, base + size, cmp_func, priv) > 0)
do_swap(base, base + size, size, swap_func, priv);
}

/**
* sort_r - sort an array of elements
* @base: pointer to data to sort
* @num: number of elements
* @size: size of each element
* @cmp_func: pointer to comparison function
* @swap_func: pointer to swap function or NULL
* @priv: third argument passed to comparison function
*
* This function does a heapsort on the given array. You may provide
* a swap_func function if you need to do something more than a memory
* copy (e.g. fix up pointers or auxiliary data), but the built-in swap
* avoids a slow retpoline and so is significantly faster.
*
* The comparison function must adhere to specific mathematical
* properties to ensure correct and stable sorting:
* - Antisymmetry: cmp_func(a, b) must return the opposite sign of
* cmp_func(b, a).
* - Transitivity: if cmp_func(a, b) <= 0 and cmp_func(b, c) <= 0, then
* cmp_func(a, c) <= 0.
*
* Sorting time is O(n log n) both on average and worst-case. While
* quicksort is slightly faster on average, it suffers from exploitable
* O(n*n) worst-case behavior and extra memory requirements that make
* it less suitable for kernel use.
*/
void sort_r(void *base, size_t num, size_t size,
cmp_r_func_t cmp_func,
swap_r_func_t swap_func,
const void *priv)
{
__sort_r(base, num, size, cmp_func, swap_func, priv, false);
}
EXPORT_SYMBOL(sort_r);

/**
* sort_r_nonatomic - sort an array of elements, with cond_resched
* @base: pointer to data to sort
* @num: number of elements
* @size: size of each element
* @cmp_func: pointer to comparison function
* @swap_func: pointer to swap function or NULL
* @priv: third argument passed to comparison function
*
* Same as sort_r, but preferred for larger arrays as it does a periodic
* cond_resched().
*/
void sort_r_nonatomic(void *base, size_t num, size_t size,
cmp_r_func_t cmp_func,
swap_r_func_t swap_func,
const void *priv)
cmp_r_func_t cmp_func,
swap_r_func_t swap_func,
const void *priv)
{
__sort_r(base, num, size, cmp_func, swap_func, priv, true);
sort_r(base, num, size, cmp_func, swap_func, priv, true);
}
EXPORT_SYMBOL(sort_r_nonatomic);

Expand All @@ -338,20 +301,81 @@ void sort(void *base, size_t num, size_t size,
.cmp = cmp_func,
.swap = swap_func,
};

return __sort_r(base, num, size, _CMP_WRAPPER, SWAP_WRAPPER, &w, false);
sort_r(base, num, size, _CMP_WRAPPER, SWAP_WRAPPER, &w, false);
}
EXPORT_SYMBOL(sort);

void sort_nonatomic(void *base, size_t num, size_t size,
cmp_func_t cmp_func,
swap_func_t swap_func)
cmp_func_t cmp_func,
swap_func_t swap_func)
{
struct wrapper w = {
.cmp = cmp_func,
.swap = swap_func,
};

return __sort_r(base, num, size, _CMP_WRAPPER, SWAP_WRAPPER, &w, true);
struct wrapper w = {
.cmp = cmp_func,
.swap = swap_func,
};
sort_r(base, num, size, _CMP_WRAPPER, SWAP_WRAPPER, &w, true);
}
EXPORT_SYMBOL(sort_nonatomic);

#define INTRO_SORT_THRESHOLD 16

static void insertion_sort(void *base, size_t num, size_t size,
cmp_r_func_t cmp_func, swap_r_func_t swap_func, const void *priv, bool may_schedule)
{
size_t i, j;
for (i = size; i < num * size; i += size) {
for (j = i; j >= size && cmp_func(base + j - size, base + j, priv) > 0; j -= size) {
swap_func(base + j, base + j - size, size, priv);
}
if (may_schedule)
cond_resched();
}
}

static void introsort_r(void *base, size_t num, size_t size, int depth_limit,
cmp_r_func_t cmp_func, swap_r_func_t swap_func, const void *priv, bool may_schedule)
{
if (num < 2)
return;
if (num < INTRO_SORT_THRESHOLD) {
insertion_sort(base, num, size, cmp_func, swap_func, priv, may_schedule);
return;
}
if (depth_limit == 0) {
// Fallback to heapsort for worst-case
sort_r(base, num, size, cmp_func, swap_func, priv, may_schedule);
return;
}

// Lomuto partition
size_t i = 0, j;
void *pivot = base + (num - 1) * size;
for (j = 0; j < num - 1; j++) {
if (cmp_func(base + j * size, pivot, priv) < 0) {
if (i != j)
swap_func(base + i * size, base + j * size, size, priv);
i++;
}
if (may_schedule)
cond_resched();
}
swap_func(base + i * size, pivot, size, priv);

introsort_r(base, i, size, depth_limit - 1, cmp_func, swap_func, priv, may_schedule);
introsort_r(base + (i + 1) * size, num - i - 1, size, depth_limit - 1, cmp_func, swap_func, priv, may_schedule);
}

void sort_r_hybrid(void *base, size_t num, size_t size,
cmp_r_func_t cmp_func, swap_r_func_t swap_func, const void *priv, bool may_schedule)
{
int depth_limit = 2 * ilog2(num);
introsort_r(base, num, size, depth_limit, cmp_func, swap_func, priv, may_schedule);
}
EXPORT_SYMBOL(sort_r_hybrid);

void sort_r_hybrid_nonatomic(void *base, size_t num, size_t size,
cmp_r_func_t cmp_func, swap_r_func_t swap_func, const void *priv)
{
sort_r_hybrid(base, num, size, cmp_func, swap_func, priv, true);
}
EXPORT_SYMBOL(sort_r_hybrid_nonatomic);