From PyTorch to MLIR: Building a TorchDynamo-Based Compiler Frontend
From PyTorch to MLIR: Building a TorchDynamo-Based Compiler Frontend
Most ML compiler projects start with the backend — instruction selection, register allocation, code generation. But the frontend is equally critical: if you can’t efficiently capture and lower PyTorch models, your compiler doesn’t matter.
We built a pure-Python frontend that integrates with PyTorch’s torch.compile infrastructure via TorchDynamo. Here’s the architecture, the tricky parts, and what we learned.
The Capture Problem
PyTorch models are imperative Python code. To compile them, you need to capture their computational graph. There are several approaches:
| Approach | Pros | Cons |
|---|---|---|
| Tracing (torch.jit.trace) | Simple | Misses control flow |
| Scripting (torch.jit.script) | Handles control flow | Requires type annotations, limited Python |
| TorchDynamo (torch.compile) | Handles real Python, guards on shapes | More complex integration |
| Export (torch.export) | Clean graph guarantee | Stricter than Dynamo |
We chose TorchDynamo because it handles real-world PyTorch code — dynamic control flow, third-party library calls, and all the Python patterns that researchers actually use.
Architecture Overview
PyTorch Model
↓
TorchDynamo (FX Graph capture)
↓
Graph IR (our intermediate representation)
↓
TOSA / Linalg MLIR generation
↓
Bufferization (tensor → memref)
↓
Midend Pipeline (subprocess)
↓
Assembly Output
Three Compilation Modes
Our frontend exposes three modes through a single FTMCompiler class:
compiler = FTMCompiler()
# Mode 1: Inspect MLIR (debugging)
mlir_str = compiler.compile_to_mlir(model, *sample_inputs)
# Mode 2: Compile to assembly
compiler.compile_to_lasm(model, *sample_inputs, output_path="kernel.lan")
# Mode 3: JIT execution on CPU (validation)
result = compiler.jit_run(model, *sample_inputs)
Mode 3 uses MLIR’s ExecutionEngine to run on CPU — invaluable for validating correctness before running on target hardware.
The Graph IR Layer
Between FX capture and MLIR generation, we maintain our own Graph IR. This might seem redundant, but it serves three purposes:
- Decouples FX changes from MLIR generation — FX’s internal representation evolves across PyTorch versions
- Carries tensor metadata — dtype, shape, stride information propagated from PyTorch
- Enables pre-MLIR transformations — op fusion, quantization handling, weight format conversion
Each node in our Graph IR holds:
@dataclass
class Node:
op_type: str # "conv2d", "relu", "matmul", ...
inputs: List[Edge] # input edges with tensor metadata
outputs: List[Edge] # output edges
attributes: Dict # kernel_size, stride, padding, ...
metadata: TensorMeta # dtype, shape, strides
Op Lowering: 350+ Operations
The largest component (3,400+ lines) is operation.py — the registry mapping PyTorch ops to MLIR dialect operations. We support two target dialects:
TOSA Ops (Tensor Operator Set Architecture)
For standard neural network operations:
# torch.nn.functional.relu → tosa.clamp
def lower_relu(node):
return tosa.ClampOp(
input=node.inputs[0],
min_fp=0.0,
max_fp=float('inf')
)
Linalg Ops (for compute-intensive kernels)
For operations that need tiling and vectorization:
# torch.matmul → linalg.matmul
def lower_matmul(node):
return linalg.MatmulOp(
inputs=[node.inputs[0], node.inputs[1]],
outputs=[init_tensor]
)
The choice between TOSA and Linalg depends on the downstream optimization: TOSA ops have well-defined semantics but limited optimization opportunity. Linalg ops expose loop structure for our cost-model-driven tiling.
Handling Conv2d: The Im2Col Approach
Convolution is the most complex lowering. Rather than emitting a monolithic linalg.conv_2d, we decompose convolutions using the im2col pattern:
Input [N, C, H, W]
→ im2col → Matrix [N*OH*OW, C*KH*KW]
→ matmul with Weight [C*KH*KW, OC]
→ reshape → Output [N, OC, OH, OW]
This decomposition is done in a dedicated midend pass (ConvertConvToIm2Col) rather than the frontend, because:
- The frontend emits clean
linalg.conv_2dthat’s easy to validate - The midend can choose whether to decompose based on kernel size and stride
- Small convolutions (1x1) are better handled as direct matmul without im2col overhead
The Pipeline Subprocess Pattern
After generating MLIR, the midend runs as a subprocess:
class FTMPipeline:
def run(self, mlir_path, output_path):
# Step 1: MLIR optimization
subprocess.run([
"ftm-opt",
"--ftm-pipeline", # chains all 18 passes
mlir_path,
"-o", optimized_path
])
# Step 2: MLIR → LLVM IR
subprocess.run([
"ftm-translate",
"--mlir-to-llvmir",
optimized_path,
"-o", llvm_ir_path
])
# Step 3: LLVM IR → Assembly
subprocess.run([
"llc",
"-mtriple=matrix",
llvm_ir_path,
"-o", output_path
])
Why subprocess instead of in-process? Two reasons:
- Isolation: A crash in the C++ compiler doesn’t kill the Python process
- No C++ Python bindings needed for the midend: The midend is a standalone CMake project. Embedding it in Python would require maintaining pybind11 wrappers for every pass
We do offer Python bindings (via upstream MLIR’s Python API) for users who want tighter integration, but the subprocess path is the default.
Integration with torch.compile
The cleanest integration point is torch.compile:
from frontend import TorchCompileBackend
backend = TorchCompileBackend(output_dir="./compiled")
@torch.compile(backend=backend)
def my_model(x):
return torch.matmul(x, weight) + bias
# First call compiles; subsequent calls reuse
result = my_model(input_tensor)
Under the hood, TorchCompileBackend wraps our DynamoCompiler, handles graph breaks (when Dynamo can’t capture a subgraph), and caches compiled results.
Lessons Learned
1. Start with TOSA, add Linalg incrementally
TOSA provides a clean, well-specified starting point. You can lower most ops to TOSA first, validate correctness, then selectively lower compute-intensive ops to Linalg for optimization.
2. The op registry is a maintenance burden
350+ ops means constant work tracking PyTorch API changes. We mitigate this with:
- Decomposition: complex ops are decomposed into primitives (e.g.,
batch_norm→ multiply + add) - AOT Autograd: let PyTorch decompose backward ops before we see them
- Exhaustive testing: every op has a round-trip test (PyTorch → MLIR → execute → compare)
3. Subprocess isolation is worth the overhead
The ~50ms subprocess startup cost is negligible compared to compilation time. And the first time the C++ compiler segfaults during development (it will), you’ll be grateful your Python process survived.
4. Keep the Graph IR thin
Our initial Graph IR was too rich — it tried to encode optimization hints, target constraints, and scheduling decisions. This created coupling between frontend and midend. We stripped it back to pure data flow + tensor metadata. Let the midend make optimization decisions.
5. Validate at every boundary
- FX → Graph IR: check shapes match
- Graph IR → MLIR: round-trip through
mlir-opt --verify-each - MLIR → Assembly: run on CPU via ExecutionEngine, compare with PyTorch eager output
Three validation points catch bugs at the layer where they originate, rather than surfacing as mysterious assembly-level failures.
What’s Next
The frontier is dynamic shapes. TorchDynamo supports shape guards, but our midend pipeline assumes static shapes for tiling and DMA insertion. Bridging this gap — compiling shape-polymorphic kernels with specialization at runtime — is the next major challenge.