1.6. 装饰器深入
装饰器是 Python 中最强大的特性之一,理解其原理和常见模式对于写出优雅代码至关重要。
1.6.1. 装饰器本质
装饰器本质上是一个接受函数并返回新函数的高阶函数。
# 装饰器语法糖
@decorator
def func():
pass
# 等价于
def func():
pass
func = decorator(func)
1.6.2. 保持函数元信息
import functools
# ❌ 不使用 @functools.wraps
def bad_decorator(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@bad_decorator
def greet(name):
"""Return a greeting."""
return f"Hello, {name}!"
print(greet.__name__) # wrapper (错误!)
print(greet.__doc__) # None (文档丢失!)
# ✅ 使用 @functools.wraps
def good_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@good_decorator
def greet(name):
"""Return a greeting."""
return f"Hello, {name}!"
print(greet.__name__) # greet (正确!)
print(greet.__doc__) # Return a greeting.
1.6.3. 常见装饰器模式
1.6.3.1. 带参数的装饰器
import functools
def repeat(times):
"""重复执行装饰器"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
result = None
for _ in range(times):
result = func(*args, **kwargs)
return result
return wrapper
return decorator
@repeat(times=3)
def say_hello():
print("Hello!")
say_hello()
# Hello!
# Hello!
# Hello!
1.6.3.2. 可选参数的装饰器
import functools
def log(func=None, *, level='INFO'):
"""
支持两种用法:
@log
def func(): ...
@log(level='DEBUG')
def func(): ...
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
print(f"[{level}] Calling {func.__name__}")
return func(*args, **kwargs)
return wrapper
if func is None:
# @log(level='DEBUG') 方式调用
return decorator
else:
# @log 方式调用
return decorator(func)
@log
def func1():
pass
@log(level='DEBUG')
def func2():
pass
func1() # [INFO] Calling func1
func2() # [DEBUG] Calling func2
1.6.3.3. 带状态的装饰器(类实现)
import functools
class CountCalls:
"""统计函数调用次数"""
def __init__(self, func):
functools.update_wrapper(self, func)
self.func = func
self.count = 0
def __call__(self, *args, **kwargs):
self.count += 1
print(f"Call {self.count} of {self.func.__name__}")
return self.func(*args, **kwargs)
@CountCalls
def fibonacci(n):
if n < 2:
return n
return fibonacci(n - 1) + fibonacci(n - 2)
# 这会显示递归调用次数
result = fibonacci(5)
print(f"fibonacci.count = {fibonacci.count}")
1.6.3.4. 缓存装饰器
import functools
def memoize(func):
"""简单的缓存装饰器"""
cache = {}
@functools.wraps(func)
def wrapper(*args):
if args not in cache:
cache[args] = func(*args)
return cache[args]
wrapper.cache = cache # 暴露缓存以便调试
wrapper.clear_cache = cache.clear
return wrapper
@memoize
def slow_function(n):
import time
time.sleep(1)
return n * 2
print(slow_function(5)) # 等待 1 秒
print(slow_function(5)) # 立即返回(缓存)
# 使用标准库的 lru_cache(推荐)
from functools import lru_cache
@lru_cache(maxsize=128)
def fibonacci(n):
if n < 2:
return n
return fibonacci(n - 1) + fibonacci(n - 2)
print(fibonacci(100)) # 快速计算
print(fibonacci.cache_info()) # 查看缓存统计
fibonacci.cache_clear() # 清除缓存
1.6.3.5. 类型检查装饰器
import functools
import inspect
def typecheck(func):
"""运行时类型检查"""
sig = inspect.signature(func)
@functools.wraps(func)
def wrapper(*args, **kwargs):
# 绑定参数
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
# 检查每个参数
for name, value in bound.arguments.items():
param = sig.parameters[name]
if param.annotation != inspect.Parameter.empty:
expected_type = param.annotation
if not isinstance(value, expected_type):
raise TypeError(
f"Argument '{name}' must be {expected_type.__name__}, "
f"got {type(value).__name__}"
)
result = func(*args, **kwargs)
# 检查返回值
if sig.return_annotation != inspect.Signature.empty:
if not isinstance(result, sig.return_annotation):
raise TypeError(
f"Return value must be {sig.return_annotation.__name__}"
)
return result
return wrapper
@typecheck
def add(a: int, b: int) -> int:
return a + b
print(add(1, 2)) # 3
# add("1", 2) # TypeError: Argument 'a' must be int, got str
1.6.4. 装饰类
import functools
def singleton(cls):
"""单例装饰器"""
instances = {}
@functools.wraps(cls)
def wrapper(*args, **kwargs):
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return wrapper
@singleton
class Database:
def __init__(self):
print("Initializing database connection")
db1 = Database() # Initializing database connection
db2 = Database() # 不会打印(返回同一实例)
print(db1 is db2) # True
1.6.4.1. 添加方法到类
def add_method(cls):
"""为类添加方法"""
def describe(self):
return f"{self.__class__.__name__}: {vars(self)}"
cls.describe = describe
return cls
@add_method
class Person:
def __init__(self, name, age):
self.name = name
self.age = age
p = Person("Alice", 30)
print(p.describe()) # Person: {'name': 'Alice', 'age': 30}
1.6.4.2. 注册模式
class Registry:
"""插件/处理器注册表"""
def __init__(self):
self._registry = {}
def register(self, name=None):
def decorator(cls):
key = name or cls.__name__
self._registry[key] = cls
return cls
return decorator
def get(self, name):
return self._registry.get(name)
def list_all(self):
return list(self._registry.keys())
# 使用示例
handlers = Registry()
@handlers.register("json")
class JSONHandler:
def process(self, data):
return f"JSON: {data}"
@handlers.register("xml")
class XMLHandler:
def process(self, data):
return f"XML: {data}"
print(handlers.list_all()) # ['json', 'xml']
handler = handlers.get("json")()
print(handler.process("data")) # JSON: data
1.6.5. 方法装饰器
1.6.5.1. 静态方法和类方法
class MyClass:
count = 0
def __init__(self, value):
self.value = value
MyClass.count += 1
@staticmethod
def static_method():
"""不访问实例或类状态"""
return "I'm static"
@classmethod
def class_method(cls):
"""访问类状态"""
return f"Instance count: {cls.count}"
@classmethod
def from_string(cls, s):
"""工厂方法"""
value = int(s)
return cls(value)
obj = MyClass.from_string("42")
print(obj.value) # 42
print(MyClass.class_method()) # Instance count: 1
1.6.5.2. property 装饰器
class Temperature:
def __init__(self, celsius=0):
self._celsius = celsius
@property
def celsius(self):
return self._celsius
@celsius.setter
def celsius(self, value):
if value < -273.15:
raise ValueError("Temperature below absolute zero!")
self._celsius = value
@celsius.deleter
def celsius(self):
print("Deleting celsius")
del self._celsius
@property
def fahrenheit(self):
return self._celsius * 9/5 + 32
t = Temperature(25)
print(t.fahrenheit) # 77.0
t.celsius = 30
# t.celsius = -300 # ValueError
del t.celsius # Deleting celsius
1.6.6. 异步装饰器
import asyncio
import functools
def async_timer(func):
"""异步函数计时器"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
import time
start = time.perf_counter()
result = await func(*args, **kwargs)
elapsed = time.perf_counter() - start
print(f"{func.__name__} took {elapsed:.4f}s")
return result
return wrapper
def async_retry(max_attempts=3, delay=1):
"""异步重试装饰器"""
def decorator(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
last_exception = None
for attempt in range(max_attempts):
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
if attempt < max_attempts - 1:
await asyncio.sleep(delay)
raise last_exception
return wrapper
return decorator
@async_timer
@async_retry(max_attempts=3)
async def fetch_data(url):
# 模拟网络请求
await asyncio.sleep(0.5)
return f"Data from {url}"
1.6.7. 装饰器执行顺序
def decorator_a(func):
print("Applying A")
@functools.wraps(func)
def wrapper(*args, **kwargs):
print("Before A")
result = func(*args, **kwargs)
print("After A")
return result
return wrapper
def decorator_b(func):
print("Applying B")
@functools.wraps(func)
def wrapper(*args, **kwargs):
print("Before B")
result = func(*args, **kwargs)
print("After B")
return result
return wrapper
@decorator_a # 外层
@decorator_b # 内层
def hello():
print("Hello!")
# 输出(定义时):
# Applying B <- 从下到上应用
# Applying A
hello()
# 输出(调用时):
# Before A <- 从外到内执行
# Before B
# Hello!
# After B <- 从内到外返回
# After A
1.6.8. 最佳实践
设计原则
始终使用
@functools.wraps保持装饰器简单:复杂逻辑抽取到辅助函数
处理
*args和**kwargs:支持任意签名返回值透明:除非有意修改,否则返回原函数结果
调试技巧
# 查看原始函数
decorated_func.__wrapped__
# functools.wraps 自动保留 __wrapped__
import functools
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@decorator
def original():
pass
print(original.__wrapped__) # <function original at ...>