如何從陣列中找出 Kth 大/小的元素?

這個好像是個經典的問題
於是把看過的方法都記錄下來
做個整理
問題描述
今天有一個數字陣列,數字是未排序的 (如果排序過就直接回傳就好了),今天要找出這個陣列中第 K 大/小的元素。
舉例:今天是陣列是 [6, 4, 5, 3, 2],那第 3 大的元素就是 4。
由於找第 K 大的跟第 K 小的是很像的,我們下面只舉例找第 K 大的。
解法
以下提供幾種方法。
排序 (Sort)
如果數字是未排序的,那…我們排序不就好了?最簡單的方法就是排序了,不論用 QuickSort、MergeSort、HeapSort 甚至其他的 Sort 都可以。
作法:
將陣列排序 (由大到小)
回傳陣列中第 K 個元素
複雜度:
- 時間複雜度:$O(N \log_2{N})$
這取決於步驟 1 的排列的時間複雜度,因為步驟 2 的時間複雜度是 $O(1)$,一般來說排序的時間複雜度是 $O(N \log_2{N})$ (不算非比較性的 Radix Sort 或是 Bucket Sort)。 - 空間複雜度:$O(\log_2{N})$
一樣是取決於步驟 1 的空間複雜度,如果用的是 QuickSort,由於堆疊呼叫會有 $\log_2{N}$ 層,因此空間複雜度是 $O(\log_2{N})$。
// k: [1, N]
int kth_element(std::vector<int> &arr, int k){
std::sort(arr.begin(), arr.end());
return arr[k-1];
}
快速選擇 (QuickSelect)
這個方法有點類似 QuickSort,在做 QuickSort 的時候,partition 會回傳 pivot 排序好的 index,這樣我們就知道 pivot 是在陣列中排名第幾了,根據名次,我們可以知道答案在 pivot 的左邊還是右邊,以下說明需要先了解 QuickSort。
作法:
- 選擇一個元素當作 pivot。
- 做 partition (Hoare partition scheme / Lomuto partition scheme)12。
- 計算 pivot 是當前陣列的第幾大
- 假設 pivot 是當前陣列的第 T 大,和 K 比較。
- T < K : 答案在 pivot 的左邊,去 pivot 左邊的陣列找,回到步驟 1。
- T = K : 那 pivot 就剛好是答案啦。
- T > K : 答案在 pivot 的右邊,去 pivot 右邊的陣列找,調整 K 的數值,因為左邊有 T 個元素,因此原始陣列的第 K 個是右邊陣列的第 K - T 個,回到步驟 1。
複雜度:
- 時間複雜度:$O(N)$ (最佳),$O(N^2)$ (最差) 3
由於是用 QuickSort 的 partition,因此也會有 QuickSort 選擇 pivot 的問題 - 空間複雜度:$O(1)$
雖然底下範例用的是遞迴,所需的空間是 $O(\log_2{N})$ 但由於是 tail recursive,因此完全可以改寫成 iterative 的版本,這樣空間複雜度就是 $O(1)$
// k: [1, N]
// 這是 Lomuto partition scheme
// 實際的 QuickSelect 是用 Hoare partition scheme
int partition(int start, int end, int k){
// 取第一個元素當 pivot
int pivot = table[start];
int small_idx = start+1;
// 依照大小,小的或是等於的往左邊丟,大的往右邊丟
for(int i = start+1; i < end; ++i){
if(table[i] <= pivot){
swap(table[small_idx], table[i]);
small_idx++;
}
}
// 把 pivot 移到正中間
swap(table[start], table[small_idx-1]);
// 計算 pivot 是第幾個
int pivot_rank = small_idx-start-1;
if(k == pivot_rank){ // 找到目標了
return pivot;
}else if(k > pivot_rank){ // 目標要的大於現在 pivot ,所以要往右邊找,k 要扣掉現在的
return partition(small_idx, end, k-pivot_rank-1);
}else{ // 找太多了,所以網做邊找,一樣是找第 k 個,k 不用動
return partition(start, small_idx-1, k);
}
}
Heap
如果你想要比較好理解、時間上又比單純排序還要快的方法的話那就是用 heap 了。我們會維護一個 min-heap,讓這個 min-heap 裡面存放我們目前看過的元素中最大的 K 個,如果我們下一個看的元素比「目前看過的元素中最大的 K 個中最小的」還要大,也就是下一個看的元素可以排入「目前看過的元素中最大的 K 個」,那我們就更新這個 min-heap,直到所有的元素看完,最後這個 min-heap 中最小的就是答案。
作法:
拿取下一個元素塞入 min-heap。
如果 min-heap 裡的數量超過 K,則丟掉 top 元素 (當中最小的)。
如果還沒看完,則回到步驟 1。
回傳 min-heap 的 top 元素。
複雜度:
- 時間複雜度:$O(N\log_2{K})$
因為 heap 的 pop 是 $O(1)$,總共會被 pop $(N-K)$ 次,pop 的時間複雜度是 $O(N)$,heap 的 push 是 $O(\log_2{K})$,push 的時間複雜度是 $O(N\log_2{K})$,總共就是 $O(N\log_2{K})$。 - 空間複雜度:$O(K)$
min-heap 最高同時會有 K 個元素。
// k: [1, N]
int kth_element(std::vector<int> &arr, int k){
std::priority_queue<int, vector<int>, std::greater<int>> pq;
for(int i = 0; i < arr.size(); ++i){
pq.push(arr[i]);
if(pq.size() > k) pq.pop();
}
return pq.top();
}
分組
由於用 QuickSelect 時,pivot 選不好會讓時間複雜度變成 $O(N^2)$,因此有其他改進的方法,詳細的部分可以參考這裡。
Reference
如果你覺得這篇文章有用 可以考慮贊助飲料給大貓咪