JAX Examples¶
Basic JAX Benchmarking¶
ZeroPyBench automatically detects JAX arrays in your code and wraps the benchmarked expression in a JIT-compiled function.
Note
For JAX benchmarking, ZeroPyBench provides additional measurements, see below.
import jax
import jax.numpy as jnp
import jax.random as jr
from zeropybench import Benchmark
bench = Benchmark(repeat=20)
x = jnp.ones(1000)
y = jnp.ones(1000)
with bench():
x + y
5.880 µs ± 0.81% (median of 20 runs, 50000 loops each)
Verbose Mode¶
Use verbose=True to see the setup code (JIT-compiled function) and the benchmarked code:
bench = Benchmark(verbose=True)
with bench():
x + y
Setup code:
@jax.jit
def __bench_func(x, y):
return x + y
Benchmarked code:
__bench_func(x, y).block_until_ready()
5.932 µs ± 0.32% (median of 7 runs, 50000 loops each)
For functions returning a PyTree (e.g., a tuple of Arrays), the slightly slower jax.block_until_ready is used instead.
from dataclasses import dataclass
@jax.tree_util.register_dataclass
@dataclass
class Vector:
x: jax.Array
y: jax.Array
@classmethod
def normal(cls, key, shape):
key_x, key_y = jr.split(key)
return Vector(jr.normal(key_x, shape), jr.normal(key_y, shape))
def __add__(self, other):
if not isinstance(other, Vector):
return NotImplemented
return Vector(self.x + other.x, self.y + other.y)
key1, key2 = jr.split(jr.key(42))
shape = (100,)
v1 = Vector.normal(key1, shape)
v2 = Vector.normal(key2, shape)
bench = Benchmark(verbose=True)
with bench():
v1 + v2
Setup code:
@jax.jit
def __bench_func(v1, v2):
return v1 + v2
Benchmarked code:
jax.block_until_ready(__bench_func(v1, v2))
9.488 µs ± 0.42% (median of 7 runs, 50000 loops each)
Multiple Bare Expressions¶
When the benchmarked code contains several bare expressions (without assignment), they are each captured into temporary variables so that block_until_ready can synchronize all computations. Calls to functions known to return None (such as print or functions annotated with -> None) are left as-is:
bench = Benchmark(verbose=True)
with bench():
print(x)
x + y
x * y
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
JitTracer(float32[1000])
Setup code:
@jax.jit
def __bench_func(x, y):
print(x)
__expr1 = x + y
__expr2 = x * y
return (__expr1, __expr2)
Benchmarked code:
jax.block_until_ready(__bench_func(x, y))
8.547 µs ± 1.17% (median of 7 runs, 50000 loops each)
JAX-Specific Report Fields¶
When JAX code is detected, the benchmark report includes additional fields
first_execution_time: The execution time of the code inside the context manager, which usually correspond to the non-jitted version of the code that is being benchmarked.compilation_time: the lowering and compilation time.generated_code_size: the total size of the generated machine code in bytes, including embedded constants.temp_size: the size of the preallocated temporary buffer arena in bytes. This accounts for intermediate buffers needed during execution, excluding input arguments, outputs, and constants.
Warning
The generated_code_size and temp_size values may not be reliable on the CPU backend.
print(bench)
run = bench[0]
print('Execution times [μs]:', sorted(_ * 1e6 for _ in run['execution_times']))
┌───┬────────────────┬──────────┬────────────────┬────────────────┬────────────────┬───────────────┐
│ ┆ median_executi ┆ ± (%) ┆ first_executio ┆ compilation_ti ┆ generated_code ┆ temp_size (B) │
│ ┆ on_time (µs) ┆ ┆ n_time (µs) ┆ me (µs) ┆ _size (B) ┆ │
╞═══╪════════════════╪══════════╪════════════════╪════════════════╪════════════════╪═══════════════╡
│ 0 ┆ 8.547132 ┆ 1.173406 ┆ 57_029.436 ┆ 29_734.152999 ┆ 0 ┆ 0 │
└───┴────────────────┴──────────┴────────────────┴────────────────┴────────────────┴───────────────┘
Execution times [μs]: [8.479386939980031, 8.479485339994426, 8.500881360014318, 8.547131759987678, 8.550357340027404, 8.624570760002825, 8.70624356000917]
Visualizing the HLO¶
The hlo field contains the StableHLO representation of the compiled function. You can visualize it using the library visu-hlo:
from visu_hlo import show
show(run['hlo'])
Comparing Broadcasting Strategies¶
Benchmark different array operations to compare their performance:
bench = Benchmark()
for N in [100, 1000, 10000]:
x = jnp.ones(N)
y = jnp.ones(1000)
with bench(method='broadcast right', N=N):
x[:, None] + y[None, :]
with bench(method='broadcast left', N=N):
x[None, :] + y[:, None]
method=broadcast right, N=100: 13.746 µs ± 2.43% (median of 7 runs, 20000 loops each)
method=broadcast left, N=100: 14.150 µs ± 2.89% (median of 7 runs, 20000 loops each)
method=broadcast right, N=1000: 50.810 µs ± 1.57% (median of 7 runs, 5000 loops each)
method=broadcast left, N=1000: 50.451 µs ± 1.94% (median of 7 runs, 5000 loops each)
method=broadcast right, N=10000: 12.588 ms ± 1.70% (median of 7 runs, 20 loops each)
method=broadcast left, N=10000: 12.741 ms ± 2.50% (median of 7 runs, 20 loops each)
print(bench)
┌───┬────────────┬────────┬────────────┬──────────┬────────────┬───────────┬───────────┬───────────┐
│ ┆ method ┆ N ┆ median_exe ┆ ± (%) ┆ first_exec ┆ compilati ┆ generated ┆ temp_size │
│ ┆ ┆ ┆ cution_tim ┆ ┆ ution_time ┆ on_time ┆ _code_siz ┆ (B) │
│ ┆ ┆ ┆ e (ms) ┆ ┆ (ms) ┆ (ms) ┆ e (B) ┆ │
╞═══╪════════════╪════════╪════════════╪══════════╪════════════╪═══════════╪═══════════╪═══════════╡
│ 0 ┆ broadcast ┆ 100 ┆ 0.013746 ┆ 2.428567 ┆ 34.956574 ┆ 20.968133 ┆ 0 ┆ 0 │
│ ┆ right ┆ ┆ ┆ ┆ ┆ ┆ ┆ │
│ 1 ┆ broadcast ┆ 100 ┆ 0.01415 ┆ 2.893307 ┆ 34.258964 ┆ 19.529872 ┆ 0 ┆ 0 │
│ ┆ left ┆ ┆ ┆ ┆ ┆ ┆ ┆ │
│ 2 ┆ broadcast ┆ 1_000 ┆ 0.05081 ┆ 1.568993 ┆ 36.311284 ┆ 22.613573 ┆ 0 ┆ 0 │
│ ┆ right ┆ ┆ ┆ ┆ ┆ ┆ ┆ │
│ 3 ┆ broadcast ┆ 1_000 ┆ 0.050451 ┆ 1.939443 ┆ 37.465634 ┆ 22.188032 ┆ 0 ┆ 0 │
│ ┆ left ┆ ┆ ┆ ┆ ┆ ┆ ┆ │
│ 4 ┆ broadcast ┆ 10_000 ┆ 12.587709 ┆ 1.702945 ┆ 37.154414 ┆ 31.106143 ┆ 0 ┆ 0 │
│ ┆ right ┆ ┆ ┆ ┆ ┆ ┆ ┆ │
│ 5 ┆ broadcast ┆ 10_000 ┆ 12.740685 ┆ 2.498716 ┆ 39.017844 ┆ 30.673434 ┆ 0 ┆ 0 │
│ ┆ left ┆ ┆ ┆ ┆ ┆ ┆ ┆ │
└───┴────────────┴────────┴────────────┴──────────┴────────────┴───────────┴───────────┴───────────┘
Benchmarking JIT-compiled Functions¶
ZeroPyBench handles both regular and JIT-compiled functions:
Note
When benchmarking an already JIT-compiled function, ZeroPyBench reuses it directly without re-wrapping, preserving any static_argnums or other JIT options you specified.
import jax
@jax.jit
def matmul(a, b):
return a @ b
bench = Benchmark(verbose=True)
for N in [256, 512]:
a = jnp.ones((N, N))
b = jnp.ones((N, N))
with bench(operation='matmul', N=N):
matmul(a, b)
Setup code:
__bench_func = matmul
Benchmarked code:
__bench_func(a, b).block_until_ready()
operation=matmul, N=256: 202.963 µs ± 1.32% (median of 7 runs, 1000 loops each)
Setup code:
__bench_func = matmul
Benchmarked code:
__bench_func(a, b).block_until_ready()
operation=matmul, N=512: 1.234 ms ± 0.45% (median of 7 runs, 200 loops each)
print(bench)
┌───┬───────────┬─────┬─────────────┬──────────┬────────────┬────────────┬────────────┬────────────┐
│ ┆ operation ┆ N ┆ median_exec ┆ ± (%) ┆ first_exec ┆ compilatio ┆ generated_ ┆ temp_size │
│ ┆ ┆ ┆ ution_time ┆ ┆ ution_time ┆ n_time ┆ code_size ┆ (B) │
│ ┆ ┆ ┆ (ms) ┆ ┆ (ms) ┆ (ms) ┆ (B) ┆ │
╞═══╪═══════════╪═════╪═════════════╪══════════╪════════════╪════════════╪════════════╪════════════╡
│ 0 ┆ matmul ┆ 256 ┆ 0.202963 ┆ 1.315189 ┆ 8.609431 ┆ 0.11784 ┆ 0 ┆ 0 │
│ 1 ┆ matmul ┆ 512 ┆ 1.233717 ┆ 0.45166 ┆ 8.463781 ┆ 0.11857 ┆ 0 ┆ 0 │
└───┴───────────┴─────┴─────────────┴──────────┴────────────┴────────────┴────────────┴────────────┘
Visualizing the Matrix Multiplication HLO¶
Display the HLO computational graph for the second run (N=128)
show(bench[1]['hlo'])
Plotting Results¶
bench.plot()