目錄

廣告 AD

Thread pool:簡單的 C 語言實作

通常在跑程式的時候,如果速度不夠快,試試看平行處理或許是個好選擇

開起多個 Thread 一起來處理各個任務

做完任務後再回到 Queue 等待下個任務執行

廣告 AD

這次我們是用 C 來實作,因此要先來建立 Queue

一共實作了下方幾種功能:

  • Queue_Init:初始化 Queue,並給予 Queue 最大的大小。
  • Queue_Destroy:不使用 Queue,釋放 Queue 的空間。
  • Queue_Empty:是否 Queue 是空的。
  • Queue_Full:是否 Queue 是滿的。
  • Queue_Push:放元素到 Queue 裡面。
  • Queue_Front:得到 Queue 的開頭元素。
  • Queue_Pop:丟到 Queue 的開頭元素。

C

#ifndef __QUEUE_H__
#define __QUEUE_H__

typedef struct
{
  int *arr;
  unsigned max_size, size, start;
} Queue;

void Queue_Init(Queue *q, unsigned max_size);
void Queue_Destroy(Queue *q);
int Queue_Empty(Queue *q);
int Queue_Full(Queue *q);
int Queue_Push(Queue *q, int element);
int Queue_Front(Queue *q, int *element);
void Queue_Pop(Queue *q);

#endif

Queue 這邊並不是用 Linked-List 實作的,而是簡單用 Array 做出 ring buffer,缺點是容量沒辦法動態調整。

如果容量滿了,不是覆蓋掉舊資料,而是返回錯誤 (1),代表執行錯誤。

C

#include <stdlib.h>
#include "queue.h"

void Queue_Init(Queue *q, unsigned max_size)
{
  q->max_size = max_size;
  q->arr = (int *)malloc(sizeof(int) * max_size);
  q->start = q->size = 0;
}

void Queue_Destroy(Queue *q)
{
  free(q->arr);
}

int Queue_Empty(Queue *q)
{
  return q->size == 0;
}

int Queue_Full(Queue *q)
{
  return q->size == q->max_size;
}

int Queue_Push(Queue *q, int element)
{
  if (Queue_Full(q))
    return 1;
  unsigned add_idx = (q->start + q->size < q->max_size ? (q->start + q->size) : (q->start + q->size - q->max_size));
  q->arr[add_idx] = element;
  q->size += 1;
  return 0;
}

int Queue_Front(Queue *q, int *element)
{
  if (Queue_Empty(q))
    return 1;
  *element = q->arr[q->start];
  return 0;
}

void Queue_Pop(Queue *q)
{
  if (Queue_Empty(q))
    return;
  q->size -= 1;
  q->start += 1;
  if (q->start == q->max_size)
    q->start = 0;
}

Thread 這邊用 POSIX 的 thread: pthread,Thread_Pool 提供了以下函式:

  • Thread_Pool_Init:初始化 Thread_Pool。
  • Thread_Pool_Destroy:不使用 Thread_Pool Thread_Pool 的空間。
  • Thread_Pool_Get:從 Thread_Pool 取得空閒的 Thread。
  • Thread_Pool_Add:把完成任務的 Thread 放回 Thread_Pool,等待執行其他任務。

C

#ifndef __THREAD_POOL_H__
#define __THREAD_POOL_H__

#include <pthread.h>
#include "queue.h"

typedef struct
{
  Queue q;
  pthread_t *threads;
  pthread_mutex_t q_mtx;
  pthread_cond_t q_cond;
  unsigned size;
} Thread_Pool;

void Thread_Pool_Init(Thread_Pool *pool, unsigned num_of_thread);
void Thread_Pool_Destroy(Thread_Pool *pool);
int Thread_Pool_Get(Thread_Pool *pool);
void Thread_Pool_Add(Thread_Pool *pool, int thread_id);
#endif

由於上面的 Queue 並沒有保證 Thread safe,因此這邊加上使用了 mutex 來保證 Queue 裡面的資料是正確的。

如果當前的 Queue 沒有任何空閒的 Thread,則會需要等待正在執行任務的 Thread 完成,因此在 Thread_Pool_Get 裡面,如果當前的 Queue 已經空了,則會進入等待 (pthread_cond_wait),在等待的時候會釋放 mutex (q_mtx),直到有人使用 pthread_cond_signal 呼叫 condition variable (q_cond),告訴其他人說有 Thread 被釋出了,接著重新獲得 mutex (q_mtx),拿到空閒的 Thread 來執行任務。

C

#include <stdlib.h>
#include <pthread.h>
#include "thread_pool.h"

void Thread_Pool_Init(Thread_Pool *pool, unsigned num_of_thread)
{
  pool->size = num_of_thread;
  Queue_Init(&(pool->q), num_of_thread);
  for (int i = 0; i < num_of_thread; ++i)
    Queue_Push(&(pool->q), i);
  pool->threads = (pthread_t *)malloc(sizeof(pthread_t) * num_of_thread);
  pthread_mutex_init(&(pool->q_mtx), NULL);
  pthread_cond_init(&(pool->q_cond), NULL);
}

void Thread_Pool_Destroy(Thread_Pool *pool)
{
  Queue_Destroy(&(pool->q));
  free(pool->threads);
  pthread_mutex_destroy(&(pool->q_mtx));
  pthread_cond_destroy(&(pool->q_cond));
}

int Thread_Pool_Get(Thread_Pool *pool)
{
  pthread_mutex_lock(&(pool->q_mtx));
  while (Queue_Empty(&(pool->q)))
  {
    pthread_cond_wait(&(pool->q_cond), &(pool->q_mtx));
  }
  int thread_id;
  Queue_Front(&(pool->q), &thread_id);
  Queue_Pop(&(pool->q));
  pthread_mutex_unlock(&(pool->q_mtx));
  return thread_id;
}

void Thread_Pool_Add(Thread_Pool *pool, int thread_id)
{
  pthread_mutex_lock(&(pool->q_mtx));
  Queue_Push(&(pool->q), thread_id);
  pthread_mutex_unlock(&(pool->q_mtx));
  pthread_cond_signal(&(pool->q_cond));
}

下面寫了一個小小的 Example 來使用 Thread Pool,總共會執行 NUM_OF_TASK 個任務,每個任務都會傳入參數,並回傳任務完成時間,任務內容就簡單放個睡覺。

在每次取得新的 Thread 之後都執行 pthread_join 取得之前的回傳值並儲存下來,如果每個任務都分配完了,則用 pthread_join 來等待全部的 Thread 完成並取得回傳值。

最後就把回傳值都印出來,接著銷毀 Thread Pool。

C

#include <pthread.h>
#include <unistd.h>
#include <stdio.h>
#include <time.h>
#include <stdlib.h>
#include "thread_pool.h"

#define NUM_OF_THREAD 4
#define NUM_OF_TASK 20
#define SLEEP_SECOND 3

// task args
typedef struct
{
  Thread_Pool *pool;
  int thread_id;
  int task_id;
} Task_Args;

// task return values
typedef struct
{
  int task_id;
  time_t finish_time;
} Task_Return_Data;

// task function
void *task(void *args)
{
  Task_Args *data = (Task_Args *)args;

  // sleep for a while
  printf("Task %d in Thread %d\n", data->task_id, data->thread_id);
  fflush(stdout);
  sleep(SLEEP_SECOND);

  // prepare return values
  Task_Return_Data *ret_data = (Task_Return_Data *)malloc(sizeof(Task_Return_Data));
  ret_data->task_id = data->task_id;
  ret_data->finish_time = time(NULL);

  Thread_Pool_Add(data->pool, data->thread_id);
  free(args);
  pthread_exit((void *)ret_data);
}

int main()
{
  Thread_Pool pool;
  Thread_Pool_Init(&pool, NUM_OF_THREAD);

  Task_Return_Data *data[NUM_OF_TASK];
  void *ret;

  // create tasks
  for (int i = 0; i < NUM_OF_TASK; ++i)
  {
    // get available thread id
    int id = Thread_Pool_Get(&pool);
    // check the return value
    pthread_join(pool.threads[id], &ret);
    if (ret != NULL)
    {
      Task_Return_Data *ret_data = ((Task_Return_Data *)ret);
      data[ret_data->task_id - 1] = ret_data;
    }
    // prepare the parameters
    Task_Args *args = (Task_Args *)malloc(sizeof(Task_Args));
    args->pool = &pool;
    args->thread_id = id;
    args->task_id = i + 1;
    // create the thread
    pthread_create(&(pool.threads[id]), NULL, task, args);
  }

  // wait for finishing all tasks
  for (int i = 0; i < pool.size; ++i)
  {
    pthread_join(pool.threads[i], &ret);
    if (ret != NULL)
    {
      Task_Return_Data *ret_data = ((Task_Return_Data *)ret);
      data[ret_data->task_id - 1] = ret_data;
    }
  }

  // deal with return values
  for (int i = 0; i < NUM_OF_TASK; ++i)
  {
    struct tm *stamp = localtime(&(data[i]->finish_time));
    printf("Task %d Finish At %04d-%02d-%02d %02d:%02d:%02d\n",
           data[i]->task_id,
           stamp->tm_year + 1900, stamp->tm_mon + 1, stamp->tm_mday,
           stamp->tm_hour, stamp->tm_min, stamp->tm_sec);
    fflush(stdout);
  }

  // free
  for (int i = 0; i < pool.size; ++i)
  {
    free(data[i]);
  }
  Thread_Pool_Destroy(&pool);

  return 0;
}

廣告 AD