Benchmarking Analysis of PyTorch Models using torch.compile()
Accelerating PyTorch with torch.compile(): When and Why It Works
PyTorch 2.0 introduced torch.compile()
, a powerful feature aimed at accelerating PyTorch models with minimal code changes. In this post, I’ll analyze benchmark results across different model architectures to understand when torch.compile()
delivers significant speedups and why its performance improvements vary dramatically across different model types.
Benchmark Results
I ran comprehensive benchmarks on various model architectures to measure the impact of torch.compile()
. Here’s what I found:
Model Name | Original Time (ms) | Compiled Time (ms) | Speedup | Compile Time (s) | Improvement (%) |
---|---|---|---|---|---|
Simple Linear (Batch=32) | 0.76 | 0.83 | 0.92x | 15.33 | -8.28% |
Large Linear (Batch=64) | 5.55 | 5.75 | 0.97x | 3.29 | -3.60% |
ConvNet (224x224) | 1557.36 | 787.21 | 1.98x | 14.37 | 49.45% |
Transformer Block | 58.59 | 57.93 | 1.01x | 5.94 | 1.14% |
Complex Mixed Model | 35.92 | 6.87 | 5.23x | 5.51 | 80.88% |
Key Observations
The results reveal a striking pattern:
Simple models got slower: The simple linear models experienced a slight performance degradation with
torch.compile()
, with the smallest model slowing down by 8.28%.Complex models shine: The ConvNet and Complex Mixed Model showed dramatic improvements, with the latter achieving a remarkable 5.23x speedup (80.88% improvement).
Medium complexity models saw minimal benefit: The Transformer Block showed only a marginal 1.14% improvement.
Compilation overhead varies: The compilation time doesn’t directly correlate with model size or complexity, ranging from 3.29s to 15.33s.
Why Does torch.compile() Work Better for Complex Models?
The Technical Underpinnings
torch.compile()
works by leveraging a technique called “just-in-time” (JIT) compilation through PyTorch’s TorchDynamo and TorchInductor backends. Here’s why it performs differently across model types:
1. Optimization Opportunities
Simple models like linear layers present fewer optimization opportunities. With just a matrix multiplication and bias addition, there’s little room for the compiler to improve upon PyTorch’s already optimized implementations. The compilation overhead actually becomes counterproductive in these cases.
2. Graph Optimization
For complex models with many operations, torch.compile()
can:
- Fuse operations that would otherwise require multiple GPU kernels
- Eliminate redundant memory accesses
- Optimize data movement between CPU and GPU
- Perform more aggressive constant folding and dead code elimination
3. Memory Access Patterns
The Complex Mixed Model likely contains diverse operations with complex memory access patterns. torch.compile()
can substantially reorganize these to minimize memory transfers and maximize cache utilization.
Code Analysis: How torch.compile() Transforms Models
Let’s examine a simplified version of the code used to create this benchmark:
# Setting up torch.compile() for a model
def benchmark_model(model, input_data, name, num_warmup=10, num_iter=50):
# Measure original performance
model(input_data) # Initial warmup
# Benchmark original model
start = time.time()
for _ in range(num_iter):
model(input_data)
original_time = (time.time() - start) * 1000 / num_iter # ms per iteration
# Compile the model and measure compilation time
print(f"Compiling {name}...")
compile_start = time.time()
compiled_model = torch.compile(model)
# Warmup runs for the compiled model
for _ in range(num_warmup):
compiled_model(input_data)
# Benchmark compiled model
start = time.time()
for _ in range(num_iter):
compiled_model(input_data)
compiled_time = (time.time() - start) * 1000 / num_iter # ms per iteration
compile_time = time.time() - compile_start
# Calculate improvements
speedup = original_time / compiled_time
improvement_pct = (1 - compiled_time / original_time) * 100
return {
"model_name": name,
"original_time_ms": round(original_time, 2),
"compiled_time_ms": round(compiled_time, 2),
"speedup": round(speedup, 2),
"compile_time_s": round(compile_time, 2),
"improvement_pct": round(improvement_pct, 2)
}
The key part is incredibly simple:
compiled_model = torch.compile(model)
This single line triggers a complex compilation process:
- TorchDynamo captures the Python program’s control flow and creates an FX graph
- Specialized backends (like inductor) transform this graph into optimized code
- Optimized kernels are generated and cached for future use
The Complex Mixed Model: Where torch.compile() Shines Brightest
The Complex Mixed Model showed an astonishing 5.23x speedup. This model likely combines:
- Multiple layer types (convolutions, attention, linear)
- Complex data transformations
- Varied tensor shapes and operations
The compiler can holistically analyze the entire computation graph rather than optimizing individual operations in isolation. This “global view” enables optimizations that aren’t possible when executing operations sequentially.
Let’s look at a conceptual example of the Complex Mixed Model:
class ComplexMixedModel(nn.Module):
def __init__(self):
super().__init__()
# Convolutional features
self.conv = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.bn = nn.BatchNorm2d(64)
# Attention mechanism
self.query = nn.Linear(64, 64)
self.key = nn.Linear(64, 64)
self.value = nn.Linear(64, 64)
# Final classification
self.fc1 = nn.Linear(64*8*8, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
# Conv path
x = F.relu(self.bn(self.conv(x)))
x = F.max_pool2d(x, 2)
# Store original shape for later
batch_size, C, H, W = x.shape
# Reshape for attention
x_flat = x.view(batch_size, C, -1).permute(0, 2, 1) # [B, H*W, C]
# Self-attention
q = self.query(x_flat)
k = self.key(x_flat)
v = self.value(x_flat)
# Attention weights and context
scores = torch.bmm(q, k.transpose(1, 2)) / math.sqrt(64)
attn = F.softmax(scores, dim=-1)
context = torch.bmm(attn, v)
# Reshape back
x_out = context.permute(0, 2, 1).view(batch_size, C, H, W)
# Skip connection
x = x + x_out
# Final classification
x = F.adaptive_avg_pool2d(x, (8, 8))
x = x.view(batch_size, -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
In this model, torch.compile()
can:
- Fuse the convolution, batch norm, and ReLU operations
- Optimize the reshape and permute operations to minimize memory movement
- Specialize the attention computation for the specific tensor shapes
- Eliminate intermediate buffers where possible
- Parallelize independent operations
When Should You Use torch.compile()?
Based on the benchmark results, here are practical recommendations:
Do use torch.compile() for:
- Complex models with diverse operations
- Computationally intensive models
- Models with convolutions and custom operations
- Production deployment where compilation overhead is amortized
Consider alternatives for:
- Simple models with few operations
- Models that are already optimized
- Development workflows where quick iteration matters more than runtime
Be strategic about compilation time:
- The one-time compilation cost (3-15s in our benchmarks) may not be worth it for models that run briefly or infrequently
Conclusion
torch.compile()
represents a significant advancement in PyTorch’s performance capabilities, but its benefits are not uniform across all model types. The most dramatic improvements occur in complex models where the compiler can identify and implement optimizations that would be infeasible by hand.
As these benchmarks demonstrate, compilation speedups range from marginal or even negative for simple models to extraordinary (5.23x) for complex ones. Understanding when and why these speedups occur allows you to make informed decisions about incorporating torch.compile()
into your deep learning workflows.
For complex models in production environments where every millisecond matters, torch.compile()
is becoming an essential tool in the PyTorch ecosystem - delivering substantial performance gains with minimal code changes.
The complete benchmark code is available in this Kaggle notebook.