Optimizing Matrix Multiplication
“Matrix multiplication is at the heart of modern AI and computer vision.” - Yann LeCun
Faster Matrix Multiplication: Methods, Complexity, and Code
Matrix multiplication is a fundamental operation in scientific computing, machine learning, and engineering. Optimizing it can lead to significant performance gains. In this post, we explore several methods for matrix multiplication, their time complexities, and provide code samples for each.
1. Naive (Triple Loop) Method
Time Complexity: $O(n^3)$
Description: This is the most straightforward way to multiply two matrices, directly implementing the mathematical definition. It uses three nested loops to compute the dot product of rows and columns. While easy to understand and implement, it is extremely slow for large matrices and is mainly used for educational purposes or when working with very small matrices. It does not leverage any hardware acceleration or low-level optimizations.
def naive_matmul(A, B):
n = len(A)
m = len(B[0])
p = len(B)
result = [[0 for _ in range(m)] for _ in range(n)]
for i in range(n):
for j in range(m):
for k in range(p):
result[i][j] += A[i][k] * B[k][j]
return result
2. NumPy (np.matmul
)
Time Complexity: $O(n^3)$ (but highly optimized with BLAS/LAPACK)
Description:
NumPy is the de facto standard for numerical computing in Python. Its matmul
(or @
operator) function is a wrapper around highly optimized C and Fortran libraries (like BLAS and LAPACK). These libraries use advanced techniques such as cache blocking, SIMD instructions, and multi-threading to achieve high performance. For most practical applications, NumPy is the go-to solution for matrix multiplication on CPUs.
import numpy as np
A = np.random.rand(500, 500)
B = np.random.rand(500, 500)
C = np.matmul(A, B) # or A @ B
3. Numba JIT Compilation
Time Complexity: $O(n^3)$ (but much faster than pure Python)
Description:
Numba is a Just-In-Time (JIT) compiler for Python that can accelerate numerical code by compiling it to machine code at runtime. By decorating a function with @njit
, Numba can make pure Python code run nearly as fast as compiled C code. It works especially well with NumPy arrays and loops. Numba can also target CUDA GPUs for even greater speedups, making it a flexible tool for both CPU and GPU acceleration without leaving Python.
from numba import njit
import numpy as np
@njit
def numba_matmul(A, B):
n, m = A.shape[0], B.shape[1]
p = B.shape[0]
result = np.zeros((n, m))
for i in range(n):
for j in range(m):
for k in range(p):
result[i, j] += A[i, k] * B[k, j]
return result
# Usage
A = np.random.rand(500, 500)
B = np.random.rand(500, 500)
C = numba_matmul(A, B)
4. CUDA (GPU Acceleration)
Time Complexity: $O(n^3)$ (but with massive parallelism)
Description: Modern GPUs are designed for highly parallel computations, making them ideal for matrix operations. Libraries like CuPy (NumPy-like API for CUDA) and PyTorch can offload matrix multiplication to the GPU, achieving speedups of 10x or more compared to CPU for large matrices. This is especially useful in deep learning, scientific simulations, and any workload where large matrix multiplications are a bottleneck. However, data transfer between CPU and GPU memory can be a limiting factor for small matrices.
import cupy as cp
A = cp.random.rand(500, 500)
B = cp.random.rand(500, 500)
C = cp.matmul(A, B)
# To move result back to CPU: C_cpu = cp.asnumpy(C)
5. Strassen’s Algorithm
Time Complexity: $O(n^{2.81})$
Description: Strassen’s algorithm is a classic example of a divide-and-conquer approach to matrix multiplication. It reduces the number of scalar multiplications required, improving the theoretical time complexity over the naive method. However, it introduces more additions and is numerically less stable. In practice, Strassen’s algorithm is only faster than the naive or BLAS-based methods for very large matrices (typically thousands by thousands), and is rarely used in standard libraries due to its complexity and overhead for small or medium sizes.
Note: Best for large matrices; overhead dominates for small sizes. Not commonly used in production code unless working with very large datasets or in specialized applications.
def strassen(A, B):
# Assumes square matrices of size 2^n x 2^n
n = len(A)
if n == 1:
return [[A[0][0] * B[0][0]]]
else:
mid = n // 2
# Partition matrices
A11 = [row[:mid] for row in A[:mid]]
A12 = [row[mid:] for row in A[:mid]]
A21 = [row[:mid] for row in A[mid:]]
A22 = [row[mid:] for row in A[mid:]]
B11 = [row[:mid] for row in B[:mid]]
B12 = [row[mid:] for row in B[:mid]]
B21 = [row[:mid] for row in B[mid:]]
B22 = [row[mid:] for row in B[mid:]]
# Compute sub-products
M1 = strassen(add(A11, A22), add(B11, B22))
M2 = strassen(add(A21, A22), B11)
M3 = strassen(A11, sub(B12, B22))
M4 = strassen(A22, sub(B21, B11))
M5 = strassen(add(A11, A12), B22)
M6 = strassen(sub(A21, A11), add(B11, B12))
M7 = strassen(sub(A12, A22), add(B21, B22))
# Combine results
C11 = add(sub(add(M1, M4), M5), M7)
C12 = add(M3, M5)
C21 = add(M2, M4)
C22 = add(sub(add(M1, M3), M2), M6)
# Assemble new matrix
new_matrix = []
for i in range(mid):
new_matrix.append(C11[i] + C12[i])
for i in range(mid):
new_matrix.append(C21[i] + C22[i])
return new_matrix
def add(A, B):
return [[A[i][j] + B[i][j] for j in range(len(A[0]))] for i in range(len(A))]
def sub(A, B):
return [[A[i][j] - B[i][j] for j in range(len(A[0]))] for i in range(len(A))]
6. Cython/CPython Extensions
Time Complexity: $O(n^3)$ (but compiled, so much faster than pure Python)
Description: Cython is a superset of Python that allows you to write C extensions for Python. By adding type annotations and compiling the code, you can achieve performance close to pure C. This is useful for performance-critical sections of code where Python is too slow, but you want to maintain most of the codebase in Python. Cython is widely used in scientific libraries (including parts of NumPy and SciPy) and is a good choice when you need custom, high-performance routines.
Example (Cython):
# cython: boundscheck=False, wraparound=False
cimport cython
import numpy as np
cimport numpy as np
@cython.boundscheck(False)
@cython.wraparound(False)
def cython_matmul(np.ndarray[double, ndim=2] A, np.ndarray[double, ndim=2] B):
cdef int n = A.shape[0]
cdef int m = B.shape[1]
cdef int p = B.shape[0]
cdef np.ndarray[double, ndim=2] result = np.zeros((n, m), dtype=np.float64)
cdef int i, j, k
for i in range(n):
for j in range(m):
for k in range(p):
result[i, j] += A[i, k] * B[k, j]
return result
Summary Table
Method | Time Complexity | Hardware/Notes |
---|---|---|
Naive (triple loop) | $O(n^3)$ | CPU, slow |
NumPy | $O(n^3)$ | BLAS/LAPACK, CPU |
Numba JIT | $O(n^3)$ | CPU, can target CUDA |
CUDA | $O(n^3)$ | Nvidia GPU |
Strassen’s Algorithm | $O(n^{2.81})$ | CPU, best for large matrices |
Cython/CPython | $O(n^3)$ | Compiled, CPU |
Latest Research
Recent research focuses on hardware acceleration (e.g., CGRA, systolic arrays), low-rank adaptation for deep learning, and efficient implementations for transformers and LLMs. No fundamentally new matrix multiplication algorithm (i.e., new exponent ω) has been announced as of July 2025. See arXiv matrix multiplication search for the latest papers.
References: