Go 语言完全指南 / 13 - 泛型:类型参数、约束、泛型函数、泛型数据结构
13 - 泛型(Generics)
13.1 泛型简介
Go 1.18 引入了泛型,允许编写适用于多种类型的通用代码。
// Go 1.18 之前:需要为每种类型写重复代码
func sumInts(nums []int) int {
total := 0
for _, n := range nums {
total += n
}
return total
}
func sumFloats(nums []float64) float64 {
total := 0.0
for _, n := range nums {
total += n
}
return total
}
// Go 1.18+:用泛型统一处理
func sum[T Number](nums []T) T {
var total T
for _, n := range nums {
total += n
}
return total
}
13.2 类型参数和约束
package main
import (
"fmt"
)
// 内置约束
// any - 任意类型
// comparable - 可比较类型(支持 == 和 !=)
// ~int - 底层类型为 int 的所有类型
// int | string - int 或 string
// 自定义约束
type Number interface {
~int | ~int8 | ~int16 | ~int32 | ~int64 |
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
~float32 | ~float64
}
type Signed interface {
~int | ~int8 | ~int16 | ~int32 | ~int64
}
type Unsigned interface {
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64
}
// 约束中嵌入接口
type Ordered interface {
Number | ~string
}
func main() {
fmt.Println(sum([]int{1, 2, 3, 4, 5})) // 15
fmt.Println(sum([]float64{1.1, 2.2, 3.3})) // 6.6
fmt.Println(sum([]int64{100, 200, 300})) // 600
}
13.3 泛型函数
package main
import (
"fmt"
"strings"
)
// 基本泛型函数
func max[T Ordered](a, b T) T {
if a > b {
return a
}
return b
}
func min[T Ordered](a, b T) T {
if a < b {
return a
}
return b
}
// 多类型参数
func contains[T comparable](slice []T, target T) bool {
for _, v := range slice {
if v == target {
return true
}
}
return false
}
// Map 函数
func Map[T, U any](slice []T, fn func(T) U) []U {
result := make([]U, len(slice))
for i, v := range slice {
result[i] = fn(v)
}
return result
}
// Filter 函数
func Filter[T any](slice []T, predicate func(T) bool) []T {
var result []T
for _, v := range slice {
if predicate(v) {
result = append(result, v)
}
}
return result
}
// Reduce 函数
func Reduce[T, U any](slice []T, initial U, fn func(U, T) U) U {
result := initial
for _, v := range slice {
result = fn(result, v)
}
return result
}
func main() {
fmt.Println(max(3, 5)) // 5
fmt.Println(max(3.14, 2.71)) // 3.14
fmt.Println(max("apple", "banana")) // banana
nums := []int{1, 2, 3, 4, 5}
fmt.Println(contains(nums, 3)) // true
fmt.Println(contains(nums, 6)) // false
// Map
doubled := Map(nums, func(n int) int { return n * 2 })
fmt.Println(doubled) // [2 4 6 8 10]
strings := Map(nums, func(n int) string { return fmt.Sprintf("#%d", n) })
fmt.Println(strings) // [#1 #2 #3 #4 #5]
// Filter
evens := Filter(nums, func(n int) bool { return n%2 == 0 })
fmt.Println(evens) // [2 4]
// Reduce
sum := Reduce(nums, 0, func(acc, n int) int { return acc + n })
fmt.Println(sum) // 15
}
13.4 泛型类型
package main
import "fmt"
// 泛型栈
type Stack[T any] struct {
items []T
}
func NewStack[T any]() *Stack[T] {
return &Stack[T]{}
}
func (s *Stack[T]) Push(item T) {
s.items = append(s.items, item)
}
func (s *Stack[T]) Pop() (T, bool) {
var zero T
if len(s.items) == 0 {
return zero, false
}
item := s.items[len(s.items)-1]
s.items = s.items[:len(s.items)-1]
return item, true
}
func (s *Stack[T]) Peek() (T, bool) {
var zero T
if len(s.items) == 0 {
return zero, false
}
return s.items[len(s.items)-1], true
}
func (s *Stack[T]) Len() int {
return len(s.items)
}
// 泛型队列
type Queue[T any] struct {
items []T
}
func (q *Queue[T]) Enqueue(item T) {
q.items = append(q.items, item)
}
func (q *Queue[T]) Dequeue() (T, bool) {
var zero T
if len(q.items) == 0 {
return zero, false
}
item := q.items[0]
q.items = q.items[1:]
return item, true
}
func main() {
// 字符串栈
stack := NewStack[string]()
stack.Push("a")
stack.Push("b")
stack.Push("c")
for v, ok := stack.Pop(); ok; v, ok = stack.Pop() {
fmt.Println(v) // c, b, a
}
// 整数队列
queue := &Queue[int]{}
queue.Enqueue(1)
queue.Enqueue(2)
queue.Enqueue(3)
for v, ok := queue.Dequeue(); ok; v, ok = queue.Dequeue() {
fmt.Println(v) // 1, 2, 3
}
}
13.5 泛型接口和约束
package main
import (
"fmt"
"sort"
)
// 容器约束
type Container[T any] interface {
Add(item T)
Get(index int) (T, bool)
Len() int
}
// 可排序约束
type Sortable[T any] interface {
Len() int
Less(i, j int) bool
Swap(i, j int)
}
// 泛型排序
func Sort[T Ordered](slice []T) {
sort.Slice(slice, func(i, j int) bool {
return slice[i] < slice[j]
})
}
// 泛型集合
type Set[T comparable] struct {
items map[T]struct{}
}
func NewSet[T comparable](items ...T) *Set[T] {
s := &Set[T]{items: make(map[T]struct{})}
for _, item := range items {
s.Add(item)
}
return s
}
func (s *Set[T]) Add(item T) {
s.items[item] = struct{}{}
}
func (s *Set[T]) Remove(item T) {
delete(s.items, item)
}
func (s *Set[T]) Contains(item T) bool {
_, ok := s.items[item]
return ok
}
func (s *Set[T]) Len() int {
return len(s.items)
}
func (s *Set[T]) Slice() []T {
result := make([]T, 0, len(s.items))
for item := range s.items {
result = append(result, item)
}
return result
}
// 集合运算
func Union[T comparable](a, b *Set[T]) *Set[T] {
result := NewSet[T]()
for item := range a.items {
result.Add(item)
}
for item := range b.items {
result.Add(item)
}
return result
}
func Intersect[T comparable](a, b *Set[T]) *Set[T] {
result := NewSet[T]()
for item := range a.items {
if b.Contains(item) {
result.Add(item)
}
}
return result
}
func main() {
s1 := NewSet(1, 2, 3, 4)
s2 := NewSet(3, 4, 5, 6)
union := Union(s1, s2)
fmt.Println("并集:", union.Slice())
intersect := Intersect(s1, s2)
fmt.Println("交集:", intersect.Slice())
// 字符串集合
tags := NewSet("go", "rust", "python")
fmt.Println("包含 go:", tags.Contains("go"))
nums := []int{5, 3, 1, 4, 2}
Sort(nums)
fmt.Println("排序:", nums) // [1 2 3 4 5]
}
13.6 类型推断
func main() {
// 编译器可以推断类型参数
fmt.Println(max(3, 5)) // T 推断为 int
fmt.Println(max(3.14, 2.71)) // T 推断为 float64
// 显式指定
fmt.Println(max[int](3, 5))
fmt.Println(max[float64](3.14, 2.71))
// 无法推断时必须显式指定
stack := NewStack[int]() // 必须指定
stack.Push(1)
}
13.7 泛型限制
// ❌ 不支持的特性
// 1. 方法不能有额外的类型参数
// func (s Stack[T]) Map[U any](fn func(T) U) Stack[U] { } // 编译错误
// 2. 不能用类型参数做类型断言
// func foo[T any](x any) T { return x.(T) } // 编译错误
// 3. 不能用类型参数创建复合字面量
// func foo[T any]() []T { return []T{0} } // 如果 T 不是数值类型会出错
// 4. 不能用类型参数做指针操作
// func foo[T any](x T) *T { return &x } // OK,但有限制
// ✅ 工作方案
type SliceFuncs[T any] struct {
data []T
}
func NewSliceFuncs[T any](data []T) SliceFuncs[T] {
return SliceFuncs[T]{data: data}
}
func (sf SliceFuncs[T]) Map(fn func(T) T) SliceFuncs[T] {
result := make([]T, len(sf.data))
for i, v := range sf.data {
result[i] = fn(v)
}
return SliceFuncs[T]{data: result}
}
13.8 性能考虑
// 泛型在编译时实例化,运行时无额外开销
// 但会导致二进制文件增大(每种类型生成一份代码)
import "testing"
func BenchmarkSumGeneric(b *testing.B) {
nums := make([]int, 10000)
for i := range nums {
nums[i] = i
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
sum(nums)
}
}
func BenchmarkSumSpecific(b *testing.B) {
nums := make([]int, 10000)
for i := range nums {
nums[i] = i
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
sumInts(nums) // 专用函数
}
}
// 泛型 vs interface{} 的性能对比
func BenchmarkContainsGeneric(b *testing.B) {
nums := make([]int, 10000)
for i := range nums { nums[i] = i }
b.ResetTimer()
for i := 0; i < b.N; i++ {
contains(nums, 9999)
}
}
func BenchmarkContainsInterface(b *testing.B) {
nums := make([]any, 10000)
for i := range nums { nums[i] = i }
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, v := range nums {
if v.(int) == 9999 { break }
}
}
}
🏢 业务场景
- 通用数据结构:泛型栈、队列、集合、链表
- 工具函数:Map/Filter/Reduce 等函数式操作
- API 响应:泛型 Response 包装不同类型数据
- 缓存系统:泛型 Cache 支持任意键值类型
- 仓储模式:泛型 Repository 接口