# 装饰器深入 装饰器是 Python 中最强大的特性之一,理解其原理和常见模式对于写出优雅代码至关重要。 ## 装饰器本质 装饰器本质上是一个接受函数并返回新函数的高阶函数。 ```python # 装饰器语法糖 @decorator def func(): pass # 等价于 def func(): pass func = decorator(func) ``` ## 保持函数元信息 ```python import functools # ❌ 不使用 @functools.wraps def bad_decorator(func): def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper @bad_decorator def greet(name): """Return a greeting.""" return f"Hello, {name}!" print(greet.__name__) # wrapper (错误!) print(greet.__doc__) # None (文档丢失!) # ✅ 使用 @functools.wraps def good_decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper @good_decorator def greet(name): """Return a greeting.""" return f"Hello, {name}!" print(greet.__name__) # greet (正确!) print(greet.__doc__) # Return a greeting. ``` ## 常见装饰器模式 ### 带参数的装饰器 ```python import functools def repeat(times): """重复执行装饰器""" def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): result = None for _ in range(times): result = func(*args, **kwargs) return result return wrapper return decorator @repeat(times=3) def say_hello(): print("Hello!") say_hello() # Hello! # Hello! # Hello! ``` ### 可选参数的装饰器 ```python import functools def log(func=None, *, level='INFO'): """ 支持两种用法: @log def func(): ... @log(level='DEBUG') def func(): ... """ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): print(f"[{level}] Calling {func.__name__}") return func(*args, **kwargs) return wrapper if func is None: # @log(level='DEBUG') 方式调用 return decorator else: # @log 方式调用 return decorator(func) @log def func1(): pass @log(level='DEBUG') def func2(): pass func1() # [INFO] Calling func1 func2() # [DEBUG] Calling func2 ``` ### 带状态的装饰器(类实现) ```python import functools class CountCalls: """统计函数调用次数""" def __init__(self, func): functools.update_wrapper(self, func) self.func = func self.count = 0 def __call__(self, *args, **kwargs): self.count += 1 print(f"Call {self.count} of {self.func.__name__}") return self.func(*args, **kwargs) @CountCalls def fibonacci(n): if n < 2: return n return fibonacci(n - 1) + fibonacci(n - 2) # 这会显示递归调用次数 result = fibonacci(5) print(f"fibonacci.count = {fibonacci.count}") ``` ### 缓存装饰器 ```python import functools def memoize(func): """简单的缓存装饰器""" cache = {} @functools.wraps(func) def wrapper(*args): if args not in cache: cache[args] = func(*args) return cache[args] wrapper.cache = cache # 暴露缓存以便调试 wrapper.clear_cache = cache.clear return wrapper @memoize def slow_function(n): import time time.sleep(1) return n * 2 print(slow_function(5)) # 等待 1 秒 print(slow_function(5)) # 立即返回(缓存) # 使用标准库的 lru_cache(推荐) from functools import lru_cache @lru_cache(maxsize=128) def fibonacci(n): if n < 2: return n return fibonacci(n - 1) + fibonacci(n - 2) print(fibonacci(100)) # 快速计算 print(fibonacci.cache_info()) # 查看缓存统计 fibonacci.cache_clear() # 清除缓存 ``` ### 类型检查装饰器 ```python import functools import inspect def typecheck(func): """运行时类型检查""" sig = inspect.signature(func) @functools.wraps(func) def wrapper(*args, **kwargs): # 绑定参数 bound = sig.bind(*args, **kwargs) bound.apply_defaults() # 检查每个参数 for name, value in bound.arguments.items(): param = sig.parameters[name] if param.annotation != inspect.Parameter.empty: expected_type = param.annotation if not isinstance(value, expected_type): raise TypeError( f"Argument '{name}' must be {expected_type.__name__}, " f"got {type(value).__name__}" ) result = func(*args, **kwargs) # 检查返回值 if sig.return_annotation != inspect.Signature.empty: if not isinstance(result, sig.return_annotation): raise TypeError( f"Return value must be {sig.return_annotation.__name__}" ) return result return wrapper @typecheck def add(a: int, b: int) -> int: return a + b print(add(1, 2)) # 3 # add("1", 2) # TypeError: Argument 'a' must be int, got str ``` ## 装饰类 ```python import functools def singleton(cls): """单例装饰器""" instances = {} @functools.wraps(cls) def wrapper(*args, **kwargs): if cls not in instances: instances[cls] = cls(*args, **kwargs) return instances[cls] return wrapper @singleton class Database: def __init__(self): print("Initializing database connection") db1 = Database() # Initializing database connection db2 = Database() # 不会打印(返回同一实例) print(db1 is db2) # True ``` ### 添加方法到类 ```python def add_method(cls): """为类添加方法""" def describe(self): return f"{self.__class__.__name__}: {vars(self)}" cls.describe = describe return cls @add_method class Person: def __init__(self, name, age): self.name = name self.age = age p = Person("Alice", 30) print(p.describe()) # Person: {'name': 'Alice', 'age': 30} ``` ### 注册模式 ```python class Registry: """插件/处理器注册表""" def __init__(self): self._registry = {} def register(self, name=None): def decorator(cls): key = name or cls.__name__ self._registry[key] = cls return cls return decorator def get(self, name): return self._registry.get(name) def list_all(self): return list(self._registry.keys()) # 使用示例 handlers = Registry() @handlers.register("json") class JSONHandler: def process(self, data): return f"JSON: {data}" @handlers.register("xml") class XMLHandler: def process(self, data): return f"XML: {data}" print(handlers.list_all()) # ['json', 'xml'] handler = handlers.get("json")() print(handler.process("data")) # JSON: data ``` ## 方法装饰器 ### 静态方法和类方法 ```python class MyClass: count = 0 def __init__(self, value): self.value = value MyClass.count += 1 @staticmethod def static_method(): """不访问实例或类状态""" return "I'm static" @classmethod def class_method(cls): """访问类状态""" return f"Instance count: {cls.count}" @classmethod def from_string(cls, s): """工厂方法""" value = int(s) return cls(value) obj = MyClass.from_string("42") print(obj.value) # 42 print(MyClass.class_method()) # Instance count: 1 ``` ### property 装饰器 ```python class Temperature: def __init__(self, celsius=0): self._celsius = celsius @property def celsius(self): return self._celsius @celsius.setter def celsius(self, value): if value < -273.15: raise ValueError("Temperature below absolute zero!") self._celsius = value @celsius.deleter def celsius(self): print("Deleting celsius") del self._celsius @property def fahrenheit(self): return self._celsius * 9/5 + 32 t = Temperature(25) print(t.fahrenheit) # 77.0 t.celsius = 30 # t.celsius = -300 # ValueError del t.celsius # Deleting celsius ``` ## 异步装饰器 ```python import asyncio import functools def async_timer(func): """异步函数计时器""" @functools.wraps(func) async def wrapper(*args, **kwargs): import time start = time.perf_counter() result = await func(*args, **kwargs) elapsed = time.perf_counter() - start print(f"{func.__name__} took {elapsed:.4f}s") return result return wrapper def async_retry(max_attempts=3, delay=1): """异步重试装饰器""" def decorator(func): @functools.wraps(func) async def wrapper(*args, **kwargs): last_exception = None for attempt in range(max_attempts): try: return await func(*args, **kwargs) except Exception as e: last_exception = e if attempt < max_attempts - 1: await asyncio.sleep(delay) raise last_exception return wrapper return decorator @async_timer @async_retry(max_attempts=3) async def fetch_data(url): # 模拟网络请求 await asyncio.sleep(0.5) return f"Data from {url}" ``` ## 装饰器执行顺序 ```python def decorator_a(func): print("Applying A") @functools.wraps(func) def wrapper(*args, **kwargs): print("Before A") result = func(*args, **kwargs) print("After A") return result return wrapper def decorator_b(func): print("Applying B") @functools.wraps(func) def wrapper(*args, **kwargs): print("Before B") result = func(*args, **kwargs) print("After B") return result return wrapper @decorator_a # 外层 @decorator_b # 内层 def hello(): print("Hello!") # 输出(定义时): # Applying B <- 从下到上应用 # Applying A hello() # 输出(调用时): # Before A <- 从外到内执行 # Before B # Hello! # After B <- 从内到外返回 # After A ``` ## 最佳实践 ::::{grid} 1 :gutter: 2 :::{grid-item-card} 设计原则 1. **始终使用 `@functools.wraps`** 2. **保持装饰器简单**:复杂逻辑抽取到辅助函数 3. **处理 `*args` 和 `**kwargs`**:支持任意签名 4. **返回值透明**:除非有意修改,否则返回原函数结果 ::: :::{grid-item-card} 调试技巧 ```python # 查看原始函数 decorated_func.__wrapped__ # functools.wraps 自动保留 __wrapped__ import functools def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper @decorator def original(): pass print(original.__wrapped__) # ``` ::: ::::