4.2. 性能陷阱
了解这些性能陷阱,写出更高效的 Python 代码。
4.2.1. 循环优化
4.2.1.1. 避免在循环中重复计算
# ❌ 每次迭代都计算长度
items = list(range(10000))
for i in range(len(items)): # len() 在每次迭代时调用
process(items[i])
# ✅ 缓存长度
length = len(items)
for i in range(length):
process(items[i])
# ✅ 更好:直接迭代
for item in items:
process(item)
# ❌ 循环中查找方法
for item in items:
result.append(item) # 每次查找 append 方法
# ✅ 缓存方法引用
append = result.append
for item in items:
append(item)
4.2.1.2. 列表推导 vs 循环
import timeit
# ❌ 传统循环
def with_loop():
result = []
for i in range(1000):
result.append(i ** 2)
return result
# ✅ 列表推导(更快)
def with_comprehension():
return [i ** 2 for i in range(1000)]
# 性能对比
print(timeit.timeit(with_loop, number=10000))
print(timeit.timeit(with_comprehension, number=10000))
# 列表推导通常快 20-30%
4.2.1.3. 使用生成器节省内存
# ❌ 创建完整列表
def get_squares_list(n):
return [i ** 2 for i in range(n)]
# 10 million items = ~80MB 内存
squares = get_squares_list(10_000_000)
# ✅ 使用生成器
def get_squares_gen(n):
return (i ** 2 for i in range(n))
# 几乎不占内存
squares = get_squares_gen(10_000_000)
4.2.2. 数据结构选择
4.2.2.1. 列表 vs 集合查找
import timeit
items_list = list(range(10000))
items_set = set(range(10000))
# ❌ 列表查找 O(n)
def search_list():
return 9999 in items_list
# ✅ 集合查找 O(1)
def search_set():
return 9999 in items_set
print(timeit.timeit(search_list, number=10000)) # 慢
print(timeit.timeit(search_set, number=10000)) # 快 100+ 倍
4.2.2.2. 合适的数据结构
from collections import deque, Counter, defaultdict
import heapq
# 需要频繁在两端插入/删除 → deque
queue = deque()
queue.append(1) # 右端添加 O(1)
queue.appendleft(0) # 左端添加 O(1)
queue.pop() # 右端删除 O(1)
queue.popleft() # 左端删除 O(1)
# 需要计数 → Counter
words = ['apple', 'banana', 'apple', 'cherry', 'banana', 'apple']
counts = Counter(words)
print(counts.most_common(2)) # [('apple', 3), ('banana', 2)]
# 需要默认值 → defaultdict
groups = defaultdict(list)
for item in items:
groups[item.category].append(item)
# 需要保持排序取最小/最大 → heapq
numbers = [3, 1, 4, 1, 5, 9, 2, 6]
heapq.heapify(numbers)
print(heapq.heappop(numbers)) # 1 (最小值)
4.2.2.3. 字符串拼接
import timeit
# ❌ 用 + 拼接(O(n²))
def concat_plus():
result = ""
for i in range(1000):
result += str(i)
return result
# ✅ 用 join(O(n))
def concat_join():
return "".join(str(i) for i in range(1000))
# ✅ 用 f-string(适合少量)
def concat_fstring():
a, b, c = "hello", "world", "!"
return f"{a} {b}{c}"
print(timeit.timeit(concat_plus, number=1000)) # 慢
print(timeit.timeit(concat_join, number=1000)) # 快很多
4.2.3. 函数调用开销
4.2.3.1. 避免不必要的函数调用
import timeit
# ❌ 频繁调用小函数
def is_even(n):
return n % 2 == 0
def filter_evens_func(numbers):
return [n for n in numbers if is_even(n)]
# ✅ 内联简单操作
def filter_evens_inline(numbers):
return [n for n in numbers if n % 2 == 0]
numbers = list(range(10000))
print(timeit.timeit(lambda: filter_evens_func(numbers), number=1000))
print(timeit.timeit(lambda: filter_evens_inline(numbers), number=1000))
4.2.3.2. 使用内置函数
import timeit
numbers = list(range(10000))
# ❌ 自己实现
def my_sum(numbers):
total = 0
for n in numbers:
total += n
return total
# ✅ 使用内置函数(C 实现,更快)
# sum(), max(), min(), len(), sorted(), ...
print(timeit.timeit(lambda: my_sum(numbers), number=1000))
print(timeit.timeit(lambda: sum(numbers), number=1000))
# 内置 sum() 快 5-10 倍
4.2.4. I/O 优化
4.2.4.1. 批量读写
# ❌ 逐行写入
with open('output.txt', 'w') as f:
for line in lines:
f.write(line + '\n') # 每行一次 I/O
# ✅ 批量写入
with open('output.txt', 'w') as f:
f.writelines(line + '\n' for line in lines)
# ✅ 或使用缓冲
with open('output.txt', 'w', buffering=1024*1024) as f: # 1MB 缓冲
for line in lines:
f.write(line + '\n')
4.2.4.2. 使用 pickle 的正确方式
import pickle
# ❌ 使用默认协议
with open('data.pkl', 'wb') as f:
pickle.dump(data, f)
# ✅ 使用最高协议(更快、更小)
with open('data.pkl', 'wb') as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
4.2.5. 数值计算
4.2.5.1. 使用 NumPy
import numpy as np
import timeit
# ❌ 纯 Python
def python_sum_squares():
return sum(x ** 2 for x in range(1000000))
# ✅ NumPy(快 10-100 倍)
arr = np.arange(1000000)
def numpy_sum_squares():
return np.sum(arr ** 2)
print(timeit.timeit(python_sum_squares, number=10))
print(timeit.timeit(numpy_sum_squares, number=10))
4.2.5.2. 避免不必要的类型转换
import numpy as np
# ❌ 在 NumPy 和 Python 之间转换
arr = np.array([1, 2, 3, 4, 5])
result = []
for x in arr: # 每次迭代都进行类型转换
result.append(x * 2)
# ✅ 保持在 NumPy 中操作
result = arr * 2
4.2.6. 缓存
4.2.6.1. 使用 functools.lru_cache
from functools import lru_cache
import timeit
# ❌ 重复计算
def fibonacci(n):
if n < 2:
return n
return fibonacci(n-1) + fibonacci(n-2)
# ✅ 使用缓存
@lru_cache(maxsize=128)
def fibonacci_cached(n):
if n < 2:
return n
return fibonacci_cached(n-1) + fibonacci_cached(n-2)
# fibonacci(35) 需要几秒
# fibonacci_cached(35) 几乎瞬间
4.2.6.2. 自定义缓存
from functools import wraps
from time import time
def timed_cache(seconds):
"""带过期时间的缓存"""
def decorator(func):
cache = {}
@wraps(func)
def wrapper(*args):
now = time()
if args in cache:
result, timestamp = cache[args]
if now - timestamp < seconds:
return result
result = func(*args)
cache[args] = (result, now)
return result
wrapper.cache_clear = cache.clear
return wrapper
return decorator
@timed_cache(seconds=60)
def expensive_operation(x):
# 复杂计算
return x ** 2
4.2.7. 性能分析工具
4.2.7.1. 使用 cProfile
import cProfile
import pstats
# 分析函数
cProfile.run('my_function()', 'output.prof')
# 读取并排序结果
stats = pstats.Stats('output.prof')
stats.sort_stats('cumulative')
stats.print_stats(10) # 打印前 10 个
# 或使用装饰器
def profile(func):
def wrapper(*args, **kwargs):
profiler = cProfile.Profile()
try:
return profiler.runcall(func, *args, **kwargs)
finally:
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
stats.print_stats(20)
return wrapper
4.2.7.2. 使用 line_profiler
# pip install line_profiler
@profile
def slow_function():
result = []
for i in range(10000):
result.append(i ** 2)
return sum(result)
# 运行: kernprof -l -v script.py
4.2.8. 最佳实践
优化原则
先测量,再优化:使用 profiler 找出真正的瓶颈
选择正确的数据结构:影响大于代码微优化
使用内置函数和库:它们是 C 实现的
减少函数调用:在关键路径上
使用缓存:避免重复计算
批量操作:减少 I/O 次数
常见错误
# ❌ 过早优化
# 在没有测量的情况下进行复杂优化
# ❌ 优化错误的地方
# 花大量时间优化只占 1% 时间的代码
# ❌ 牺牲可读性
# 除非必要,代码清晰比微小性能提升更重要