强曰为道

与天地相似,故不违。知周乎万物,而道济天下,故不过。旁行而不流,乐天知命,故不忧.
文档目录

06 - 函数

第 06 章:函数

函数是 Python 的一等公民,掌握参数、返回值、装饰器、闭包和作用域。


6.1 函数基础

6.1.1 定义函数

def greet(name: str) -> str:
    """生成问候语。

    Args:
        name: 用户姓名。

    Returns:
        问候字符串。
    """
    return f"Hello, {name}!"

print(greet("Alice"))  # Hello, Alice!

6.1.2 函数是对象

# 函数可以赋值给变量
say_hello = greet
print(say_hello("Bob"))  # Hello, Bob!

# 函数可以作为参数传递
def apply(func, value):
    return func(value)

print(apply(greet, "Charlie"))  # Hello, Charlie!

# 函数可以存储在容器中
operations = {
    "greet": greet,
    "upper": str.upper,
}
print(operations["greet"]("Dave"))

6.2 参数

6.2.1 参数类型

def func(a, b, *args, c, d, **kwargs):
    pass
# a, b: 位置参数(positional)
# *args: 可变位置参数
# c, d: 仅关键字参数(keyword-only)
# **kwargs: 可变关键字参数

6.2.2 位置参数与关键字参数

def create_user(name: str, age: int, city: str = "北京") -> dict:
    return {"name": name, "age": age, "city": city}

# 位置调用
create_user("Alice", 30, "上海")

# 关键字调用
create_user(name="Alice", age=30, city="上海")

# 混合调用(关键字参数后不能再有位置参数)
create_user("Alice", 30, city="上海")

6.2.3 默认参数

# ✅ 正确:不可变对象作为默认值
def append_to(element: int, target: list[int] | None = None) -> list[int]:
    if target is None:
        target = []
    target.append(element)
    return target

# ❌ 错误:可变对象作为默认值
def bad_append(element: int, target: list[int] = []) -> list[int]:
    target.append(element)
    return target

print(bad_append(1))  # [1]
print(bad_append(2))  # [1, 2] ← 默认列表被共享!

🔴 注意:默认参数在函数定义时求值,而非调用时。永远不要使用可变对象作为默认值。

6.2.4 *args 和 **kwargs

def summary(*args: int, **kwargs: str) -> None:
    print(f"位置参数: {args}")
    print(f"关键字参数: {kwargs}")

summary(1, 2, 3, name="Alice", city="北京")
# 位置参数: (1, 2, 3)
# 关键字参数: {'name': 'Alice', 'city': '北京'}

6.2.5 参数解包

def add(a: int, b: int, c: int) -> int:
    return a + b + c

# 列表解包
nums = [1, 2, 3]
print(add(*nums))  # 6

# 字典解包
params = {"a": 1, "b": 2, "c": 3}
print(add(**params))  # 6

6.2.6 仅限位置参数(Python 3.8+)

def func(a, b, /, c, *, d):
    pass

# a, b: 仅限位置参数(只能用位置传递)
# c: 位置或关键字
# d: 仅限关键字参数

func(1, 2, 3, d=4)      # ✅
func(1, 2, c=3, d=4)    # ✅
# func(a=1, b=2, c=3, d=4)  # ❌ TypeError

6.3 返回值

6.3.1 单个返回值

def square(x: int) -> int:
    return x ** 2

6.3.2 多个返回值(元组解包)

def min_max(numbers: list[int]) -> tuple[int, int]:
    return min(numbers), max(numbers)

lo, hi = min_max([3, 1, 4, 1, 5, 9])
print(f"最小: {lo}, 最大: {hi}")  # 最小: 1, 最大: 9

6.3.3 提前返回

def validate_age(age: int) -> str:
    if age < 0:
        return "年龄不能为负数"
    if age > 150:
        return "年龄不合理"
    if age < 18:
        return "未成年"
    return "成年"

6.3.4 无显式返回

def log_message(message: str) -> None:
    print(f"[LOG] {message}")
    # 隐式返回 None

result = log_message("test")
print(result)  # None

6.4 Lambda 表达式

# 语法:lambda arguments: expression
square = lambda x: x ** 2
print(square(5))  # 25

# 常见用法:作为排序的 key
students = [("Alice", 85), ("Bob", 92), ("Charlie", 78)]
students.sort(key=lambda s: s[1], reverse=True)

# 与 map/filter 配合
numbers = [1, 2, 3, 4, 5]
squared = list(map(lambda x: x**2, numbers))
evens = list(filter(lambda x: x % 2 == 0, numbers))

# 推荐:使用列表推导替代 map/filter
squared = [x**2 for x in numbers]
evens = [x for x in numbers if x % 2 == 0]

6.5 作用域(LEGB 规则)

# LEGB: Local → Enclosing → Global → Built-in

x = "全局"

def outer():
    x = "外层"
    
    def inner():
        x = "内层"
        print(x)  # 内层
    
    inner()
    print(x)  # 外层

outer()
print(x)  # 全局

6.5.1 global 和 nonlocal

count = 0

def increment():
    global count  # 声明使用全局变量
    count += 1

increment()
print(count)  # 1

def outer():
    value = 0
    
    def inner():
        nonlocal value  # 声明使用外层变量
        value += 1
    
    inner()
    print(value)  # 1

outer()

6.6 闭包(Closure)

def make_multiplier(factor: int):
    """返回一个乘法函数。"""
    def multiplier(x: int) -> int:
        return x * factor
    return multiplier

double = make_multiplier(2)
triple = make_multiplier(3)

print(double(5))  # 10
print(triple(5))  # 15

# 闭包捕获的是变量的引用,而非值
def make_counter():
    count = 0
    def counter():
        nonlocal count
        count += 1
        return count
    return counter

c = make_counter()
print(c())  # 1
print(c())  # 2
print(c())  # 3

📌 业务场景:带缓存的函数工厂

def make_validator(min_val: float, max_val: float, field_name: str = "value"):
    """创建带范围验证的函数。"""
    def validate(x: float) -> float:
        if not min_val <= x <= max_val:
            raise ValueError(
                f"{field_name} 必须在 {min_val}{max_val} 之间,收到 {x}"
            )
        return x
    return validate

validate_age = make_validator(0, 150, "年龄")
validate_score = make_validator(0, 100, "分数")

print(validate_age(25))    # 25
print(validate_score(85))  # 85
# validate_age(200)  # ValueError: 年龄必须在 0 和 150 之间

6.7 装饰器(Decorator)

6.7.1 基本装饰器

import time
from functools import wraps

def timer(func):
    """计时装饰器。"""
    @wraps(func)  # 保留原函数的元数据
    def wrapper(*args, **kwargs):
        start = time.perf_counter()
        result = func(*args, **kwargs)
        elapsed = time.perf_counter() - start
        print(f"{func.__name__} 耗时: {elapsed:.4f} 秒")
        return result
    return wrapper

@timer
def slow_function():
    time.sleep(1)
    return "完成"

result = slow_function()
# slow_function 耗时: 1.0012 秒

6.7.2 带参数的装饰器

def retry(max_attempts: int = 3, delay: float = 1.0):
    """重试装饰器。"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            for attempt in range(1, max_attempts + 1):
                try:
                    return func(*args, **kwargs)
                except Exception as e:
                    if attempt == max_attempts:
                        raise
                    print(f"尝试 {attempt} 失败: {e}{delay}秒后重试...")
                    time.sleep(delay)
        return wrapper
    return decorator

@retry(max_attempts=3, delay=0.5)
def fetch_data(url: str) -> dict:
    import requests
    response = requests.get(url, timeout=5)
    response.raise_for_status()
    return response.json()

6.7.3 类装饰器

class Cache:
    """简单缓存装饰器类。"""
    def __init__(self, func):
        self.func = func
        self.cache = {}

    def __call__(self, *args):
        if args not in self.cache:
            self.cache[args] = self.func(*args)
        return self.cache[args]

@Cache
def expensive_calculation(n: int) -> int:
    print(f"计算 {n}...")
    return sum(i**2 for i in range(n))

print(expensive_calculation(1000))  # 计算 1000...
print(expensive_calculation(1000))  # 直接返回缓存

6.7.4 functools 常用装饰器

from functools import lru_cache, cache, partial

# LRU 缓存
@lru_cache(maxsize=128)
def fibonacci(n: int) -> int:
    if n < 2:
        return n
    return fibonacci(n - 1) + fibonacci(n - 2)

# Python 3.9+ 无限缓存
@cache
def expensive(n: int) -> int:
    return sum(range(n))

# partial:部分应用
def power(base: int, exp: int) -> int:
    return base ** exp

square = partial(power, exp=2)
cube = partial(power, exp=3)
print(square(5))  # 25
print(cube(3))    # 27

6.8 类型注解

from typing import Callable, TypeVar, ParamSpec

# 函数类型注解
def apply(func: Callable[[int], int], x: int) -> int:
    return func(x)

# 泛型函数
T = TypeVar("T")

def first(items: list[T]) -> T | None:
    return items[0] if items else None

# Python 3.10+ 更简洁的写法
def process(data: list[dict[str, int]]) -> dict[str, float]:
    return {"avg": sum(d["value"] for d in data) / len(data)}

6.9 注意事项

🔴 注意

  • 默认参数只在定义时求值一次,不要使用可变对象作为默认值
  • 使用 @wraps(func) 保留被装饰函数的元数据
  • 递归函数注意递归深度限制(默认 1000 层)
  • globalnonlocal 少用为佳,优先使用参数传递

💡 提示

  • 函数签名是文档的一部分,使用类型注解
  • 短小的单行函数可以使用 lambda,复杂逻辑请用 def
  • functools.lru_cache 可以显著加速重复计算
  • 闭包是实现装饰器、工厂模式、回调的基础

📌 业务场景

# API 请求的认证装饰器
from functools import wraps
from typing import Callable, Any

def require_auth(get_token: Callable[[], str]):
    """为 API 请求添加认证头。"""
    def decorator(func: Callable) -> Callable:
        @wraps(func)
        def wrapper(*args, **kwargs) -> Any:
            token = get_token()
            kwargs["headers"] = {**kwargs.get("headers", {}), "Authorization": f"Bearer {token}"}
            return func(*args, **kwargs)
        return wrapper
    return decorator

# 使用
get_current_token = lambda: "eyJhbGciOiJIUzI1NiIs..."

@require_auth(get_current_token)
def fetch_orders(url: str, **kwargs) -> dict:
    import requests
    return requests.get(url, **kwargs).json()

6.10 扩展阅读