4.2. 性能陷阱

了解这些性能陷阱,写出更高效的 Python 代码。

4.2.1. 循环优化

4.2.1.1. 避免在循环中重复计算

# ❌ 每次迭代都计算长度
items = list(range(10000))
for i in range(len(items)):  # len() 在每次迭代时调用
    process(items[i])

# ✅ 缓存长度
length = len(items)
for i in range(length):
    process(items[i])

# ✅ 更好:直接迭代
for item in items:
    process(item)

# ❌ 循环中查找方法
for item in items:
    result.append(item)  # 每次查找 append 方法

# ✅ 缓存方法引用
append = result.append
for item in items:
    append(item)

4.2.1.2. 列表推导 vs 循环

import timeit

# ❌ 传统循环
def with_loop():
    result = []
    for i in range(1000):
        result.append(i ** 2)
    return result

# ✅ 列表推导(更快)
def with_comprehension():
    return [i ** 2 for i in range(1000)]

# 性能对比
print(timeit.timeit(with_loop, number=10000))
print(timeit.timeit(with_comprehension, number=10000))
# 列表推导通常快 20-30%

4.2.1.3. 使用生成器节省内存

# ❌ 创建完整列表
def get_squares_list(n):
    return [i ** 2 for i in range(n)]

# 10 million items = ~80MB 内存
squares = get_squares_list(10_000_000)

# ✅ 使用生成器
def get_squares_gen(n):
    return (i ** 2 for i in range(n))

# 几乎不占内存
squares = get_squares_gen(10_000_000)

4.2.2. 数据结构选择

4.2.2.1. 列表 vs 集合查找

import timeit

items_list = list(range(10000))
items_set = set(range(10000))

# ❌ 列表查找 O(n)
def search_list():
    return 9999 in items_list

# ✅ 集合查找 O(1)
def search_set():
    return 9999 in items_set

print(timeit.timeit(search_list, number=10000))  # 慢
print(timeit.timeit(search_set, number=10000))    # 快 100+ 倍

4.2.2.2. 合适的数据结构

from collections import deque, Counter, defaultdict
import heapq

# 需要频繁在两端插入/删除 → deque
queue = deque()
queue.append(1)      # 右端添加 O(1)
queue.appendleft(0)  # 左端添加 O(1)
queue.pop()          # 右端删除 O(1)
queue.popleft()      # 左端删除 O(1)

# 需要计数 → Counter
words = ['apple', 'banana', 'apple', 'cherry', 'banana', 'apple']
counts = Counter(words)
print(counts.most_common(2))  # [('apple', 3), ('banana', 2)]

# 需要默认值 → defaultdict
groups = defaultdict(list)
for item in items:
    groups[item.category].append(item)

# 需要保持排序取最小/最大 → heapq
numbers = [3, 1, 4, 1, 5, 9, 2, 6]
heapq.heapify(numbers)
print(heapq.heappop(numbers))  # 1 (最小值)

4.2.2.3. 字符串拼接

import timeit

# ❌ 用 + 拼接(O(n²))
def concat_plus():
    result = ""
    for i in range(1000):
        result += str(i)
    return result

# ✅ 用 join(O(n))
def concat_join():
    return "".join(str(i) for i in range(1000))

# ✅ 用 f-string(适合少量)
def concat_fstring():
    a, b, c = "hello", "world", "!"
    return f"{a} {b}{c}"

print(timeit.timeit(concat_plus, number=1000))  # 慢
print(timeit.timeit(concat_join, number=1000))  # 快很多

4.2.3. 函数调用开销

4.2.3.1. 避免不必要的函数调用

import timeit

# ❌ 频繁调用小函数
def is_even(n):
    return n % 2 == 0

def filter_evens_func(numbers):
    return [n for n in numbers if is_even(n)]

# ✅ 内联简单操作
def filter_evens_inline(numbers):
    return [n for n in numbers if n % 2 == 0]

numbers = list(range(10000))
print(timeit.timeit(lambda: filter_evens_func(numbers), number=1000))
print(timeit.timeit(lambda: filter_evens_inline(numbers), number=1000))

4.2.3.2. 使用内置函数

import timeit

numbers = list(range(10000))

# ❌ 自己实现
def my_sum(numbers):
    total = 0
    for n in numbers:
        total += n
    return total

# ✅ 使用内置函数(C 实现,更快)
# sum(), max(), min(), len(), sorted(), ...

print(timeit.timeit(lambda: my_sum(numbers), number=1000))
print(timeit.timeit(lambda: sum(numbers), number=1000))
# 内置 sum() 快 5-10 倍

4.2.4. I/O 优化

4.2.4.1. 批量读写

# ❌ 逐行写入
with open('output.txt', 'w') as f:
    for line in lines:
        f.write(line + '\n')  # 每行一次 I/O

# ✅ 批量写入
with open('output.txt', 'w') as f:
    f.writelines(line + '\n' for line in lines)

# ✅ 或使用缓冲
with open('output.txt', 'w', buffering=1024*1024) as f:  # 1MB 缓冲
    for line in lines:
        f.write(line + '\n')

4.2.4.2. 使用 pickle 的正确方式

import pickle

# ❌ 使用默认协议
with open('data.pkl', 'wb') as f:
    pickle.dump(data, f)

# ✅ 使用最高协议(更快、更小)
with open('data.pkl', 'wb') as f:
    pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)

4.2.5. 数值计算

4.2.5.1. 使用 NumPy

import numpy as np
import timeit

# ❌ 纯 Python
def python_sum_squares():
    return sum(x ** 2 for x in range(1000000))

# ✅ NumPy(快 10-100 倍)
arr = np.arange(1000000)
def numpy_sum_squares():
    return np.sum(arr ** 2)

print(timeit.timeit(python_sum_squares, number=10))
print(timeit.timeit(numpy_sum_squares, number=10))

4.2.5.2. 避免不必要的类型转换

import numpy as np

# ❌ 在 NumPy 和 Python 之间转换
arr = np.array([1, 2, 3, 4, 5])
result = []
for x in arr:  # 每次迭代都进行类型转换
    result.append(x * 2)

# ✅ 保持在 NumPy 中操作
result = arr * 2

4.2.6. 缓存

4.2.6.1. 使用 functools.lru_cache

from functools import lru_cache
import timeit

# ❌ 重复计算
def fibonacci(n):
    if n < 2:
        return n
    return fibonacci(n-1) + fibonacci(n-2)

# ✅ 使用缓存
@lru_cache(maxsize=128)
def fibonacci_cached(n):
    if n < 2:
        return n
    return fibonacci_cached(n-1) + fibonacci_cached(n-2)

# fibonacci(35) 需要几秒
# fibonacci_cached(35) 几乎瞬间

4.2.6.2. 自定义缓存

from functools import wraps
from time import time

def timed_cache(seconds):
    """带过期时间的缓存"""
    def decorator(func):
        cache = {}
        
        @wraps(func)
        def wrapper(*args):
            now = time()
            if args in cache:
                result, timestamp = cache[args]
                if now - timestamp < seconds:
                    return result
            
            result = func(*args)
            cache[args] = (result, now)
            return result
        
        wrapper.cache_clear = cache.clear
        return wrapper
    return decorator

@timed_cache(seconds=60)
def expensive_operation(x):
    # 复杂计算
    return x ** 2

4.2.7. 性能分析工具

4.2.7.1. 使用 cProfile

import cProfile
import pstats

# 分析函数
cProfile.run('my_function()', 'output.prof')

# 读取并排序结果
stats = pstats.Stats('output.prof')
stats.sort_stats('cumulative')
stats.print_stats(10)  # 打印前 10 个

# 或使用装饰器
def profile(func):
    def wrapper(*args, **kwargs):
        profiler = cProfile.Profile()
        try:
            return profiler.runcall(func, *args, **kwargs)
        finally:
            stats = pstats.Stats(profiler)
            stats.sort_stats('cumulative')
            stats.print_stats(20)
    return wrapper

4.2.7.2. 使用 line_profiler

# pip install line_profiler

@profile
def slow_function():
    result = []
    for i in range(10000):
        result.append(i ** 2)
    return sum(result)

# 运行: kernprof -l -v script.py

4.2.8. 最佳实践

优化原则
  1. 先测量,再优化:使用 profiler 找出真正的瓶颈

  2. 选择正确的数据结构:影响大于代码微优化

  3. 使用内置函数和库:它们是 C 实现的

  4. 减少函数调用:在关键路径上

  5. 使用缓存:避免重复计算

  6. 批量操作:减少 I/O 次数

常见错误
# ❌ 过早优化
# 在没有测量的情况下进行复杂优化

# ❌ 优化错误的地方
# 花大量时间优化只占 1% 时间的代码

# ❌ 牺牲可读性
# 除非必要,代码清晰比微小性能提升更重要