Algorithm & Data Structure - Quick select

Problem statement

Find the kth smallest/largest element in an array.

Approach 1

Sort the array and pick the kth element directly from the array, takes O(NlogN).

Approach 2

Use heap to store k smallest/largest elements. Whenever the heap size is more than k, pop out an element. After loop through all elements, the top element in the heap is the kth element. This approach takes O(NlogK).

Approach 3

Quick select could solvethis problem with linear time. Quick select partitions the array into two parts until the kth element is found.

Partition


    private int partition(int left, int right, int[] nums) {
        if (left == right) return left;

        // randomly pick an element to partition
        // Note this is important to speed up the select.
        Random rand = new Random();
        int id = rand.nextInt(right+1-left) + left;
        swap(nums, left, id);


        int i = left, j = right + 1;
        int v = nums[left];
        while(true) {
            while (nums[++i] < v) if (i == right) break;
            while (v < nums[--j]) if (j == left ) break;

            if (i >= j ) break;
            swap(nums, i, j);
        }
        swap(nums, left, j);
        return j;
    }

Select

    private int kthSmallestElement(int[] nums, int k) {
        int left = 0, right = nums.length - 1;

        while(left <= rght) {
            int mid = partition(left, right, nums);
            if (k == mid) {
                return nums[mid];
            } else if( k > mid) {
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }
        return -1;
    }

Randomly pick one element as pivot in the partition is very important to speed up, else it could results in O(N^2) operations. Image a sorted array, if the first element is picked, then every partion will result in 1 and n-1 elements.

Quick select takes O(N), because ideally, partition results in two equal length array, so, the number of elements searched becomes

\[n, n/2, n/4, n/8, n/ 16......\]

The total sum is \(N x ( 1 + 1/2 + 1/4 + 1/8 + 1/ 16....)\) which results to 2N.

Written on May 19, 2021