1.6. 装饰器深入

装饰器是 Python 中最强大的特性之一,理解其原理和常见模式对于写出优雅代码至关重要。

1.6.1. 装饰器本质

装饰器本质上是一个接受函数并返回新函数的高阶函数。

# 装饰器语法糖
@decorator
def func():
    pass

# 等价于
def func():
    pass
func = decorator(func)

1.6.2. 保持函数元信息

import functools

# ❌ 不使用 @functools.wraps
def bad_decorator(func):
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)
    return wrapper

@bad_decorator
def greet(name):
    """Return a greeting."""
    return f"Hello, {name}!"

print(greet.__name__)  # wrapper (错误!)
print(greet.__doc__)   # None (文档丢失!)

# ✅ 使用 @functools.wraps
def good_decorator(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)
    return wrapper

@good_decorator
def greet(name):
    """Return a greeting."""
    return f"Hello, {name}!"

print(greet.__name__)  # greet (正确!)
print(greet.__doc__)   # Return a greeting.

1.6.3. 常见装饰器模式

1.6.3.1. 带参数的装饰器

import functools

def repeat(times):
    """重复执行装饰器"""
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            result = None
            for _ in range(times):
                result = func(*args, **kwargs)
            return result
        return wrapper
    return decorator

@repeat(times=3)
def say_hello():
    print("Hello!")

say_hello()
# Hello!
# Hello!
# Hello!

1.6.3.2. 可选参数的装饰器

import functools

def log(func=None, *, level='INFO'):
    """
    支持两种用法:
    @log
    def func(): ...
    
    @log(level='DEBUG')
    def func(): ...
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            print(f"[{level}] Calling {func.__name__}")
            return func(*args, **kwargs)
        return wrapper
    
    if func is None:
        # @log(level='DEBUG') 方式调用
        return decorator
    else:
        # @log 方式调用
        return decorator(func)

@log
def func1():
    pass

@log(level='DEBUG')
def func2():
    pass

func1()  # [INFO] Calling func1
func2()  # [DEBUG] Calling func2

1.6.3.3. 带状态的装饰器(类实现)

import functools

class CountCalls:
    """统计函数调用次数"""
    
    def __init__(self, func):
        functools.update_wrapper(self, func)
        self.func = func
        self.count = 0
    
    def __call__(self, *args, **kwargs):
        self.count += 1
        print(f"Call {self.count} of {self.func.__name__}")
        return self.func(*args, **kwargs)

@CountCalls
def fibonacci(n):
    if n < 2:
        return n
    return fibonacci(n - 1) + fibonacci(n - 2)

# 这会显示递归调用次数
result = fibonacci(5)
print(f"fibonacci.count = {fibonacci.count}")

1.6.3.4. 缓存装饰器

import functools

def memoize(func):
    """简单的缓存装饰器"""
    cache = {}
    
    @functools.wraps(func)
    def wrapper(*args):
        if args not in cache:
            cache[args] = func(*args)
        return cache[args]
    
    wrapper.cache = cache  # 暴露缓存以便调试
    wrapper.clear_cache = cache.clear
    return wrapper

@memoize
def slow_function(n):
    import time
    time.sleep(1)
    return n * 2

print(slow_function(5))  # 等待 1 秒
print(slow_function(5))  # 立即返回(缓存)

# 使用标准库的 lru_cache(推荐)
from functools import lru_cache

@lru_cache(maxsize=128)
def fibonacci(n):
    if n < 2:
        return n
    return fibonacci(n - 1) + fibonacci(n - 2)

print(fibonacci(100))  # 快速计算
print(fibonacci.cache_info())  # 查看缓存统计
fibonacci.cache_clear()  # 清除缓存

1.6.3.5. 类型检查装饰器

import functools
import inspect

def typecheck(func):
    """运行时类型检查"""
    sig = inspect.signature(func)
    
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        # 绑定参数
        bound = sig.bind(*args, **kwargs)
        bound.apply_defaults()
        
        # 检查每个参数
        for name, value in bound.arguments.items():
            param = sig.parameters[name]
            if param.annotation != inspect.Parameter.empty:
                expected_type = param.annotation
                if not isinstance(value, expected_type):
                    raise TypeError(
                        f"Argument '{name}' must be {expected_type.__name__}, "
                        f"got {type(value).__name__}"
                    )
        
        result = func(*args, **kwargs)
        
        # 检查返回值
        if sig.return_annotation != inspect.Signature.empty:
            if not isinstance(result, sig.return_annotation):
                raise TypeError(
                    f"Return value must be {sig.return_annotation.__name__}"
                )
        
        return result
    
    return wrapper

@typecheck
def add(a: int, b: int) -> int:
    return a + b

print(add(1, 2))  # 3
# add("1", 2)  # TypeError: Argument 'a' must be int, got str

1.6.4. 装饰类

import functools

def singleton(cls):
    """单例装饰器"""
    instances = {}
    
    @functools.wraps(cls)
    def wrapper(*args, **kwargs):
        if cls not in instances:
            instances[cls] = cls(*args, **kwargs)
        return instances[cls]
    
    return wrapper

@singleton
class Database:
    def __init__(self):
        print("Initializing database connection")

db1 = Database()  # Initializing database connection
db2 = Database()  # 不会打印(返回同一实例)
print(db1 is db2)  # True

1.6.4.1. 添加方法到类

def add_method(cls):
    """为类添加方法"""
    def describe(self):
        return f"{self.__class__.__name__}: {vars(self)}"
    
    cls.describe = describe
    return cls

@add_method
class Person:
    def __init__(self, name, age):
        self.name = name
        self.age = age

p = Person("Alice", 30)
print(p.describe())  # Person: {'name': 'Alice', 'age': 30}

1.6.4.2. 注册模式

class Registry:
    """插件/处理器注册表"""
    
    def __init__(self):
        self._registry = {}
    
    def register(self, name=None):
        def decorator(cls):
            key = name or cls.__name__
            self._registry[key] = cls
            return cls
        return decorator
    
    def get(self, name):
        return self._registry.get(name)
    
    def list_all(self):
        return list(self._registry.keys())

# 使用示例
handlers = Registry()

@handlers.register("json")
class JSONHandler:
    def process(self, data):
        return f"JSON: {data}"

@handlers.register("xml")
class XMLHandler:
    def process(self, data):
        return f"XML: {data}"

print(handlers.list_all())  # ['json', 'xml']
handler = handlers.get("json")()
print(handler.process("data"))  # JSON: data

1.6.5. 方法装饰器

1.6.5.1. 静态方法和类方法

class MyClass:
    count = 0
    
    def __init__(self, value):
        self.value = value
        MyClass.count += 1
    
    @staticmethod
    def static_method():
        """不访问实例或类状态"""
        return "I'm static"
    
    @classmethod
    def class_method(cls):
        """访问类状态"""
        return f"Instance count: {cls.count}"
    
    @classmethod
    def from_string(cls, s):
        """工厂方法"""
        value = int(s)
        return cls(value)

obj = MyClass.from_string("42")
print(obj.value)  # 42
print(MyClass.class_method())  # Instance count: 1

1.6.5.2. property 装饰器

class Temperature:
    def __init__(self, celsius=0):
        self._celsius = celsius
    
    @property
    def celsius(self):
        return self._celsius
    
    @celsius.setter
    def celsius(self, value):
        if value < -273.15:
            raise ValueError("Temperature below absolute zero!")
        self._celsius = value
    
    @celsius.deleter
    def celsius(self):
        print("Deleting celsius")
        del self._celsius
    
    @property
    def fahrenheit(self):
        return self._celsius * 9/5 + 32

t = Temperature(25)
print(t.fahrenheit)  # 77.0
t.celsius = 30
# t.celsius = -300  # ValueError
del t.celsius  # Deleting celsius

1.6.6. 异步装饰器

import asyncio
import functools

def async_timer(func):
    """异步函数计时器"""
    @functools.wraps(func)
    async def wrapper(*args, **kwargs):
        import time
        start = time.perf_counter()
        result = await func(*args, **kwargs)
        elapsed = time.perf_counter() - start
        print(f"{func.__name__} took {elapsed:.4f}s")
        return result
    return wrapper

def async_retry(max_attempts=3, delay=1):
    """异步重试装饰器"""
    def decorator(func):
        @functools.wraps(func)
        async def wrapper(*args, **kwargs):
            last_exception = None
            for attempt in range(max_attempts):
                try:
                    return await func(*args, **kwargs)
                except Exception as e:
                    last_exception = e
                    if attempt < max_attempts - 1:
                        await asyncio.sleep(delay)
            raise last_exception
        return wrapper
    return decorator

@async_timer
@async_retry(max_attempts=3)
async def fetch_data(url):
    # 模拟网络请求
    await asyncio.sleep(0.5)
    return f"Data from {url}"

1.6.7. 装饰器执行顺序

def decorator_a(func):
    print("Applying A")
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        print("Before A")
        result = func(*args, **kwargs)
        print("After A")
        return result
    return wrapper

def decorator_b(func):
    print("Applying B")
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        print("Before B")
        result = func(*args, **kwargs)
        print("After B")
        return result
    return wrapper

@decorator_a  # 外层
@decorator_b  # 内层
def hello():
    print("Hello!")

# 输出(定义时):
# Applying B  <- 从下到上应用
# Applying A

hello()
# 输出(调用时):
# Before A  <- 从外到内执行
# Before B
# Hello!
# After B  <- 从内到外返回
# After A

1.6.8. 最佳实践

设计原则
  1. 始终使用 @functools.wraps

  2. 保持装饰器简单:复杂逻辑抽取到辅助函数

  3. 处理 *args**kwargs:支持任意签名

  4. 返回值透明:除非有意修改,否则返回原函数结果

调试技巧
# 查看原始函数
decorated_func.__wrapped__

# functools.wraps 自动保留 __wrapped__
import functools

def decorator(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)
    return wrapper

@decorator
def original():
    pass

print(original.__wrapped__)  # <function original at ...>