强曰为道
与天地相似,故不违。知周乎万物,而道济天下,故不过。旁行而不流,乐天知命,故不忧.
文档目录

LLVM 开发指南 / 第 17 章:MLIR 多级 IR

第 17 章:MLIR 多级 IR

“MLIR 不只是一个 IR,而是一个构建 IR 的框架。”


17.1 MLIR 概述

MLIR(Multi-Level Intermediate Representation) 是 LLVM 项目中的通用编译器基础设施,旨在解决传统编译器中的"方言碎片化"问题。

17.1.1 为什么需要 MLIR?

传统编译器的问题:
  每个领域 (深度学习、HPC、量子计算) 都有自己的 IR
  无法共享优化和基础设施
  重复劳动

MLIR 的解决方案:
  提供统一的可扩展框架
  通过 Dialect 支持任意抽象级别
  渐进式降低 (Progressive Lowering)

17.1.2 MLIR 应用场景

场景 说明 用户
深度学习 TensorFlow → MLIR → GPU/TPU TensorFlow, PyTorch
HPC 编译 循环优化、向量化 Polygeist, Triton
硬件加速器 FPGA/ASIC 编译 CIRCT, HEDA
领域特定语言 DSL 编译器 数学、物理引擎
源到源翻译 代码变换 MLIR-to-MLIR

17.2 核心概念

17.2.1 Dialect(方言)

Dialect 是 MLIR 中组织 Operation 和 Type 的命名空间:

┌──────────────────────────────────────┐
│              MLIR                     │
│                                      │
│  ┌──────────┐  ┌──────────┐         │
│  │ arith    │  │  func    │         │
│  │ 算术运算  │  │ 函数定义  │         │
│  └──────────┘  └──────────┘         │
│                                      │
│  ┌──────────┐  ┌──────────┐         │
│  │  linalg  │  │ tensor   │         │
│  │ 线性代数  │  │ 张量操作  │         │
│  └──────────┘  └──────────┘         │
│                                      │
│  ┌──────────┐  ┌──────────┐         │
│  │   scf    │  │  memref  │         │
│  │ 结构控制流│  │ 内存引用  │         │
│  └──────────┘  └──────────┘         │
│                                      │
│  ┌──────────┐  ┌──────────┐         │
│  │  llvm    │  │  gpu     │         │
│  │ LLVM 方言│  │ GPU 操作  │         │
│  └──────────┘  └──────────┘         │
└──────────────────────────────────────┘

17.2.2 Operation(操作)

// MLIR 操作格式:
%result = dialect.operation_name %operand1, %operand2 {attr = value} : (type1, type2) -> type3

// 示例:
%sum = arith.addi %a, %b : i32                    // 整数加法
%prod = arith.mulf %x, %y : f32                   // 浮点乘法
%cmp = arith.cmpi eq, %a, %b : i32                // 整数比较

// 函数定义
func.func @add(%arg0: i32, %arg1: i32) -> i32 {
  %result = arith.addi %arg0, %arg1 : i32
  return %result : i32
}

// 控制流
scf.for %i = %lb to %ub step %step {
  %val = memref.load %arr[%i] : memref<10xi32>
  scf.yield
}

// 张量操作
%c = linalg.matmul ins(%a, %b : tensor<4x8xf32>, tensor<8x16xf32>)
                   outs(%c : tensor<4x16xf32>) -> tensor<4x16xf32>

17.2.3 Block 和 Region

// Region 包含 Block
// Block 包含 Operation
// Operation 可以包含 Region (嵌套)

func.func @example(%cond: i1) -> i32 {
  // 这里是一个 Region
  // 包含一个 Block (entry block)

  cf.cond_br %cond, ^bb1, ^bb2  // 条件分支

^bb1:
  %x = arith.constant 42 : i32
  cf.br ^bb3(%x : i32)

^bb2:
  %y = arith.constant 0 : i32
  cf.br ^bb3(%y : i32)

^bb3(%result: i32):
  return %result : i32
}

17.3 MLIR 类型系统

// 基础类型
i32                     // 32位整数
f64                     // 64位浮点
index                   // 索引类型 (平台相关)

// 张量类型
tensor<4xf32>           // 1D张量,4个f32
tensor<3x4xf32>         // 2D张量
tensor<?x?xf32>         // 动态形状张量
tensor<4x?x8xi32>       // 部分动态

// 内存引用类型
memref<10xf32>          // 1D内存,10个f32
memref<3x4xf32>         // 2D内存
memref<?xf32>           // 动态大小
memref<10xf32, affine_map<(d0) -> (d0)>, 1>  // 带布局和地址空间

// 向量类型
vector<4xf32>           // SIMD向量
vector<4x8xf32>         // 2D向量

// 函数类型
(i32, i32) -> i32       // 两参数一返回值
(i32) -> ()             // 一参数无返回值

// 索引映射类型
affine_map<(d0, d1) -> (d0, d1)>    // 仿射映射

17.4 常用 Dialect

Dialect 用途 抽象级别
func 函数定义 高层
arith 算术运算 中层
math 数学函数 中层
tensor 张量操作 高层
linalg 线性代数 高层
memref 内存操作 中层
scf 结构化控制流 中层
cf 非结构化控制流 低层
affine 仿射循环 高层
llvm LLVM 方言 最低层
gpu GPU 操作 高层
spirv SPIR-V GPU后端
tosa Tensor 操作 AI 专用

17.5 渐进式降低 (Progressive Lowering)

高级别表示:
  %c = linalg.matmul ins(%a, %b : tensor<4x8xf32>, tensor<8x16xf32>)
                     outs(%c : tensor<4x16xf32>)
        ↓ tensor → memref
中级别表示:
  linalg.matmul ins(%a, %b : memref<4x8xf32>, memref<8x16xf32>)
                outs(%c : memref<4x16xf32>)
        ↓ linalg → scf + arith
循环表示:
  scf.for %i = ... {
    scf.for %j = ... {
      scf.for %k = ... {
        %a_val = memref.load %a[%i, %k]
        %b_val = memref.load %b[%k, %j]
        %c_val = memref.load %c[%i, %j]
        %prod = arith.mulf %a_val, %b_val
        %sum = arith.addf %c_val, %prod
        memref.store %sum, %c[%i, %j]
      }
    }
  }
        ↓ scf → cf + memref → llvm
LLVM IR 表示:
  llvm.func @matmul(...) {
    llvm.br ^bb1
    ...
  }
        ↓ LLVM 后端
机器码

17.6 Dialect 变换 Pass

17.6.1 内置 Pass

# 使用 mlir-opt 运行变换
mlir-opt input.mlir -pass-name -o output.mlir

# 常用变换
mlir-opt input.mlir \
  -linalg-bufferize \
  -tensor-bufferize \
  -func-bufferize \
  -convert-linalg-to-loops \
  -convert-scf-to-cf \
  -convert-memref-to-llvm \
  -convert-arith-to-llvm \
  -convert-func-to-llvm \
  -reconcile-unrealized-casts \
  -o output.mlir

# 一键降级到 LLVM 方言
mlir-opt input.mlir \
  -one-shot-bufferize \
  -convert-linalg-to-loops \
  -lower-affine \
  -convert-scf-to-cf \
  -finalize-memref-to-llvm \
  -convert-func-to-llvm \
  -convert-arith-to-llvm \
  -reconcile-unrealized-casts \
  -o output.mlir

17.6.2 编写自定义变换

// MyTransformPass.cpp
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

struct MyPattern : public OpRewritePattern<arith::AddIOp> {
    using OpRewritePattern::OpRewritePattern;

    LogicalResult matchAndRewrite(arith::AddIOp op,
                                   PatternRewriter &rewriter) const override {
        // 匹配 add x, 0 → x
        auto rhs = op.getRhs().getDefiningOp<arith::ConstantOp>();
        if (rhs && rhs.getValue().cast<IntegerAttr>().getInt() == 0) {
            rewriter.replaceOp(op, op.getLhs());
            return success();
        }
        return failure();
    }
};

struct MyPass : public PassWrapper<MyPass, OperationPass<ModuleOp>> {
    MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyPass)

    StringRef getArgument() const override { return "my-pass"; }
    StringRef getDescription() const override { return "My custom pass"; }

    void runOnOperation() override {
        MLIRContext *ctx = &getContext();
        RewritePatternSet patterns(ctx);
        patterns.add<MyPattern>(ctx);
        
        if (failed(applyPatternsAndFoldGreedily(
                getOperation(), std::move(patterns)))) {
            signalPassFailure();
        }
    }
};

// 注册 Pass
void registerMyPass() {
    PassRegistration<MyPass>();
}

17.7 MLIR 工具链

工具 功能
mlir-opt 运行变换 Pass
mlir-translate IR 翻译(MLIR ↔ LLVM IR)
mlir-cpu-runner CPU 上执行 MLIR
mlir-tblgen Dialect/TableGen 代码生成
mlir-lsp-server 语言服务器
# MLIR → LLVM IR
mlir-translate --mlir-to-llvmir input.mlir -o output.ll

# LLVM IR → MLIR
mlir-translate --import-llvm input.ll -o output.mlir

# 执行 MLIR
mlir-cpu-runner -e main -entry-point-result=i32 \
  -shared-libs=/usr/lib/libmlir_runner_utils.so input.mlir

17.8 本章小结

概念 说明
Dialect 操作和类型的命名空间
Operation IR 的基本单元
Region/Block 嵌套结构
Progression Lowering 逐步降低抽象级别
Bufferization Tensor → Memref 转换

扩展阅读

  1. MLIR 官方文档 — MLIR 项目主页
  2. MLIR Tutorial — Toy 语言教程
  3. MLIR Dialects — 方言文档
  4. MLIR Language Reference — 语言参考

下一章: 第 18 章:开发环境与构建系统 — 学习 LLVM 的 Docker 开发环境和 CMake 集成。