2.5. 并发编程模式

本节介绍一些常用的并发编程模式和最佳实践。

2.5.1. 生产者-消费者模式

2.5.1.1. 异步版本

import asyncio
from dataclasses import dataclass
from typing import Any

@dataclass
class Task:
    id: int
    data: Any

async def producer(queue: asyncio.Queue, count: int):
    """生产者:生成任务"""
    for i in range(count):
        task = Task(id=i, data=f"data_{i}")
        await queue.put(task)
        print(f"Produced: {task}")
        await asyncio.sleep(0.1)
    
    # 发送停止信号
    await queue.put(None)

async def consumer(queue: asyncio.Queue, name: str):
    """消费者:处理任务"""
    while True:
        task = await queue.get()
        
        if task is None:
            # 放回停止信号供其他消费者使用
            await queue.put(None)
            break
        
        print(f"{name} processing: {task}")
        await asyncio.sleep(0.2)  # 模拟处理时间
        queue.task_done()

async def main():
    queue = asyncio.Queue(maxsize=10)
    
    # 启动多个消费者
    consumers = [
        asyncio.create_task(consumer(queue, f"Consumer-{i}"))
        for i in range(3)
    ]
    
    # 启动生产者
    await producer(queue, 20)
    
    # 等待所有任务处理完成
    await queue.join()
    
    # 取消消费者
    for c in consumers:
        c.cancel()

asyncio.run(main())

2.5.1.2. 多进程版本

from multiprocessing import Process, Queue, Event
import time

def producer(queue, count, stop_event):
    for i in range(count):
        queue.put(f"item_{i}")
        time.sleep(0.1)
    stop_event.set()

def consumer(queue, name, stop_event):
    while not (stop_event.is_set() and queue.empty()):
        try:
            item = queue.get(timeout=0.5)
            print(f"{name} processing: {item}")
            time.sleep(0.2)
        except:
            continue

if __name__ == '__main__':
    queue = Queue()
    stop_event = Event()
    
    producer_proc = Process(target=producer, args=(queue, 20, stop_event))
    consumer_procs = [
        Process(target=consumer, args=(queue, f"Consumer-{i}", stop_event))
        for i in range(3)
    ]
    
    producer_proc.start()
    for p in consumer_procs:
        p.start()
    
    producer_proc.join()
    for p in consumer_procs:
        p.join()

2.5.2. 工作池模式

import asyncio
from typing import Callable, Any, List

class AsyncWorkerPool:
    """异步工作池"""
    
    def __init__(self, num_workers: int, max_queue_size: int = 0):
        self.num_workers = num_workers
        self.queue = asyncio.Queue(maxsize=max_queue_size)
        self.workers: List[asyncio.Task] = []
        self.results: List[Any] = []
        self._lock = asyncio.Lock()
    
    async def _worker(self, worker_id: int):
        """工作协程"""
        while True:
            try:
                func, args, kwargs = await self.queue.get()
                
                if func is None:
                    break
                
                try:
                    result = await func(*args, **kwargs)
                    async with self._lock:
                        self.results.append(result)
                except Exception as e:
                    print(f"Worker {worker_id} error: {e}")
                finally:
                    self.queue.task_done()
            except asyncio.CancelledError:
                break
    
    async def start(self):
        """启动工作池"""
        self.workers = [
            asyncio.create_task(self._worker(i))
            for i in range(self.num_workers)
        ]
    
    async def submit(self, func: Callable, *args, **kwargs):
        """提交任务"""
        await self.queue.put((func, args, kwargs))
    
    async def shutdown(self):
        """关闭工作池"""
        # 发送停止信号
        for _ in range(self.num_workers):
            await self.queue.put((None, None, None))
        
        # 等待所有工作完成
        await asyncio.gather(*self.workers)

async def process_item(item):
    await asyncio.sleep(0.1)
    return item * 2

async def main():
    pool = AsyncWorkerPool(num_workers=4)
    await pool.start()
    
    # 提交任务
    for i in range(20):
        await pool.submit(process_item, i)
    
    # 等待队列清空
    await pool.queue.join()
    
    # 关闭
    await pool.shutdown()
    
    print(f"Results: {sorted(pool.results)}")

asyncio.run(main())

2.5.3. 限流模式

2.5.3.1. 令牌桶

import asyncio
import time

class TokenBucket:
    """令牌桶限流器"""
    
    def __init__(self, rate: float, capacity: int):
        """
        Args:
            rate: 每秒生成的令牌数
            capacity: 桶容量
        """
        self.rate = rate
        self.capacity = capacity
        self.tokens = capacity
        self.last_update = time.monotonic()
        self._lock = asyncio.Lock()
    
    async def acquire(self, tokens: int = 1) -> bool:
        """获取令牌"""
        async with self._lock:
            now = time.monotonic()
            # 补充令牌
            elapsed = now - self.last_update
            self.tokens = min(
                self.capacity,
                self.tokens + elapsed * self.rate
            )
            self.last_update = now
            
            if self.tokens >= tokens:
                self.tokens -= tokens
                return True
            return False
    
    async def wait_for_token(self, tokens: int = 1):
        """等待获取令牌"""
        while not await self.acquire(tokens):
            await asyncio.sleep(0.1)

async def api_call(limiter: TokenBucket, request_id: int):
    await limiter.wait_for_token()
    print(f"Request {request_id} at {time.time():.2f}")
    await asyncio.sleep(0.1)  # 模拟 API 调用

async def main():
    # 每秒最多 5 个请求
    limiter = TokenBucket(rate=5, capacity=5)
    
    # 同时发起 20 个请求
    tasks = [api_call(limiter, i) for i in range(20)]
    await asyncio.gather(*tasks)

asyncio.run(main())

2.5.3.2. Semaphore 限流

import asyncio

class RateLimiter:
    """基于 Semaphore 的限流器"""
    
    def __init__(self, max_concurrent: int):
        self.semaphore = asyncio.Semaphore(max_concurrent)
    
    async def __aenter__(self):
        await self.semaphore.acquire()
        return self
    
    async def __aexit__(self, *args):
        self.semaphore.release()

async def fetch_url(limiter: RateLimiter, url: str):
    async with limiter:
        print(f"Fetching {url}")
        await asyncio.sleep(1)
        return f"Data from {url}"

async def main():
    limiter = RateLimiter(max_concurrent=3)
    
    urls = [f"https://api.example.com/{i}" for i in range(10)]
    tasks = [fetch_url(limiter, url) for url in urls]
    
    results = await asyncio.gather(*tasks)
    print(f"Fetched {len(results)} URLs")

asyncio.run(main())

2.5.4. 重试模式

import asyncio
import random
from functools import wraps

def async_retry(
    max_attempts: int = 3,
    delay: float = 1.0,
    backoff: float = 2.0,
    exceptions: tuple = (Exception,)
):
    """异步重试装饰器"""
    def decorator(func):
        @wraps(func)
        async def wrapper(*args, **kwargs):
            last_exception = None
            current_delay = delay
            
            for attempt in range(max_attempts):
                try:
                    return await func(*args, **kwargs)
                except exceptions as e:
                    last_exception = e
                    if attempt < max_attempts - 1:
                        print(f"Attempt {attempt + 1} failed: {e}")
                        print(f"Retrying in {current_delay:.1f}s...")
                        await asyncio.sleep(current_delay)
                        current_delay *= backoff
            
            raise last_exception
        return wrapper
    return decorator

@async_retry(max_attempts=3, delay=0.5, backoff=2.0)
async def unreliable_api():
    if random.random() < 0.7:
        raise ConnectionError("Network error")
    return "Success!"

async def main():
    try:
        result = await unreliable_api()
        print(f"Result: {result}")
    except ConnectionError:
        print("All retries failed")

asyncio.run(main())

2.5.5. 超时与取消模式

import asyncio

class TimeoutManager:
    """超时管理器"""
    
    def __init__(self, timeout: float):
        self.timeout = timeout
        self._task: asyncio.Task = None
    
    async def __aenter__(self):
        return self
    
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        pass
    
    async def run(self, coro):
        """运行协程,带超时"""
        try:
            return await asyncio.wait_for(coro, timeout=self.timeout)
        except asyncio.TimeoutError:
            print(f"Operation timed out after {self.timeout}s")
            raise

async def cancellable_operation():
    """可取消的操作"""
    try:
        while True:
            print("Working...")
            await asyncio.sleep(1)
    except asyncio.CancelledError:
        print("Operation cancelled, cleaning up...")
        await asyncio.sleep(0.1)  # 清理工作
        raise  # 重新抛出以通知调用者

async def main():
    task = asyncio.create_task(cancellable_operation())
    
    await asyncio.sleep(3)
    
    # 取消任务
    task.cancel()
    
    try:
        await task
    except asyncio.CancelledError:
        print("Task was cancelled")

asyncio.run(main())

2.5.6. 熔断器模式

import asyncio
import time
from enum import Enum
from dataclasses import dataclass

class CircuitState(Enum):
    CLOSED = "closed"      # 正常状态
    OPEN = "open"          # 熔断状态
    HALF_OPEN = "half_open"  # 半开状态

@dataclass
class CircuitBreakerConfig:
    failure_threshold: int = 5
    recovery_timeout: float = 30.0
    half_open_max_calls: int = 3

class CircuitBreaker:
    """熔断器"""
    
    def __init__(self, config: CircuitBreakerConfig = None):
        self.config = config or CircuitBreakerConfig()
        self.state = CircuitState.CLOSED
        self.failure_count = 0
        self.success_count = 0
        self.last_failure_time = 0
        self._lock = asyncio.Lock()
    
    async def call(self, func, *args, **kwargs):
        async with self._lock:
            await self._check_state()
            
            if self.state == CircuitState.OPEN:
                raise Exception("Circuit breaker is OPEN")
        
        try:
            result = await func(*args, **kwargs)
            await self._on_success()
            return result
        except Exception as e:
            await self._on_failure()
            raise
    
    async def _check_state(self):
        if self.state == CircuitState.OPEN:
            if time.time() - self.last_failure_time > self.config.recovery_timeout:
                self.state = CircuitState.HALF_OPEN
                self.success_count = 0
    
    async def _on_success(self):
        async with self._lock:
            if self.state == CircuitState.HALF_OPEN:
                self.success_count += 1
                if self.success_count >= self.config.half_open_max_calls:
                    self.state = CircuitState.CLOSED
                    self.failure_count = 0
            elif self.state == CircuitState.CLOSED:
                self.failure_count = 0
    
    async def _on_failure(self):
        async with self._lock:
            self.failure_count += 1
            self.last_failure_time = time.time()
            
            if self.failure_count >= self.config.failure_threshold:
                self.state = CircuitState.OPEN
            
            if self.state == CircuitState.HALF_OPEN:
                self.state = CircuitState.OPEN

async def unreliable_service():
    import random
    if random.random() < 0.5:
        raise Exception("Service error")
    return "Success"

async def main():
    breaker = CircuitBreaker()
    
    for i in range(20):
        try:
            result = await breaker.call(unreliable_service)
            print(f"Call {i}: {result}")
        except Exception as e:
            print(f"Call {i}: {e}")
        
        await asyncio.sleep(0.5)

asyncio.run(main())

2.5.7. 最佳实践总结

模式选择

场景

推荐模式

任务队列处理

生产者-消费者

批量并发请求

工作池 + 限流

不稳定服务调用

重试 + 熔断器

资源访问控制

Semaphore

长时间操作

超时 + 取消

关键原则
  1. 优雅降级:服务不可用时有备选方案

  2. 快速失败:尽早检测和处理错误

  3. 资源限制:控制并发数和资源使用

  4. 可观测性:记录关键指标和日志

  5. 可测试性:设计易于测试的接口