2.3. 多进程编程
多进程是绕过 GIL 实现真正并行计算的主要方式,适用于 CPU 密集型任务。
2.3.1. 基础使用
2.3.1.1. 创建进程
from multiprocessing import Process
import os
def worker(name):
print(f"Worker {name}, PID: {os.getpid()}, Parent PID: {os.getppid()}")
if __name__ == '__main__': # Windows 必须
processes = []
for i in range(4):
p = Process(target=worker, args=(i,))
processes.append(p)
p.start()
for p in processes:
p.join()
print(f"Main process PID: {os.getpid()}")
警告
在 Windows 上,多进程代码必须放在 if __name__ == '__main__': 块中,否则会导致无限递归创建进程。
2.3.1.2. 进程池
from multiprocessing import Pool
import time
def cpu_intensive(n):
"""CPU 密集型任务"""
total = 0
for i in range(n):
total += i * i
return total
if __name__ == '__main__':
# 使用进程池
with Pool(4) as pool:
# map:阻塞,保持顺序
results = pool.map(cpu_intensive, [1000000] * 8)
print(f"map results: {len(results)}")
# map_async:非阻塞
async_result = pool.map_async(cpu_intensive, [1000000] * 8)
results = async_result.get(timeout=10)
# apply:单次调用
result = pool.apply(cpu_intensive, (1000000,))
# apply_async:单次非阻塞调用
async_result = pool.apply_async(cpu_intensive, (1000000,))
result = async_result.get()
# starmap:解包参数
def add(a, b):
return a + b
results = pool.starmap(add, [(1, 2), (3, 4), (5, 6)])
2.3.1.3. ProcessPoolExecutor(推荐)
from concurrent.futures import ProcessPoolExecutor, as_completed
import time
def process_data(data):
"""处理数据"""
time.sleep(0.1)
return data * 2
if __name__ == '__main__':
data_list = list(range(100))
with ProcessPoolExecutor(max_workers=4) as executor:
# 方式1:map
results = list(executor.map(process_data, data_list))
# 方式2:submit + as_completed
futures = [executor.submit(process_data, d) for d in data_list]
for future in as_completed(futures):
result = future.result()
print(f"Got result: {result}")
2.3.2. 进程间通信
2.3.2.1. Queue(队列)
from multiprocessing import Process, Queue
import time
def producer(q):
for i in range(10):
q.put(f"item_{i}")
time.sleep(0.1)
q.put(None) # 哨兵
def consumer(q):
while True:
item = q.get()
if item is None:
break
print(f"Consumed: {item}")
if __name__ == '__main__':
q = Queue()
p1 = Process(target=producer, args=(q,))
p2 = Process(target=consumer, args=(q,))
p1.start()
p2.start()
p1.join()
p2.join()
2.3.2.2. Pipe(管道)
from multiprocessing import Process, Pipe
def sender(conn):
conn.send("Hello from sender")
conn.send([1, 2, 3])
conn.close()
def receiver(conn):
print(conn.recv()) # Hello from sender
print(conn.recv()) # [1, 2, 3]
conn.close()
if __name__ == '__main__':
parent_conn, child_conn = Pipe()
p1 = Process(target=sender, args=(child_conn,))
p2 = Process(target=receiver, args=(parent_conn,))
p1.start()
p2.start()
p1.join()
p2.join()
2.3.2.3. 共享内存
from multiprocessing import Process, Value, Array
import ctypes
def increment_counter(counter, lock):
for _ in range(10000):
with lock:
counter.value += 1
def modify_array(arr):
for i in range(len(arr)):
arr[i] = arr[i] * 2
if __name__ == '__main__':
from multiprocessing import Lock
# 共享值
counter = Value('i', 0) # 'i' = int
lock = Lock()
processes = [
Process(target=increment_counter, args=(counter, lock))
for _ in range(4)
]
for p in processes:
p.start()
for p in processes:
p.join()
print(f"Counter: {counter.value}") # 40000
# 共享数组
arr = Array('d', [1.0, 2.0, 3.0, 4.0]) # 'd' = double
p = Process(target=modify_array, args=(arr,))
p.start()
p.join()
print(f"Array: {list(arr)}") # [2.0, 4.0, 6.0, 8.0]
2.3.2.4. Manager(共享复杂对象)
from multiprocessing import Process, Manager
def worker(shared_dict, shared_list, worker_id):
shared_dict[worker_id] = f"result_{worker_id}"
shared_list.append(worker_id)
if __name__ == '__main__':
with Manager() as manager:
# Manager 支持复杂数据类型
shared_dict = manager.dict()
shared_list = manager.list()
processes = [
Process(target=worker, args=(shared_dict, shared_list, i))
for i in range(4)
]
for p in processes:
p.start()
for p in processes:
p.join()
print(f"Dict: {dict(shared_dict)}")
print(f"List: {list(shared_list)}")
2.3.3. 进程同步
2.3.3.1. Lock 和 RLock
from multiprocessing import Process, Lock, Value
def safe_increment(counter, lock):
for _ in range(10000):
with lock:
counter.value += 1
if __name__ == '__main__':
counter = Value('i', 0)
lock = Lock()
processes = [
Process(target=safe_increment, args=(counter, lock))
for _ in range(4)
]
for p in processes:
p.start()
for p in processes:
p.join()
print(f"Final counter: {counter.value}") # 正确: 40000
2.3.3.2. Semaphore
from multiprocessing import Process, Semaphore
import time
def limited_resource(sem, process_id):
with sem:
print(f"Process {process_id} acquired resource")
time.sleep(1)
print(f"Process {process_id} released resource")
if __name__ == '__main__':
# 最多 3 个进程同时访问
sem = Semaphore(3)
processes = [
Process(target=limited_resource, args=(sem, i))
for i in range(10)
]
for p in processes:
p.start()
for p in processes:
p.join()
2.3.3.3. Event
from multiprocessing import Process, Event
import time
def wait_for_event(event, process_id):
print(f"Process {process_id} waiting...")
event.wait()
print(f"Process {process_id} got signal!")
if __name__ == '__main__':
event = Event()
processes = [
Process(target=wait_for_event, args=(event, i))
for i in range(3)
]
for p in processes:
p.start()
time.sleep(2)
print("Setting event...")
event.set()
for p in processes:
p.join()
2.3.4. 实用模式
2.3.4.1. 工作池模式
from multiprocessing import Pool
import os
def init_worker():
"""初始化每个 worker 进程"""
print(f"Worker {os.getpid()} initialized")
def process_item(item):
return item * 2
if __name__ == '__main__':
with Pool(4, initializer=init_worker) as pool:
results = pool.map(process_item, range(10))
print(results)
2.3.4.2. 分块处理
from multiprocessing import Pool
def process_chunk(chunk):
"""处理数据块"""
return [x * 2 for x in chunk]
if __name__ == '__main__':
data = list(range(1000))
with Pool(4) as pool:
# chunksize 可以减少进程间通信开销
results = pool.map(process_chunk,
[data[i:i+100] for i in range(0, len(data), 100)],
chunksize=2)
# 展平结果
flat_results = [item for chunk in results for item in chunk]
2.3.4.3. 超时处理
from multiprocessing import Pool
from multiprocessing.pool import TimeoutError
def slow_function(x):
import time
time.sleep(x)
return x
if __name__ == '__main__':
with Pool(2) as pool:
async_result = pool.apply_async(slow_function, (10,))
try:
result = async_result.get(timeout=3)
except TimeoutError:
print("Task timed out!")
pool.terminate() # 强制终止
2.3.4.4. 错误处理
from concurrent.futures import ProcessPoolExecutor, as_completed
def risky_function(x):
if x == 5:
raise ValueError(f"Bad value: {x}")
return x * 2
if __name__ == '__main__':
with ProcessPoolExecutor(max_workers=4) as executor:
futures = {executor.submit(risky_function, i): i for i in range(10)}
for future in as_completed(futures):
x = futures[future]
try:
result = future.result()
print(f"Input {x} -> Result {result}")
except Exception as e:
print(f"Input {x} raised {e}")
2.3.5. 注意事项
进程 vs 线程
特性 |
进程 |
线程 |
|---|---|---|
创建开销 |
大 |
小 |
内存共享 |
需要特殊机制 |
自动共享 |
GIL 影响 |
不受影响 |
受影响 |
通信方式 |
Queue/Pipe/共享内存 |
直接共享 |
适用场景 |
CPU 密集型 |
I/O 密集型 |
常见陷阱
# ❌ 在函数外定义 Pool(Windows 问题)
pool = Pool(4) # 错误!
# ✅ 在 main 块中创建
if __name__ == '__main__':
with Pool(4) as pool:
pass
# ❌ 传递不可 pickle 的对象
def worker(callback): # lambda 不可 pickle
pass
# ✅ 使用普通函数
def my_callback():
pass
# ❌ 在子进程中使用全局锁
lock = Lock() # 这个锁不会在子进程间共享
# ✅ 通过参数传递锁
if __name__ == '__main__':
lock = Lock()
Process(target=worker, args=(lock,)).start()
性能建议
减少进程间通信:通信开销大
使用合适的数据结构:Queue vs Pipe vs 共享内存
批量处理:减少函数调用次数
进程数量:通常 = CPU 核心数