算法: 寻找第k小值问题的O(n)解法

题目:

给定一个长度为n的数组arr,请输出第k小值,1<=k<=n,当k = 1或n时,就是找最小值或最大值。

思路:

把数组划分为n/5组。每组5个元素。分别找到每组的中间值,然后在n/5个中间值中找中间值mm。用这个中间值mm划分数组,如果mm所在位置等于k,则找到答案。如果,mm所在位置小于k,继续在右半部分找。如果mm所在位置大于k,继续在左半部分找。重复这个过程。直到找到答案。中间给5个元素排序,用的是快速排序代码,给5个元素排序的时间复杂度为O(1)

代码:

class kSmallNumber {
public:
    int kSmallNumb(vector<int>& arr,int k,int left, int right) {
        int n = right - left + 1;
        if (n <= 5) {
            quickSort(arr,left,right);
            return arr[left + k-1];
        }
        vector<int> tmpM;
        for(int i = left;i < arr.size();i+=5) {
            int tempRight = i + 4;
            if (tempRight >= arr.size()) {
                tempRight = arr.size() - 1;
            }
            quickSort(arr, i, tempRight);
            tmpM.emplace_back(arr[i+(tempRight - i)/2]);
        }
        
        int mm = kSmallNumb(tmpM, n/10 + 1, 0, tmpM.size()-1);
        
        int j = partition(arr, mm, left, right);
        
        if (k == j + 1 - left) {
            return arr[j];
        }
        else if (k < j + 1 - left) {
            return kSmallNumb(arr, k, left, j-1);
        }
        return kSmallNumb(arr, k - (j + 1 - left), j+1, right);
    }
    
    int partition(vector<int>& arr, int parti, int left, int right) {
        int origLoc = 0;
        for(int i = 0 ;i < arr.size();i++) {
            if (arr[i] == parti) {
                origLoc = i;
                break;
            }
        }
        int l = left, r = right;
        while (l < r) {
            while (l < r && parti <= arr[r]) {
                r--;
            }
            while (l < r && parti >= arr[l]) {
                l++;
            }
            int tmp = arr[l];
            arr[l] = arr[r];
            arr[r] = tmp;
        }
        if (origLoc > l) {
            if (l+1 < right) {
                int tmp = arr[l+1];
                arr[l+1] = parti;
                arr[origLoc] = tmp;
                return l+1;
            }
        }
        else {
            int tmp = arr[l];
            arr[l] = parti;
            arr[origLoc] = tmp;
        }
        
        return l;
    }
    
    void quickSort(vector<int>& arr,int left, int right) {
        if (left < right) {
            int l = left, r = right;
            while (l < r) {
                while (l < r && arr[left] <= arr[r]) {
                    r--;
                }
                while (l < r && arr[left] >= arr[l]) {
                    l++;
                }
                int tmp = arr[l];
                arr[l] = arr[r];
                arr[r] = tmp;
            }
            int tmp = arr[l];
            arr[l] = arr[left];
            arr[left] = tmp;
            quickSort(arr, left, l - 1);
            quickSort(arr, l+1, right);
        }
    }
};