强曰为道
与天地相似,故不违。知周乎万物,而道济天下,故不过。旁行而不流,乐天知命,故不忧.
文档目录

HTTP/2 与 RPC 精讲教程 / 10 - gRPC 高级特性

第 10 章:gRPC 高级特性

从可用到好用——拦截器、负载均衡、重试、元数据与截止时间


10.1 拦截器(Interceptor)

拦截器是 gRPC 中最强大的中间件机制,允许在 RPC 调用前后插入自定义逻辑。

10.1.1 拦截器类型

类型 作用域 用途
Unary Server 服务端一元 RPC 日志、认证、限流
Stream Server 服务端流式 RPC 日志、认证
Unary Client 客户端一元 RPC 重试、追踪、超时
Stream Client 客户端流式 RPC 追踪

10.1.2 服务端一元拦截器

package main

import (
	"context"
	"log"
	"time"

	"google.golang.org/grpc"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/peer"
	"google.golang.org/grpc/status"
)

// 日志拦截器
func loggingInterceptor(
	ctx context.Context,
	req interface{},
	info *grpc.UnaryServerInfo,
	handler grpc.UnaryHandler,
) (interface{}, error) {
	start := time.Now()

	// 获取客户端信息
	p, _ := peer.FromContext(ctx)
	clientAddr := "unknown"
	if p != nil {
		clientAddr = p.Addr.String()
	}

	// 获取请求元数据
	md, _ := metadata.FromIncomingContext(ctx)
	requestID := ""
	if vals := md.Get("x-request-id"); len(vals) > 0 {
		requestID = vals[0]
	}

	log.Printf("[gRPC] %s %s <- %s (req_id: %s)",
		info.FullMethod, "START", clientAddr, requestID)

	// 调用实际处理器
	resp, err := handler(ctx, req)

	// 记录结果
	duration := time.Since(start)
	st, _ := status.FromError(err)
	log.Printf("[gRPC] %s %s %v (duration: %v)",
		info.FullMethod, st.Code(), st.Message(), duration)

	return resp, err
}

// 认证拦截器
func authInterceptor(
	ctx context.Context,
	req interface{},
	info *grpc.UnaryServerInfo,
	handler grpc.UnaryHandler,
) (interface{}, error) {
	// 跳过健康检查
	if info.FullMethod == "/grpc.health.v1.Health/Check" {
		return handler(ctx, req)
	}

	// 从元数据中提取 token
	md, ok := metadata.FromIncomingContext(ctx)
	if !ok {
		return nil, status.Error(codes.Unauthenticated, "缺少元数据")
	}

	tokens := md.Get("authorization")
	if len(tokens) == 0 {
		return nil, status.Error(codes.Unauthenticated, "缺少认证信息")
	}

	// 验证 token
	userID, err := validateToken(tokens[0])
	if err != nil {
		return nil, status.Error(codes.Unauthenticated, "认证失败")
	}

	// 将用户信息注入上下文
	ctx = context.WithValue(ctx, "user_id", userID)
	return handler(ctx, req)
}

// 限流拦截器
func rateLimitInterceptor(
	ctx context.Context,
	req interface{},
	info *grpc.UnaryServerInfo,
	handler grpc.UnaryHandler,
) (interface{}, error) {
	if !limiter.Allow() {
		return nil, status.Error(codes.ResourceExhausted, "请求过于频繁")
	}
	return handler(ctx, req)
}

// 使用链式拦截器
func main() {
	server := grpc.NewServer(
		grpc.ChainUnaryInterceptor(
			loggingInterceptor,
			authInterceptor,
			rateLimitInterceptor,
		),
	)
	_ = server
}

10.1.3 客户端拦截器

// 客户端重试拦截器
func retryInterceptor(
	ctx context.Context,
	method string,
	req, reply interface{},
	cc *grpc.ClientConn,
	invoker grpc.UnaryInvoker,
	opts ...grpc.CallOption,
) error {
	maxRetries := 3
	var lastErr error

	for attempt := 0; attempt <= maxRetries; attempt++ {
		if attempt > 0 {
			// 指数退避
			backoff := time.Duration(attempt*attempt) * 100 * time.Millisecond
			time.Sleep(backoff)
			log.Printf("重试 %d/%d: %s", attempt, maxRetries, method)
		}

		lastErr = invoker(ctx, method, req, reply, cc, opts...)
		if lastErr == nil {
			return nil
		}

		// 只对特定错误重试
		st, _ := status.FromError(lastErr)
		switch st.Code() {
		case codes.Unavailable, codes.DeadlineExceeded, codes.ResourceExhausted:
			continue // 可重试
		default:
			return lastErr // 不可重试
		}
	}

	return lastErr
}

// 追踪拦截器
func tracingInterceptor(
	ctx context.Context,
	method string,
	req, reply interface{},
	cc *grpc.ClientConn,
	invoker grpc.UnaryInvoker,
	opts ...grpc.CallOption,
) error {
	// 生成追踪 ID
	traceID := generateTraceID()
	ctx = metadata.AppendToOutgoingContext(ctx, "x-trace-id", traceID)

	start := time.Now()
	err := invoker(ctx, method, req, reply, cc, opts...)
	duration := time.Since(start)

	log.Printf("[Trace] %s trace_id=%s duration=%v err=%v",
		method, traceID, duration, err)

	return err
}

func main() {
	conn, err := grpc.Dial("localhost:50051",
		grpc.WithTransportCredentials(insecure.NewCredentials()),
		grpc.WithChainUnaryInterceptor(
			retryInterceptor,
			tracingInterceptor,
		),
	)
	if err != nil {
		log.Fatalf("连接失败: %v", err)
	}
	defer conn.Close()
}

10.2 元数据(Metadata)

10.2.1 元数据基础

// 发送元数据(客户端)
func sendWithMetadata(client pb.UserServiceClient) error {
	// 方式 1:创建元数据
	md := metadata.New(map[string]string{
		"authorization": "Bearer eyJhbG...",
		"x-request-id":  "req-12345",
		"x-client":      "mobile-app",
	})
	ctx := metadata.NewOutgoingContext(context.Background(), md)

	resp, err := client.GetUser(ctx, &pb.GetUserRequest{Id: 1})
	_ = resp
	return err
}

// 方式 2:追加元数据
func appendMetadata(ctx context.Context) context.Context {
	return metadata.AppendToOutgoingContext(ctx,
		"authorization", "Bearer token",
		"x-request-id", generateID(),
	)
}

// 接收元数据(服务端)
func (s *server) GetUser(ctx context.Context, req *pb.GetUserRequest) (*pb.User, error) {
	// 读取请求元数据
	md, ok := metadata.FromIncomingContext(ctx)
	if !ok {
		return nil, status.Error(codes.InvalidArgument, "缺少元数据")
	}

	// 获取特定值
	if authTokens := md.Get("authorization"); len(authTokens) > 0 {
		token := authTokens[0]
		_ = token
	}

	// 发送响应元数据(Header)
	header := metadata.Pairs("x-response-id", "resp-123")
	grpc.SetHeader(ctx, header)

	// 发送响应元数据(Trailer)
	trailer := metadata.Pairs("x-processing-time", "42ms")
	grpc.SetTrailer(ctx, trailer)

	return &pb.User{Id: req.Id, Name: "Alice"}, nil
}

10.2.2 元数据规范

键名前缀 保留方 说明
grpc- gRPC 框架 gRPC 内部保留
无前缀 应用层 双向传递,小写
bin 后缀 二进制值 base64 编码的二进制数据

10.3 截止时间与超时(Deadline & Timeout)

10.3.1 截止时间传播

服务调用链中的截止时间传播:

客户端 A              服务 B              服务 C
  │                     │                   │
  │ deadline=5s        │                   │
  │─── GetUser ────→  │                   │
  │                     │ deadline=4.8s    │
  │                     │── Validate ────→│
  │                     │                  │ 处理...
  │                     │  ←── Response ──│
  │  ←── Response ────│                   │
  │                     │                   │

关键:截止时间在整条链路中自动传播
     如果链路上任何节点超时,整个调用链取消
// 设置截止时间
func callWithDeadline(client pb.UserServiceClient) error {
	// 方式 1:绝对截止时间
	deadline := time.Now().Add(5 * time.Second)
	ctx, cancel := context.WithDeadline(context.Background(), deadline)
	defer cancel()

	resp, err := client.GetUser(ctx, &pb.GetUserRequest{Id: 1})
	if err != nil {
		st, _ := status.FromError(err)
		if st.Code() == codes.DeadlineExceeded {
			log.Println("调用超时")
		}
		return err
	}
	_ = resp
	return nil
}

// 方式 2:相对超时
func callWithTimeout(client pb.UserServiceClient) error {
	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()
	
	resp, err := client.GetUser(ctx, &pb.GetUserRequest{Id: 1})
	_ = resp
	return err
}

// 服务端检查截止时间
func (s *server) GetUser(ctx context.Context, req *pb.GetUserRequest) (*pb.User, error) {
	deadline, ok := ctx.Deadline()
	if ok {
		remaining := time.Until(deadline)
		log.Printf("截止时间剩余: %v", remaining)
		
		if remaining < 100*time.Millisecond {
			return nil, status.Error(codes.DeadlineExceeded, "时间不足以完成操作")
		}
	}

	// 执行耗时操作...
	return &pb.User{Id: req.Id}, nil
}

10.4 负载均衡

10.4.1 客户端负载均衡

package main

import (
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"
	"google.golang.org/grpc/resolver"
	"google.golang.org/grpc/balancer/roundrobin"
)

func main() {
	// 方式 1:使用 DNS 服务发现
	conn, err := grpc.Dial(
		"dns:///user-service.default.svc.cluster.local:50051",
		grpc.WithTransportCredentials(insecure.NewCredentials()),
		grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"round_robin":{}}]}`),
	)
	if err != nil {
		panic(err)
	}
	defer conn.Close()
	_ = conn
}

10.4.2 自定义解析器

package main

import (
	"google.golang.org/grpc/resolver"
)

// 自定义服务发现解析器
type staticResolver struct {
	addresses []resolver.Address
	cc        resolver.ClientConn
}

func (r *staticResolver) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) {
	r.cc = cc
	r.cc.UpdateState(resolver.State{Addresses: r.addresses})
	return r, nil
}

func (r *staticResolver) Scheme() string { return "static" }

func (r *staticResolver) ResolveNow(resolver.ResolveNowOptions) {}
func (r *staticResolver) Close()                                {}

func init() {
	// 注册自定义解析器
	resolver.Register(&staticResolver{
		addresses: []resolver.Address{
			{Addr: "10.0.0.1:50051"},
			{Addr: "10.0.0.2:50051"},
			{Addr: "10.0.0.3:50051"},
		},
	})
}

10.5 重试策略

10.5.1 服务配置式重试

// 通过服务配置启用自动重试
func dialWithRetry() (*grpc.ClientConn, error) {
	serviceConfig := `{
		"methodConfig": [{
			"name": [{"service": "example.UserService"}],
			"retryPolicy": {
				"maxAttempts": 4,
				"initialBackoff": "0.1s",
				"maxBackoff": "1s",
				"backoffMultiplier": 2.0,
				"retryableStatusCodes": ["UNAVAILABLE", "DEADLINE_EXCEEDED"]
			}
		}]
	}`

	return grpc.Dial(
		"localhost:50051",
		grpc.WithTransportCredentials(insecure.NewCredentials()),
		grpc.WithDefaultServiceConfig(serviceConfig),
	)
}

10.5.2 等待就绪(Wait-for-Ready)

func callWithWaitForReady(client pb.UserServiceClient) error {
	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
	defer cancel()

	// Wait-for-Ready:如果连接未就绪,等待而非立即失败
	resp, err := client.GetUser(ctx,
		&pb.GetUserRequest{Id: 1},
		grpc.WaitForReady(true),
	)
	_ = resp
	return err
}

10.6 健康检查

import (
	"google.golang.org/grpc/health"
	"google.golang.org/grpc/health/grpc_health_v1"
)

func setupHealthCheck(server *grpc.Server) {
	healthServer := health.NewServer()
	grpc_health_v1.RegisterHealthServer(server, healthServer)

	// 设置服务状态
	healthServer.SetServingStatus("example.UserService",
		grpc_health_v1.HealthCheckResponse_SERVING)

	// 当服务不可用时
	// healthServer.SetServingStatus("example.UserService",
	//     grpc_health_v1.HealthCheckResponse_NOT_SERVING)
}
# 测试健康检查
grpcurl -plaintext localhost:50051 grpc.health.v1.Health/Check

# 输出:
# {
#   "status": "SERVING"
# }

10.7 性能优化

10.7.1 连接池

package main

import (
	"sync"

	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"
)

type ConnPool struct {
	conns []*grpc.ClientConn
	mu    sync.Mutex
	idx   int
	size  int
}

func NewConnPool(target string, size int) (*ConnPool, error) {
	pool := &ConnPool{
		conns: make([]*grpc.ClientConn, size),
		size:  size,
	}

	for i := 0; i < size; i++ {
		conn, err := grpc.Dial(target,
			grpc.WithTransportCredentials(insecure.NewCredentials()),
			grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"round_robin":{}}]}`),
		)
		if err != nil {
			return nil, err
		}
		pool.conns[i] = conn
	}

	return pool, nil
}

func (p *ConnPool) Get() *grpc.ClientConn {
	p.mu.Lock()
	defer p.mu.Unlock()
	conn := p.conns[p.idx%p.size]
	p.idx++
	return conn
}

func (p *ConnPool) Close() {
	for _, conn := range p.conns {
		conn.Close()
	}
}

10.7.2 压缩

// 服务端启用压缩
server := grpc.NewServer(
	grpc.RPCCompressor(grpc.NewGZIPCompressor()),
	grpc.RPCDecompressor(grpc.NewGZIPDecompressor()),
)

// 客户端请求压缩
resp, err := client.GetUser(ctx,
	&pb.GetUserRequest{Id: 1},
	grpc.UseCompressor(gzip.Name),
)

10.8 注意事项

⚠️ 拦截器顺序

  • 链式拦截器按添加顺序执行
  • 通常顺序:认证 → 日志 → 限流 → 业务

⚠️ 截止时间传播

  • 截止时间在整条调用链中自动传播
  • 服务端应检查上下文的截止时间
  • 适当预留时间给下游调用

⚠️ 连接管理

  • 长期不用的连接可能被中间设备断开
  • 启用 keepalive 机制
  • 合理设置连接池大小

10.9 扩展阅读


第 09 章 - gRPC 流式通信 | 第 11 章 - REST vs gRPC 选型