2.2. 多线程编程

Python 的 threading 模块提供了多线程编程能力。虽然受 GIL 限制,但对于 I/O 密集型任务仍然非常有效。

2.2.1. 基础使用

2.2.1.1. 创建线程

import threading
import time

# 方式1:传递函数
def worker(name, delay):
    print(f"Worker {name} starting")
    time.sleep(delay)
    print(f"Worker {name} finished")

t = threading.Thread(target=worker, args=("A", 2))
t.start()
t.join()  # 等待线程完成

# 方式2:继承 Thread 类
class MyThread(threading.Thread):
    def __init__(self, name, delay):
        super().__init__()
        self.name = name
        self.delay = delay
    
    def run(self):
        print(f"Thread {self.name} starting")
        time.sleep(self.delay)
        print(f"Thread {self.name} finished")

t = MyThread("B", 1)
t.start()
t.join()

2.2.1.2. 线程属性

import threading
import time

def task():
    print(f"Thread name: {threading.current_thread().name}")
    print(f"Thread ID: {threading.get_ident()}")
    time.sleep(2)

# 设置线程属性
t = threading.Thread(
    target=task,
    name="MyWorker",
    daemon=True  # 守护线程:主线程结束时自动终止
)

print(f"Active threads: {threading.active_count()}")
print(f"Thread alive: {t.is_alive()}")  # False

t.start()
print(f"Thread alive: {t.is_alive()}")  # True

# 列出所有线程
for thread in threading.enumerate():
    print(f"  {thread.name}")

2.2.2. 线程同步

2.2.2.1. Lock(互斥锁)

import threading

class BankAccount:
    def __init__(self, balance):
        self.balance = balance
        self.lock = threading.Lock()
    
    def withdraw(self, amount):
        with self.lock:  # 获取锁
            if self.balance >= amount:
                # 模拟处理时间
                import time
                time.sleep(0.001)
                self.balance -= amount
                return True
            return False
    
    def deposit(self, amount):
        with self.lock:
            self.balance += amount

account = BankAccount(1000)

def make_withdrawals():
    for _ in range(100):
        account.withdraw(10)

threads = [threading.Thread(target=make_withdrawals) for _ in range(10)]
for t in threads:
    t.start()
for t in threads:
    t.join()

print(f"Final balance: {account.balance}")  # 正确:0

2.2.2.2. RLock(可重入锁)

import threading

class SafeCounter:
    def __init__(self):
        self.value = 0
        self.lock = threading.RLock()  # 可重入锁
    
    def increment(self):
        with self.lock:
            self.value += 1
    
    def increment_twice(self):
        with self.lock:  # 第一次获取
            self.increment()  # 内部再次获取同一把锁
            self.increment()  # RLock 允许同一线程多次获取

# 如果使用普通 Lock,increment_twice 会死锁

2.2.2.3. Semaphore(信号量)

import threading
import time

# 限制同时访问资源的线程数
connection_pool = threading.Semaphore(5)

def access_database(thread_id):
    with connection_pool:
        print(f"Thread {thread_id} acquired connection")
        time.sleep(1)  # 模拟数据库操作
        print(f"Thread {thread_id} released connection")

# 创建 10 个线程,但只有 5 个能同时获取连接
threads = [
    threading.Thread(target=access_database, args=(i,))
    for i in range(10)
]
for t in threads:
    t.start()
for t in threads:
    t.join()

2.2.2.4. Event(事件)

import threading
import time

# 用于线程间通信
start_event = threading.Event()

def worker(name):
    print(f"Worker {name} waiting...")
    start_event.wait()  # 阻塞直到事件被设置
    print(f"Worker {name} started!")

threads = [threading.Thread(target=worker, args=(i,)) for i in range(3)]
for t in threads:
    t.start()

time.sleep(2)
print("Setting event...")
start_event.set()  # 所有等待的线程同时开始

for t in threads:
    t.join()

2.2.2.5. Condition(条件变量)

import threading
import time
from collections import deque

class ProducerConsumer:
    def __init__(self, max_size):
        self.queue = deque()
        self.max_size = max_size
        self.condition = threading.Condition()
    
    def produce(self, item):
        with self.condition:
            while len(self.queue) >= self.max_size:
                print("Queue full, producer waiting...")
                self.condition.wait()  # 等待消费者消费
            
            self.queue.append(item)
            print(f"Produced: {item}, queue size: {len(self.queue)}")
            self.condition.notify()  # 通知消费者
    
    def consume(self):
        with self.condition:
            while len(self.queue) == 0:
                print("Queue empty, consumer waiting...")
                self.condition.wait()  # 等待生产者生产
            
            item = self.queue.popleft()
            print(f"Consumed: {item}, queue size: {len(self.queue)}")
            self.condition.notify()  # 通知生产者
            return item

pc = ProducerConsumer(5)

def producer():
    for i in range(10):
        pc.produce(i)
        time.sleep(0.1)

def consumer():
    for _ in range(10):
        pc.consume()
        time.sleep(0.2)

t1 = threading.Thread(target=producer)
t2 = threading.Thread(target=consumer)
t1.start()
t2.start()
t1.join()
t2.join()

2.2.2.6. Barrier(屏障)

import threading
import time

# 等待所有线程到达某个点
barrier = threading.Barrier(3)

def worker(name):
    print(f"Worker {name} doing phase 1")
    time.sleep(name)  # 不同的处理时间
    
    print(f"Worker {name} waiting at barrier")
    barrier.wait()  # 等待所有线程到达
    
    print(f"Worker {name} doing phase 2")

threads = [threading.Thread(target=worker, args=(i,)) for i in range(3)]
for t in threads:
    t.start()
for t in threads:
    t.join()

2.2.3. 线程池

from concurrent.futures import ThreadPoolExecutor, as_completed
import time

def download(url):
    """模拟下载"""
    time.sleep(1)
    return f"Downloaded {url}"

urls = [f"https://example.com/page{i}" for i in range(10)]

# 使用线程池
with ThreadPoolExecutor(max_workers=5) as executor:
    # 方式1:map(保持顺序)
    results = executor.map(download, urls)
    for result in results:
        print(result)
    
    # 方式2:submit + as_completed(完成顺序)
    futures = [executor.submit(download, url) for url in urls]
    for future in as_completed(futures):
        print(future.result())
    
    # 方式3:submit + result(指定超时)
    future = executor.submit(download, "https://example.com")
    try:
        result = future.result(timeout=2)
    except TimeoutError:
        print("Download timed out")

2.2.4. 避免死锁

2.2.4.1. 死锁示例

import threading
import time

lock1 = threading.Lock()
lock2 = threading.Lock()

def thread1():
    with lock1:
        print("Thread 1 acquired lock1")
        time.sleep(0.1)
        with lock2:  # 等待 lock2
            print("Thread 1 acquired lock2")

def thread2():
    with lock2:
        print("Thread 2 acquired lock2")
        time.sleep(0.1)
        with lock1:  # 等待 lock1
            print("Thread 2 acquired lock1")

# 可能死锁!Thread 1 持有 lock1 等待 lock2
#              Thread 2 持有 lock2 等待 lock1

2.2.4.2. 避免死锁的方法

import threading

# 方法1:固定获取锁的顺序
lock1 = threading.Lock()
lock2 = threading.Lock()

def safe_thread1():
    with lock1:  # 总是先获取 lock1
        with lock2:
            pass

def safe_thread2():
    with lock1:  # 总是先获取 lock1
        with lock2:
            pass

# 方法2:使用超时
def try_lock_with_timeout():
    acquired1 = lock1.acquire(timeout=1)
    if not acquired1:
        return False
    
    try:
        acquired2 = lock2.acquire(timeout=1)
        if not acquired2:
            return False
        try:
            # 做事情
            pass
        finally:
            lock2.release()
    finally:
        lock1.release()
    return True

# 方法3:使用 contextlib.ExitStack
from contextlib import ExitStack

def acquire_locks(*locks):
    with ExitStack() as stack:
        for lock in sorted(locks, key=id):  # 按 id 排序确保顺序一致
            stack.enter_context(lock)
        yield

2.2.5. 线程安全的数据结构

from queue import Queue, LifoQueue, PriorityQueue
import threading

# Queue(FIFO)
q = Queue(maxsize=10)
q.put("item")
item = q.get()
q.task_done()

# LifoQueue(栈)
stack = LifoQueue()
stack.put("first")
stack.put("second")
print(stack.get())  # second

# PriorityQueue(优先队列)
pq = PriorityQueue()
pq.put((2, "medium"))
pq.put((1, "high"))
pq.put((3, "low"))
print(pq.get())  # (1, 'high')

# 生产者-消费者模式
def producer(q):
    for i in range(10):
        q.put(i)
    q.put(None)  # 哨兵值

def consumer(q):
    while True:
        item = q.get()
        if item is None:
            break
        print(f"Processing {item}")
        q.task_done()

q = Queue()
t1 = threading.Thread(target=producer, args=(q,))
t2 = threading.Thread(target=consumer, args=(q,))
t1.start()
t2.start()
t1.join()
t2.join()

2.2.6. 最佳实践

何时使用线程
  1. I/O 密集型任务:网络请求、文件操作

  2. GUI 应用:保持界面响应

  3. 简单并发:不需要复杂的进程间通信

注意事项
  1. 使用 with 语句获取锁:确保释放

  2. 设置合理的超时:避免无限等待

  3. 使用线程池:避免创建过多线程

  4. 优先使用队列:而非共享状态加锁

常见错误
# ❌ 忘记 join
t = threading.Thread(target=work)
t.start()
# 主线程可能在子线程完成前结束

# ✅ 等待线程完成
t = threading.Thread(target=work)
t.start()
t.join()

# ❌ 在锁外检查条件
if not lock.locked():  # 检查和获取之间可能被其他线程抢占
    lock.acquire()

# ✅ 使用 with 或 acquire 的返回值
with lock:
    pass