目錄

廣告 AD

如何從陣列中找出 Kth 大/小的元素?

這個好像是個經典的問題

於是把看過的方法都記錄下來

做個整理

廣告 AD

今天有一個數字陣列,數字是未排序的 (如果排序過就直接回傳就好了),今天要找出這個陣列中第 K 大/小的元素。

舉例:今天是陣列是 [6, 4, 5, 3, 2],那第 3 大的元素就是 4。

由於找第 K 大的跟第 K 小的是很像的,我們下面只舉例找第 K 大的。

以下提供幾種方法。

如果數字是未排序的,那…我們排序不就好了?最簡單的方法就是排序了,不論用 QuickSort、MergeSort、HeapSort 甚至其他的 Sort 都可以。


作法:

  1. 將陣列排序 (由大到小)

  2. 回傳陣列中第 K 個元素

複雜度:

  • 時間複雜度:$O(N \log_2{N})$
    這取決於步驟 1 的排列的時間複雜度,因為步驟 2 的時間複雜度是 $O(1)$,一般來說排序的時間複雜度是 $O(N \log_2{N})$ (不算非比較性的 Radix Sort 或是 Bucket Sort)。
  • 空間複雜度:$O(\log_2{N})$
    一樣是取決於步驟 1 的空間複雜度,如果用的是 QuickSort,由於堆疊呼叫會有 $\log_2{N}$ 層,因此空間複雜度是 $O(\log_2{N})$。
其他重點
要注意使用的排序的方法時間複雜度和空間複雜度,如果知道資料範圍,可以用非比較的排序,或是空間上有強烈要求,那要注意能不能用 in-place 的方法,甚至如果題目有需要 stable 的條件,也要注意排序演算法是不是 stable。

cpp

// k: [1, N]
int kth_element(std::vector<int> &arr, int k){
  std::sort(arr.begin(), arr.end());
  return arr[k-1];
}

這個方法有點類似 QuickSort,在做 QuickSort 的時候,partition 會回傳 pivot 排序好的 index,這樣我們就知道 pivot 是在陣列中排名第幾了,根據名次,我們可以知道答案在 pivot 的左邊還是右邊,以下說明需要先了解 QuickSort。


作法:

  1. 選擇一個元素當作 pivot。
  2. 做 partition (Hoare partition scheme / Lomuto partition scheme)12
  3. 計算 pivot 是當前陣列的第幾大
  4. 假設 pivot 是當前陣列的第 T 大,和 K 比較。
    • T < K : 答案在 pivot 的左邊,去 pivot 左邊的陣列找,回到步驟 1。
    • T = K : 那 pivot 就剛好是答案啦。
    • T > K : 答案在 pivot 的右邊,去 pivot 右邊的陣列找,調整 K 的數值,因為左邊有 T 個元素,因此原始陣列的第 K 個是右邊陣列的第 K - T 個,回到步驟 1。

複雜度:

  • 時間複雜度:$O(N)$ (最佳),$O(N^2)$ (最差) 3
    由於是用 QuickSort 的 partition,因此也會有 QuickSort 選擇 pivot 的問題
  • 空間複雜度:$O(1)$
    雖然底下範例用的是遞迴,所需的空間是 $O(\log_2{N})$ 但由於是 tail recursive,因此完全可以改寫成 iterative 的版本,這樣空間複雜度就是 $O(1)$

cpp

// k: [1, N]
// 這是 Lomuto partition scheme
// 實際的 QuickSelect 是用 Hoare partition scheme
int partition(int start, int end, int k){
  // 取第一個元素當 pivot
  int pivot = table[start];
  int small_idx = start+1;
  // 依照大小,小的或是等於的往左邊丟,大的往右邊丟
  for(int i = start+1; i < end; ++i){
    if(table[i] <= pivot){
        swap(table[small_idx], table[i]);
        small_idx++;
    }
  }
  // 把 pivot 移到正中間
  swap(table[start], table[small_idx-1]);
  // 計算 pivot 是第幾個
  int pivot_rank = small_idx-start-1;

  if(k == pivot_rank){ // 找到目標了
    return pivot;
  }else if(k > pivot_rank){ // 目標要的大於現在 pivot ,所以要往右邊找,k 要扣掉現在的
    return partition(small_idx, end, k-pivot_rank-1);
  }else{ // 找太多了,所以網做邊找,一樣是找第 k 個,k 不用動
    return partition(start, small_idx-1, k);
  }
}

如果你想要比較好理解、時間上又比單純排序還要快的方法的話那就是用 heap 了。我們會維護一個 min-heap,讓這個 min-heap 裡面存放我們目前看過的元素中最大的 K 個,如果我們下一個看的元素比「目前看過的元素中最大的 K 個中最小的」還要大,也就是下一個看的元素可以排入「目前看過的元素中最大的 K 個」,那我們就更新這個 min-heap,直到所有的元素看完,最後這個 min-heap 中最小的就是答案。


作法:

  1. 拿取下一個元素塞入 min-heap。

  2. 如果 min-heap 裡的數量超過 K,則丟掉 top 元素 (當中最小的)。

  3. 如果還沒看完,則回到步驟 1。

  4. 回傳 min-heap 的 top 元素。

複雜度:

  • 時間複雜度:$O(N\log_2{K})$
    因為 heap 的 pop 是 $O(1)$,總共會被 pop $(N-K)$ 次,pop 的時間複雜度是 $O(N)$,heap 的 push 是 $O(\log_2{K})$,push 的時間複雜度是 $O(N\log_2{K})$,總共就是 $O(N\log_2{K})$。
  • 空間複雜度:$O(K)$
    min-heap 最高同時會有 K 個元素。

cpp

// k: [1, N]
int kth_element(std::vector<int> &arr, int k){
  std::priority_queue<int, vector<int>, std::greater<int>> pq;
  for(int i = 0; i < arr.size(); ++i){
    pq.push(arr[i]);
    if(pq.size() > k) pq.pop();
  }
  return pq.top();
}

由於用 QuickSelect 時,pivot 選不好會讓時間複雜度變成 $O(N^2)$,因此有其他改進的方法,詳細的部分可以參考這裡

廣告 AD